aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/BUILD695
-rw-r--r--tensorflow/core/client/tensor_c_api.cc370
-rw-r--r--tensorflow/core/client/tensor_c_api_test.cc94
-rw-r--r--tensorflow/core/common_runtime/device.cc37
-rw-r--r--tensorflow/core/common_runtime/device.h128
-rw-r--r--tensorflow/core/common_runtime/device_factory.cc106
-rw-r--r--tensorflow/core/common_runtime/device_factory.h69
-rw-r--r--tensorflow/core/common_runtime/device_mgr.cc90
-rw-r--r--tensorflow/core/common_runtime/device_mgr.h55
-rw-r--r--tensorflow/core/common_runtime/device_set.cc68
-rw-r--r--tensorflow/core/common_runtime/device_set.h64
-rw-r--r--tensorflow/core/common_runtime/device_set_test.cc65
-rw-r--r--tensorflow/core/common_runtime/eigen_thread_pool.h22
-rw-r--r--tensorflow/core/common_runtime/executor.cc2118
-rw-r--r--tensorflow/core/common_runtime/executor.h209
-rw-r--r--tensorflow/core/common_runtime/function.cc1335
-rw-r--r--tensorflow/core/common_runtime/function.h100
-rw-r--r--tensorflow/core/common_runtime/gpu/dma_helper.h18
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_allocator_retry.cc49
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h36
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc175
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc397
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h156
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc166
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc186
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h68
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc207
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc651
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.h94
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_factory.cc52
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc132
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.h118
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc152
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_init.cc147
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_init.h19
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_region_allocator.cc371
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_region_allocator.h146
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_region_allocator_test.cc71
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_stream_util.cc97
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_stream_util.h30
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc137
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util.cc345
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util.h89
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc24
-rw-r--r--tensorflow/core/common_runtime/gpu/pool_allocator.cc269
-rw-r--r--tensorflow/core/common_runtime/gpu/pool_allocator.h202
-rw-r--r--tensorflow/core/common_runtime/gpu/pool_allocator_test.cc203
-rw-r--r--tensorflow/core/common_runtime/gpu/process_state.cc220
-rw-r--r--tensorflow/core/common_runtime/gpu/process_state.h140
-rw-r--r--tensorflow/core/common_runtime/gpu/visitable_allocator.h30
-rw-r--r--tensorflow/core/common_runtime/gpu_device_context.h45
-rw-r--r--tensorflow/core/common_runtime/kernel_benchmark_testlib.cc160
-rw-r--r--tensorflow/core/common_runtime/kernel_benchmark_testlib.h52
-rw-r--r--tensorflow/core/common_runtime/local_device.cc51
-rw-r--r--tensorflow/core/common_runtime/local_device.h27
-rw-r--r--tensorflow/core/common_runtime/local_session.cc500
-rw-r--r--tensorflow/core/common_runtime/local_session.h109
-rw-r--r--tensorflow/core/common_runtime/local_session_test.cc314
-rw-r--r--tensorflow/core/common_runtime/rendezvous_mgr.cc170
-rw-r--r--tensorflow/core/common_runtime/rendezvous_mgr.h73
-rw-r--r--tensorflow/core/common_runtime/session.cc51
-rw-r--r--tensorflow/core/common_runtime/session_factory.cc41
-rw-r--r--tensorflow/core/common_runtime/session_factory.h25
-rw-r--r--tensorflow/core/common_runtime/session_options.cc9
-rw-r--r--tensorflow/core/common_runtime/session_test.cc17
-rw-r--r--tensorflow/core/common_runtime/simple_placer.cc559
-rw-r--r--tensorflow/core/common_runtime/simple_placer.h81
-rw-r--r--tensorflow/core/common_runtime/simple_placer_test.cc863
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.cc55
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.h31
-rw-r--r--tensorflow/core/common_runtime/threadpool_device_factory.cc31
-rw-r--r--tensorflow/core/example/example.proto95
-rw-r--r--tensorflow/core/example/feature.proto82
-rw-r--r--tensorflow/core/framework/allocation_description.proto15
-rw-r--r--tensorflow/core/framework/allocator.cc25
-rw-r--r--tensorflow/core/framework/allocator.h132
-rw-r--r--tensorflow/core/framework/allocator_test.cc61
-rw-r--r--tensorflow/core/framework/attr_value.proto57
-rw-r--r--tensorflow/core/framework/attr_value_util.cc382
-rw-r--r--tensorflow/core/framework/attr_value_util.h83
-rw-r--r--tensorflow/core/framework/attr_value_util_test.cc91
-rw-r--r--tensorflow/core/framework/bfloat16.cc22
-rw-r--r--tensorflow/core/framework/bfloat16.h58
-rw-r--r--tensorflow/core/framework/bfloat16_test.cc69
-rw-r--r--tensorflow/core/framework/cancellation.cc79
-rw-r--r--tensorflow/core/framework/cancellation.h121
-rw-r--r--tensorflow/core/framework/cancellation_test.cc102
-rw-r--r--tensorflow/core/framework/config.proto61
-rw-r--r--tensorflow/core/framework/control_flow.h43
-rw-r--r--tensorflow/core/framework/device_attributes.proto35
-rw-r--r--tensorflow/core/framework/device_base.cc7
-rw-r--r--tensorflow/core/framework/device_base.h172
-rw-r--r--tensorflow/core/framework/fake_input.cc214
-rw-r--r--tensorflow/core/framework/fake_input.h25
-rw-r--r--tensorflow/core/framework/function.cc878
-rw-r--r--tensorflow/core/framework/function.h376
-rw-r--r--tensorflow/core/framework/function.proto68
-rw-r--r--tensorflow/core/framework/function_test.cc634
-rw-r--r--tensorflow/core/framework/function_testlib.cc146
-rw-r--r--tensorflow/core/framework/function_testlib.h53
-rw-r--r--tensorflow/core/framework/graph.proto103
-rw-r--r--tensorflow/core/framework/graph_def_util.cc25
-rw-r--r--tensorflow/core/framework/graph_def_util.h29
-rw-r--r--tensorflow/core/framework/kernel_def.proto33
-rw-r--r--tensorflow/core/framework/kernel_def_builder.cc47
-rw-r--r--tensorflow/core/framework/kernel_def_builder.h77
-rw-r--r--tensorflow/core/framework/kernel_def_builder_test.cc76
-rw-r--r--tensorflow/core/framework/lookup_interface.cc45
-rw-r--r--tensorflow/core/framework/lookup_interface.h65
-rw-r--r--tensorflow/core/framework/node_def_builder.cc194
-rw-r--r--tensorflow/core/framework/node_def_builder.h176
-rw-r--r--tensorflow/core/framework/node_def_builder_test.cc1036
-rw-r--r--tensorflow/core/framework/node_def_util.cc414
-rw-r--r--tensorflow/core/framework/node_def_util.h157
-rw-r--r--tensorflow/core/framework/node_def_util_test.cc442
-rw-r--r--tensorflow/core/framework/numeric_op.h96
-rw-r--r--tensorflow/core/framework/numeric_types.h15
-rw-r--r--tensorflow/core/framework/op.cc135
-rw-r--r--tensorflow/core/framework/op.h122
-rw-r--r--tensorflow/core/framework/op_def.proto142
-rw-r--r--tensorflow/core/framework/op_def_builder.cc447
-rw-r--r--tensorflow/core/framework/op_def_builder.h109
-rw-r--r--tensorflow/core/framework/op_def_builder_test.cc519
-rw-r--r--tensorflow/core/framework/op_def_util.cc344
-rw-r--r--tensorflow/core/framework/op_def_util.h32
-rw-r--r--tensorflow/core/framework/op_def_util_test.cc330
-rw-r--r--tensorflow/core/framework/op_gen_lib.cc55
-rw-r--r--tensorflow/core/framework/op_gen_lib.h24
-rw-r--r--tensorflow/core/framework/op_kernel.cc749
-rw-r--r--tensorflow/core/framework/op_kernel.h1250
-rw-r--r--tensorflow/core/framework/op_kernel_test.cc803
-rw-r--r--tensorflow/core/framework/op_segment.cc86
-rw-r--r--tensorflow/core/framework/op_segment.h67
-rw-r--r--tensorflow/core/framework/op_segment_test.cc142
-rw-r--r--tensorflow/core/framework/queue_interface.h77
-rw-r--r--tensorflow/core/framework/reader_interface.h66
-rw-r--r--tensorflow/core/framework/reader_op_kernel.cc39
-rw-r--r--tensorflow/core/framework/reader_op_kernel.h42
-rw-r--r--tensorflow/core/framework/register_types.h90
-rw-r--r--tensorflow/core/framework/rendezvous.cc263
-rw-r--r--tensorflow/core/framework/rendezvous.h102
-rw-r--r--tensorflow/core/framework/rendezvous_test.cc314
-rw-r--r--tensorflow/core/framework/resource_mgr.cc146
-rw-r--r--tensorflow/core/framework/resource_mgr.h280
-rw-r--r--tensorflow/core/framework/resource_mgr_test.cc173
-rw-r--r--tensorflow/core/framework/step_stats.proto58
-rw-r--r--tensorflow/core/framework/summary.proto67
-rw-r--r--tensorflow/core/framework/tensor.cc570
-rw-r--r--tensorflow/core/framework/tensor.proto57
-rw-r--r--tensorflow/core/framework/tensor_description.proto19
-rw-r--r--tensorflow/core/framework/tensor_shape.cc138
-rw-r--r--tensorflow/core/framework/tensor_shape.proto29
-rw-r--r--tensorflow/core/framework/tensor_shape_test.cc75
-rw-r--r--tensorflow/core/framework/tensor_slice.cc226
-rw-r--r--tensorflow/core/framework/tensor_slice.h189
-rw-r--r--tensorflow/core/framework/tensor_slice.proto34
-rw-r--r--tensorflow/core/framework/tensor_slice_test.cc246
-rw-r--r--tensorflow/core/framework/tensor_test.cc551
-rw-r--r--tensorflow/core/framework/tensor_testutil.cc43
-rw-r--r--tensorflow/core/framework/tensor_testutil.h189
-rw-r--r--tensorflow/core/framework/tensor_types.h92
-rw-r--r--tensorflow/core/framework/tensor_util.cc28
-rw-r--r--tensorflow/core/framework/tensor_util.h21
-rw-r--r--tensorflow/core/framework/tensor_util_test.cc124
-rw-r--r--tensorflow/core/framework/tracking_allocator.cc100
-rw-r--r--tensorflow/core/framework/tracking_allocator.h80
-rw-r--r--tensorflow/core/framework/tracking_allocator_test.cc115
-rw-r--r--tensorflow/core/framework/type_traits.h69
-rw-r--r--tensorflow/core/framework/types.cc210
-rw-r--r--tensorflow/core/framework/types.h168
-rw-r--r--tensorflow/core/framework/types.proto48
-rw-r--r--tensorflow/core/framework/types_test.cc117
-rw-r--r--tensorflow/core/graph/algorithm.cc107
-rw-r--r--tensorflow/core/graph/algorithm.h40
-rw-r--r--tensorflow/core/graph/algorithm_test.cc103
-rw-r--r--tensorflow/core/graph/colors.cc25
-rw-r--r--tensorflow/core/graph/colors.h14
-rw-r--r--tensorflow/core/graph/costmodel.cc308
-rw-r--r--tensorflow/core/graph/costmodel.h123
-rw-r--r--tensorflow/core/graph/costutil.cc22
-rw-r--r--tensorflow/core/graph/costutil.h19
-rw-r--r--tensorflow/core/graph/default_device.h25
-rw-r--r--tensorflow/core/graph/dot.cc289
-rw-r--r--tensorflow/core/graph/dot.h43
-rw-r--r--tensorflow/core/graph/edgeset.cc56
-rw-r--r--tensorflow/core/graph/edgeset.h216
-rw-r--r--tensorflow/core/graph/edgeset_test.cc95
-rw-r--r--tensorflow/core/graph/equal_graph_def.cc176
-rw-r--r--tensorflow/core/graph/equal_graph_def.h32
-rw-r--r--tensorflow/core/graph/equal_graph_def_test.cc279
-rw-r--r--tensorflow/core/graph/graph.cc319
-rw-r--r--tensorflow/core/graph/graph.h440
-rw-r--r--tensorflow/core/graph/graph_constructor.cc385
-rw-r--r--tensorflow/core/graph/graph_constructor.h43
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc190
-rw-r--r--tensorflow/core/graph/graph_def_builder.cc121
-rw-r--r--tensorflow/core/graph/graph_def_builder.h181
-rw-r--r--tensorflow/core/graph/graph_partition.cc1050
-rw-r--r--tensorflow/core/graph/graph_partition.h77
-rw-r--r--tensorflow/core/graph/graph_partition_test.cc316
-rw-r--r--tensorflow/core/graph/graph_test.cc252
-rw-r--r--tensorflow/core/graph/node_builder.cc115
-rw-r--r--tensorflow/core/graph/node_builder.h146
-rw-r--r--tensorflow/core/graph/node_builder_test.cc59
-rw-r--r--tensorflow/core/graph/optimizer_cse.cc220
-rw-r--r--tensorflow/core/graph/optimizer_cse.h19
-rw-r--r--tensorflow/core/graph/optimizer_cse_test.cc365
-rw-r--r--tensorflow/core/graph/subgraph.cc258
-rw-r--r--tensorflow/core/graph/subgraph.h49
-rw-r--r--tensorflow/core/graph/subgraph_test.cc305
-rw-r--r--tensorflow/core/graph/tensor_id.cc41
-rw-r--r--tensorflow/core/graph/tensor_id.h28
-rw-r--r--tensorflow/core/graph/tensor_id_test.cc77
-rw-r--r--tensorflow/core/graph/testlib.cc299
-rw-r--r--tensorflow/core/graph/testlib.h141
-rw-r--r--tensorflow/core/graph/types.h17
-rw-r--r--tensorflow/core/kernels/adjust_contrast_op.cc121
-rw-r--r--tensorflow/core/kernels/adjust_contrast_op.h64
-rw-r--r--tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc43
-rw-r--r--tensorflow/core/kernels/adjust_contrast_op_gpu.cu.cc22
-rw-r--r--tensorflow/core/kernels/adjust_contrast_op_test.cc88
-rw-r--r--tensorflow/core/kernels/aggregate_ops.cc238
-rw-r--r--tensorflow/core/kernels/aggregate_ops.h211
-rw-r--r--tensorflow/core/kernels/aggregate_ops_gpu.cu.cc141
-rw-r--r--tensorflow/core/kernels/argmax_op.cc163
-rw-r--r--tensorflow/core/kernels/argmax_op.h55
-rw-r--r--tensorflow/core/kernels/argmax_op_gpu.cu.cc20
-rw-r--r--tensorflow/core/kernels/assign_op.h92
-rw-r--r--tensorflow/core/kernels/attention_ops.cc92
-rw-r--r--tensorflow/core/kernels/avgpooling_op.cc418
-rw-r--r--tensorflow/core/kernels/avgpooling_op.h58
-rw-r--r--tensorflow/core/kernels/avgpooling_op_gpu.cu.cc101
-rw-r--r--tensorflow/core/kernels/batch_matmul_op.cc260
-rw-r--r--tensorflow/core/kernels/batch_norm_op.cc223
-rw-r--r--tensorflow/core/kernels/batch_norm_op.h133
-rw-r--r--tensorflow/core/kernels/batch_norm_op_gpu.cu.cc17
-rw-r--r--tensorflow/core/kernels/bcast_ops.cc71
-rw-r--r--tensorflow/core/kernels/bias_op.cc112
-rw-r--r--tensorflow/core/kernels/bias_op.h41
-rw-r--r--tensorflow/core/kernels/bias_op_gpu.cu.cc23
-rw-r--r--tensorflow/core/kernels/candidate_sampler_ops.cc243
-rw-r--r--tensorflow/core/kernels/cast_op.cc233
-rw-r--r--tensorflow/core/kernels/cast_op.h71
-rw-r--r--tensorflow/core/kernels/cast_op_gpu.cu.cc45
-rw-r--r--tensorflow/core/kernels/cast_op_test.cc100
-rw-r--r--tensorflow/core/kernels/check_numerics_op.cc190
-rw-r--r--tensorflow/core/kernels/check_numerics_op_gpu.cu.cc62
-rw-r--r--tensorflow/core/kernels/cholesky_op.cc71
-rw-r--r--tensorflow/core/kernels/concat_op.cc153
-rw-r--r--tensorflow/core/kernels/concat_op.h27
-rw-r--r--tensorflow/core/kernels/concat_op_cpu.cc122
-rw-r--r--tensorflow/core/kernels/concat_op_gpu.cu.cc41
-rw-r--r--tensorflow/core/kernels/concat_op_test.cc240
-rw-r--r--tensorflow/core/kernels/constant_op.cc249
-rw-r--r--tensorflow/core/kernels/constant_op.h25
-rw-r--r--tensorflow/core/kernels/constant_op_gpu.cu.cc89
-rw-r--r--tensorflow/core/kernels/constant_op_test.cc43
-rw-r--r--tensorflow/core/kernels/control_flow_ops.cc359
-rw-r--r--tensorflow/core/kernels/control_flow_ops.h22
-rw-r--r--tensorflow/core/kernels/control_flow_ops_test.cc71
-rw-r--r--tensorflow/core/kernels/conv_2d.h127
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.cc1190
-rw-r--r--tensorflow/core/kernels/conv_ops.cc373
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.cu.cc35
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_2.cu.cc16
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc22
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_matmul.cu.cc16
-rw-r--r--tensorflow/core/kernels/core_ops_test.cc990
-rw-r--r--tensorflow/core/kernels/count_up_to_op.cc51
-rw-r--r--tensorflow/core/kernels/cwise_op_abs.cc23
-rw-r--r--tensorflow/core/kernels/cwise_op_add.cc21
-rw-r--r--tensorflow/core/kernels/cwise_op_ceil.cc8
-rw-r--r--tensorflow/core/kernels/cwise_op_complex.cc10
-rw-r--r--tensorflow/core/kernels/cwise_op_conj.cc10
-rw-r--r--tensorflow/core/kernels/cwise_op_cos.cc8
-rw-r--r--tensorflow/core/kernels/cwise_op_div.cc21
-rw-r--r--tensorflow/core/kernels/cwise_op_equal_to.cc21
-rw-r--r--tensorflow/core/kernels/cwise_op_exp.cc8
-rw-r--r--tensorflow/core/kernels/cwise_op_floor.cc8
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_abs.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_add.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_ceil.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_conj.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_cos.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_div.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_exp.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_floor.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_greater.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_greater_equal.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_isfinite.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_isinf.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_isnan.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_less.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_less_equal.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_log.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_logical_and.cu.cc13
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_logical_not.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_logical_or.cu.cc13
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_maximum.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_minimum.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_mod.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_neg.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_pow.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_real.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_select.cu.cc15
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_sign.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_sin.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_square.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_sub.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc11
-rw-r--r--tensorflow/core/kernels/cwise_op_greater.cc21
-rw-r--r--tensorflow/core/kernels/cwise_op_greater_equal.cc22
-rw-r--r--tensorflow/core/kernels/cwise_op_imag.cc10
-rw-r--r--tensorflow/core/kernels/cwise_op_inverse.cc8
-rw-r--r--tensorflow/core/kernels/cwise_op_isfinite.cc8
-rw-r--r--tensorflow/core/kernels/cwise_op_isinf.cc8
-rw-r--r--tensorflow/core/kernels/cwise_op_isnan.cc8
-rw-r--r--tensorflow/core/kernels/cwise_op_less.cc20
-rw-r--r--tensorflow/core/kernels/cwise_op_less_equal.cc22
-rw-r--r--tensorflow/core/kernels/cwise_op_log.cc8
-rw-r--r--tensorflow/core/kernels/cwise_op_logical_and.cc10
-rw-r--r--tensorflow/core/kernels/cwise_op_logical_not.cc10
-rw-r--r--tensorflow/core/kernels/cwise_op_logical_or.cc10
-rw-r--r--tensorflow/core/kernels/cwise_op_maximum.cc21
-rw-r--r--tensorflow/core/kernels/cwise_op_minimum.cc21
-rw-r--r--tensorflow/core/kernels/cwise_op_mod.cc6
-rw-r--r--tensorflow/core/kernels/cwise_op_mul.cc21
-rw-r--r--tensorflow/core/kernels/cwise_op_neg.cc9
-rw-r--r--tensorflow/core/kernels/cwise_op_not_equal_to.cc10
-rw-r--r--tensorflow/core/kernels/cwise_op_pow.cc9
-rw-r--r--tensorflow/core/kernels/cwise_op_real.cc10
-rw-r--r--tensorflow/core/kernels/cwise_op_rsqrt.cc8
-rw-r--r--tensorflow/core/kernels/cwise_op_select.cc17
-rw-r--r--tensorflow/core/kernels/cwise_op_sigmoid.cc8
-rw-r--r--tensorflow/core/kernels/cwise_op_sign.cc19
-rw-r--r--tensorflow/core/kernels/cwise_op_sin.cc8
-rw-r--r--tensorflow/core/kernels/cwise_op_sqrt.cc8
-rw-r--r--tensorflow/core/kernels/cwise_op_square.cc9
-rw-r--r--tensorflow/core/kernels/cwise_op_sub.cc21
-rw-r--r--tensorflow/core/kernels/cwise_op_tanh.cc8
-rw-r--r--tensorflow/core/kernels/cwise_ops.h607
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.cc42
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.h390
-rw-r--r--tensorflow/core/kernels/cwise_ops_gpu_common.cu.h135
-rw-r--r--tensorflow/core/kernels/cwise_ops_test.cc167
-rw-r--r--tensorflow/core/kernels/decode_csv_op.cc222
-rw-r--r--tensorflow/core/kernels/decode_jpeg_op.cc72
-rw-r--r--tensorflow/core/kernels/decode_png_op.cc69
-rw-r--r--tensorflow/core/kernels/decode_raw_op.cc90
-rw-r--r--tensorflow/core/kernels/dense_update_ops.cc136
-rw-r--r--tensorflow/core/kernels/dense_update_ops.h43
-rw-r--r--tensorflow/core/kernels/dense_update_ops_gpu.cu.cc22
-rw-r--r--tensorflow/core/kernels/determinant_op.cc66
-rw-r--r--tensorflow/core/kernels/diag_op.cc93
-rw-r--r--tensorflow/core/kernels/dynamic_partition_op.cc154
-rw-r--r--tensorflow/core/kernels/dynamic_partition_op_test.cc145
-rw-r--r--tensorflow/core/kernels/dynamic_stitch_op.cc158
-rw-r--r--tensorflow/core/kernels/dynamic_stitch_op_test.cc133
-rw-r--r--tensorflow/core/kernels/edit_distance_op.cc217
-rw-r--r--tensorflow/core/kernels/encode_jpeg_op.cc114
-rw-r--r--tensorflow/core/kernels/encode_png_op.cc52
-rw-r--r--tensorflow/core/kernels/example_parsing_ops.cc444
-rw-r--r--tensorflow/core/kernels/fact_op.cc96
-rw-r--r--tensorflow/core/kernels/fifo_queue.cc518
-rw-r--r--tensorflow/core/kernels/fifo_queue.h127
-rw-r--r--tensorflow/core/kernels/fifo_queue_op.cc93
-rw-r--r--tensorflow/core/kernels/fill_functor.h26
-rw-r--r--tensorflow/core/kernels/fixed_length_record_reader_op.cc109
-rw-r--r--tensorflow/core/kernels/gather_op.cc136
-rw-r--r--tensorflow/core/kernels/gather_op_test.cc213
-rw-r--r--tensorflow/core/kernels/identity_op.cc45
-rw-r--r--tensorflow/core/kernels/identity_op.h25
-rw-r--r--tensorflow/core/kernels/identity_op_test.cc56
-rw-r--r--tensorflow/core/kernels/identity_reader_op.cc57
-rw-r--r--tensorflow/core/kernels/in_topk_op.cc58
-rw-r--r--tensorflow/core/kernels/initializable_lookup_table.cc41
-rw-r--r--tensorflow/core/kernels/initializable_lookup_table.h103
-rw-r--r--tensorflow/core/kernels/io.cc270
-rw-r--r--tensorflow/core/kernels/io.h38
-rw-r--r--tensorflow/core/kernels/l2loss_op.cc69
-rw-r--r--tensorflow/core/kernels/l2loss_op.h24
-rw-r--r--tensorflow/core/kernels/l2loss_op_gpu.cu.cc16
-rw-r--r--tensorflow/core/kernels/linalg_ops_common.cc99
-rw-r--r--tensorflow/core/kernels/linalg_ops_common.h123
-rw-r--r--tensorflow/core/kernels/listdiff_op.cc75
-rw-r--r--tensorflow/core/kernels/logging_ops.cc77
-rw-r--r--tensorflow/core/kernels/logging_ops_test.cc87
-rw-r--r--tensorflow/core/kernels/lookup_table_init_op.cc116
-rw-r--r--tensorflow/core/kernels/lookup_table_op.cc166
-rw-r--r--tensorflow/core/kernels/lookup_table_op.h80
-rw-r--r--tensorflow/core/kernels/lookup_util.cc72
-rw-r--r--tensorflow/core/kernels/lookup_util.h31
-rw-r--r--tensorflow/core/kernels/lrn_op.cc228
-rw-r--r--tensorflow/core/kernels/lrn_op_test.cc185
-rw-r--r--tensorflow/core/kernels/matching_files_op.cc42
-rw-r--r--tensorflow/core/kernels/matmul_op.cc214
-rw-r--r--tensorflow/core/kernels/matmul_op.h40
-rw-r--r--tensorflow/core/kernels/matmul_op_gpu.cu.cc32
-rw-r--r--tensorflow/core/kernels/matmul_op_test.cc56
-rw-r--r--tensorflow/core/kernels/matrix_inverse_op.cc64
-rw-r--r--tensorflow/core/kernels/maxpooling_op.cc554
-rw-r--r--tensorflow/core/kernels/maxpooling_op.h29
-rw-r--r--tensorflow/core/kernels/maxpooling_op_gpu.cu.cc261
-rw-r--r--tensorflow/core/kernels/maxpooling_op_gpu.h42
-rw-r--r--tensorflow/core/kernels/no_op.cc8
-rw-r--r--tensorflow/core/kernels/no_op.h17
-rw-r--r--tensorflow/core/kernels/ops_testutil.cc18
-rw-r--r--tensorflow/core/kernels/ops_testutil.h191
-rw-r--r--tensorflow/core/kernels/ops_util.cc113
-rw-r--r--tensorflow/core/kernels/ops_util.h180
-rw-r--r--tensorflow/core/kernels/ops_util_test.cc265
-rw-r--r--tensorflow/core/kernels/pack_op.cc114
-rw-r--r--tensorflow/core/kernels/pad_op.cc159
-rw-r--r--tensorflow/core/kernels/pad_op.h27
-rw-r--r--tensorflow/core/kernels/pad_op_gpu.cu.cc26
-rw-r--r--tensorflow/core/kernels/pooling_ops_common.cc252
-rw-r--r--tensorflow/core/kernels/pooling_ops_common.h264
-rw-r--r--tensorflow/core/kernels/pooling_ops_common_gpu.h39
-rw-r--r--tensorflow/core/kernels/queue_base.cc153
-rw-r--r--tensorflow/core/kernels/queue_base.h77
-rw-r--r--tensorflow/core/kernels/queue_ops.cc288
-rw-r--r--tensorflow/core/kernels/random_crop_op.cc103
-rw-r--r--tensorflow/core/kernels/random_crop_op_test.cc60
-rw-r--r--tensorflow/core/kernels/random_op.cc276
-rw-r--r--tensorflow/core/kernels/random_op.h16
-rw-r--r--tensorflow/core/kernels/random_op_gpu.cu.cc152
-rw-r--r--tensorflow/core/kernels/random_op_test.cc99
-rw-r--r--tensorflow/core/kernels/random_shuffle_op.cc89
-rw-r--r--tensorflow/core/kernels/random_shuffle_queue_op.cc740
-rw-r--r--tensorflow/core/kernels/range_sampler.cc305
-rw-r--r--tensorflow/core/kernels/range_sampler.h237
-rw-r--r--tensorflow/core/kernels/range_sampler_test.cc320
-rw-r--r--tensorflow/core/kernels/reader_base.cc156
-rw-r--r--tensorflow/core/kernels/reader_base.h107
-rw-r--r--tensorflow/core/kernels/reader_base.proto13
-rw-r--r--tensorflow/core/kernels/reader_ops.cc132
-rw-r--r--tensorflow/core/kernels/reduction_ops.h66
-rw-r--r--tensorflow/core/kernels/reduction_ops_all.cc17
-rw-r--r--tensorflow/core/kernels/reduction_ops_any.cc17
-rw-r--r--tensorflow/core/kernels/reduction_ops_common.h302
-rw-r--r--tensorflow/core/kernels/reduction_ops_gpu.cu.cc65
-rw-r--r--tensorflow/core/kernels/reduction_ops_max.cc26
-rw-r--r--tensorflow/core/kernels/reduction_ops_mean.cc12
-rw-r--r--tensorflow/core/kernels/reduction_ops_min.cc26
-rw-r--r--tensorflow/core/kernels/reduction_ops_prod.cc26
-rw-r--r--tensorflow/core/kernels/reduction_ops_sum.cc37
-rw-r--r--tensorflow/core/kernels/reduction_ops_test.cc73
-rw-r--r--tensorflow/core/kernels/reference_gemm.h75
-rw-r--r--tensorflow/core/kernels/relu_op.cc154
-rw-r--r--tensorflow/core/kernels/relu_op.h79
-rw-r--r--tensorflow/core/kernels/relu_op_gpu.cu.cc27
-rw-r--r--tensorflow/core/kernels/reshape_op.cc29
-rw-r--r--tensorflow/core/kernels/reshape_op.h83
-rw-r--r--tensorflow/core/kernels/resize_area_op.cc139
-rw-r--r--tensorflow/core/kernels/resize_bicubic_op.cc121
-rw-r--r--tensorflow/core/kernels/resize_bilinear_op.cc109
-rw-r--r--tensorflow/core/kernels/resize_bilinear_op_test.cc171
-rw-r--r--tensorflow/core/kernels/resize_nearest_neighbor_op.cc89
-rw-r--r--tensorflow/core/kernels/resize_nearest_neighbor_op_test.cc163
-rw-r--r--tensorflow/core/kernels/restore_op.cc65
-rw-r--r--tensorflow/core/kernels/restore_op_test.cc305
-rw-r--r--tensorflow/core/kernels/reverse_op.cc139
-rw-r--r--tensorflow/core/kernels/reverse_op.h28
-rw-r--r--tensorflow/core/kernels/reverse_op_gpu.cu.cc33
-rw-r--r--tensorflow/core/kernels/reverse_op_test.cc101
-rw-r--r--tensorflow/core/kernels/reverse_sequence_op.cc170
-rw-r--r--tensorflow/core/kernels/reverse_sequence_op.h56
-rw-r--r--tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc26
-rw-r--r--tensorflow/core/kernels/save_op.cc81
-rw-r--r--tensorflow/core/kernels/save_op_test.cc443
-rw-r--r--tensorflow/core/kernels/scatter_op.cc167
-rw-r--r--tensorflow/core/kernels/scatter_op_test.cc255
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.cc466
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops_test.cc157
-rw-r--r--tensorflow/core/kernels/sendrecv_ops.cc116
-rw-r--r--tensorflow/core/kernels/sendrecv_ops.h32
-rw-r--r--tensorflow/core/kernels/sequence_ops.cc123
-rw-r--r--tensorflow/core/kernels/shape_ops.cc261
-rw-r--r--tensorflow/core/kernels/slice_op.cc242
-rw-r--r--tensorflow/core/kernels/slice_op.h25
-rw-r--r--tensorflow/core/kernels/slice_op_gpu.cu.cc31
-rw-r--r--tensorflow/core/kernels/slice_op_test.cc73
-rw-r--r--tensorflow/core/kernels/softmax_op.cc62
-rw-r--r--tensorflow/core/kernels/softmax_op.h70
-rw-r--r--tensorflow/core/kernels/softmax_op_gpu.cu.cc31
-rw-r--r--tensorflow/core/kernels/softplus_op.cc97
-rw-r--r--tensorflow/core/kernels/softplus_op.h46
-rw-r--r--tensorflow/core/kernels/softplus_op_gpu.cu.cc25
-rw-r--r--tensorflow/core/kernels/sparse_concat_op.cc139
-rw-r--r--tensorflow/core/kernels/sparse_matmul_op.cc192
-rw-r--r--tensorflow/core/kernels/sparse_matmul_op_test.cc139
-rw-r--r--tensorflow/core/kernels/sparse_reorder_op.cc71
-rw-r--r--tensorflow/core/kernels/sparse_to_dense_op.cc129
-rw-r--r--tensorflow/core/kernels/sparse_to_dense_op_test.cc283
-rw-r--r--tensorflow/core/kernels/split_op.cc146
-rw-r--r--tensorflow/core/kernels/split_op.h31
-rw-r--r--tensorflow/core/kernels/split_op_cpu.cc30
-rw-r--r--tensorflow/core/kernels/split_op_gpu.cu.cc31
-rw-r--r--tensorflow/core/kernels/string_to_hash_bucket_op.cc47
-rw-r--r--tensorflow/core/kernels/string_to_number_op.cc71
-rw-r--r--tensorflow/core/kernels/summary_image_op.cc169
-rw-r--r--tensorflow/core/kernels/summary_image_op_test.cc141
-rw-r--r--tensorflow/core/kernels/summary_op.cc141
-rw-r--r--tensorflow/core/kernels/summary_op_test.cc282
-rw-r--r--tensorflow/core/kernels/text_line_reader_op.cc99
-rw-r--r--tensorflow/core/kernels/tf_record_reader_op.cc76
-rw-r--r--tensorflow/core/kernels/tile_ops.cc460
-rw-r--r--tensorflow/core/kernels/tile_ops.h48
-rw-r--r--tensorflow/core/kernels/tile_ops_gpu.cu.cc38
-rw-r--r--tensorflow/core/kernels/topk_op.cc71
-rw-r--r--tensorflow/core/kernels/training_ops.cc884
-rw-r--r--tensorflow/core/kernels/training_ops.h65
-rw-r--r--tensorflow/core/kernels/training_ops_gpu.cu.cc127
-rw-r--r--tensorflow/core/kernels/training_ops_test.cc226
-rw-r--r--tensorflow/core/kernels/transpose_op.cc190
-rw-r--r--tensorflow/core/kernels/transpose_op.h19
-rw-r--r--tensorflow/core/kernels/transpose_op_functor.h28
-rw-r--r--tensorflow/core/kernels/transpose_op_gpu.cu.cc43
-rw-r--r--tensorflow/core/kernels/unique_op.cc61
-rw-r--r--tensorflow/core/kernels/unique_op_test.cc51
-rw-r--r--tensorflow/core/kernels/unpack_op.cc96
-rw-r--r--tensorflow/core/kernels/variable_ops.cc37
-rw-r--r--tensorflow/core/kernels/variable_ops.h146
-rw-r--r--tensorflow/core/kernels/where_op.cc74
-rw-r--r--tensorflow/core/kernels/where_op.h65
-rw-r--r--tensorflow/core/kernels/whole_file_read_ops.cc108
-rw-r--r--tensorflow/core/kernels/xent_op.cc90
-rw-r--r--tensorflow/core/kernels/xent_op.h102
-rw-r--r--tensorflow/core/kernels/xent_op_gpu.cu.cc35
-rw-r--r--tensorflow/core/kernels/xent_op_test.cc46
-rw-r--r--tensorflow/core/lib/core/arena.cc246
-rw-r--r--tensorflow/core/lib/core/arena.h90
-rw-r--r--tensorflow/core/lib/core/arena_test.cc92
-rw-r--r--tensorflow/core/lib/core/bit_cast_test.cc95
-rw-r--r--tensorflow/core/lib/core/bits.h84
-rw-r--r--tensorflow/core/lib/core/blocking_counter.h41
-rw-r--r--tensorflow/core/lib/core/blocking_counter_test.cc36
-rw-r--r--tensorflow/core/lib/core/casts.h85
-rw-r--r--tensorflow/core/lib/core/coding.cc164
-rw-r--r--tensorflow/core/lib/core/coding.h55
-rw-r--r--tensorflow/core/lib/core/coding_test.cc168
-rw-r--r--tensorflow/core/lib/core/command_line_flags.cc94
-rw-r--r--tensorflow/core/lib/core/command_line_flags.h60
-rw-r--r--tensorflow/core/lib/core/error_codes.proto145
-rw-r--r--tensorflow/core/lib/core/errors.h131
-rw-r--r--tensorflow/core/lib/core/notification.h42
-rw-r--r--tensorflow/core/lib/core/notification_test.cc64
-rw-r--r--tensorflow/core/lib/core/raw_coding.h43
-rw-r--r--tensorflow/core/lib/core/refcount.cc35
-rw-r--r--tensorflow/core/lib/core/refcount.h63
-rw-r--r--tensorflow/core/lib/core/refcount_test.cc92
-rw-r--r--tensorflow/core/lib/core/status.cc107
-rw-r--r--tensorflow/core/lib/core/status_test.cc84
-rw-r--r--tensorflow/core/lib/core/status_test_util.h20
-rw-r--r--tensorflow/core/lib/core/stringpiece.cc57
-rw-r--r--tensorflow/core/lib/core/stringpiece.h159
-rw-r--r--tensorflow/core/lib/core/threadpool.cc108
-rw-r--r--tensorflow/core/lib/core/threadpool.h59
-rw-r--r--tensorflow/core/lib/core/threadpool_test.cc93
-rw-r--r--tensorflow/core/lib/gtl/array_slice.h299
-rw-r--r--tensorflow/core/lib/gtl/array_slice_internal.h253
-rw-r--r--tensorflow/core/lib/gtl/array_slice_test.cc646
-rw-r--r--tensorflow/core/lib/gtl/edit_distance.h82
-rw-r--r--tensorflow/core/lib/gtl/edit_distance_test.cc125
-rw-r--r--tensorflow/core/lib/gtl/inlined_vector.h839
-rw-r--r--tensorflow/core/lib/gtl/inlined_vector_test.cc905
-rw-r--r--tensorflow/core/lib/gtl/int_type.h343
-rw-r--r--tensorflow/core/lib/gtl/int_type_test.cc282
-rw-r--r--tensorflow/core/lib/gtl/iterator_range.h49
-rw-r--r--tensorflow/core/lib/gtl/iterator_range_test.cc60
-rw-r--r--tensorflow/core/lib/gtl/manual_constructor.h230
-rw-r--r--tensorflow/core/lib/gtl/manual_constructor_test.cc113
-rw-r--r--tensorflow/core/lib/gtl/map_util.h123
-rw-r--r--tensorflow/core/lib/gtl/map_util_test.cc47
-rw-r--r--tensorflow/core/lib/gtl/stl_util.h130
-rw-r--r--tensorflow/core/lib/gtl/top_n.h324
-rw-r--r--tensorflow/core/lib/gtl/top_n_test.cc249
-rw-r--r--tensorflow/core/lib/hash/crc32c.cc244
-rw-r--r--tensorflow/core/lib/hash/crc32c.h39
-rw-r--r--tensorflow/core/lib/hash/crc32c_test.cc51
-rw-r--r--tensorflow/core/lib/hash/hash.cc113
-rw-r--r--tensorflow/core/lib/hash/hash.h28
-rw-r--r--tensorflow/core/lib/hash/hash_test.cc64
-rw-r--r--tensorflow/core/lib/histogram/histogram.cc247
-rw-r--r--tensorflow/core/lib/histogram/histogram.h119
-rw-r--r--tensorflow/core/lib/histogram/histogram_test.cc112
-rw-r--r--tensorflow/core/lib/io/block.cc236
-rw-r--r--tensorflow/core/lib/io/block.h45
-rw-r--r--tensorflow/core/lib/io/block_builder.cc107
-rw-r--r--tensorflow/core/lib/io/block_builder.h57
-rw-r--r--tensorflow/core/lib/io/format.cc148
-rw-r--r--tensorflow/core/lib/io/format.h99
-rw-r--r--tensorflow/core/lib/io/inputbuffer.cc112
-rw-r--r--tensorflow/core/lib/io/inputbuffer.h62
-rw-r--r--tensorflow/core/lib/io/inputbuffer_test.cc174
-rw-r--r--tensorflow/core/lib/io/iterator.cc72
-rw-r--r--tensorflow/core/lib/io/iterator.h93
-rw-r--r--tensorflow/core/lib/io/match.cc31
-rw-r--r--tensorflow/core/lib/io/match.h24
-rw-r--r--tensorflow/core/lib/io/match_test.cc51
-rw-r--r--tensorflow/core/lib/io/path.cc92
-rw-r--r--tensorflow/core/lib/io/path.h47
-rw-r--r--tensorflow/core/lib/io/path_test.cc65
-rw-r--r--tensorflow/core/lib/io/record_reader.cc80
-rw-r--r--tensorflow/core/lib/io/record_reader.h36
-rw-r--r--tensorflow/core/lib/io/record_writer.cc42
-rw-r--r--tensorflow/core/lib/io/record_writer.h34
-rw-r--r--tensorflow/core/lib/io/recordio_test.cc245
-rw-r--r--tensorflow/core/lib/io/table.cc169
-rw-r--r--tensorflow/core/lib/io/table.h76
-rw-r--r--tensorflow/core/lib/io/table_builder.cc263
-rw-r--r--tensorflow/core/lib/io/table_builder.h87
-rw-r--r--tensorflow/core/lib/io/table_format.txt8
-rw-r--r--tensorflow/core/lib/io/table_options.h53
-rw-r--r--tensorflow/core/lib/io/table_test.cc601
-rw-r--r--tensorflow/core/lib/io/two_level_iterator.cc148
-rw-r--r--tensorflow/core/lib/io/two_level_iterator.h30
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_handle.cc162
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_handle.h51
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_mem.cc557
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_mem.h130
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc304
-rw-r--r--tensorflow/core/lib/jpeg/testdata/bad_huffman.jpgbin0 -> 15416 bytes
-rw-r--r--tensorflow/core/lib/jpeg/testdata/corrupt.jpgbin0 -> 1552 bytes
-rw-r--r--tensorflow/core/lib/jpeg/testdata/corrupt34_2.jpgbin0 -> 755 bytes
-rw-r--r--tensorflow/core/lib/jpeg/testdata/corrupt34_3.jpgbin0 -> 5505 bytes
-rw-r--r--tensorflow/core/lib/jpeg/testdata/corrupt34_4.jpgbin0 -> 5092 bytes
-rw-r--r--tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1.jpgbin0 -> 3771 bytes
-rw-r--r--tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpgbin0 -> 5324 bytes
-rw-r--r--tensorflow/core/lib/png/png_io.cc385
-rw-r--r--tensorflow/core/lib/png/png_io.h88
-rw-r--r--tensorflow/core/lib/png/testdata/lena_gray.pngbin0 -> 1491 bytes
-rw-r--r--tensorflow/core/lib/png/testdata/lena_rgba.pngbin0 -> 4032 bytes
-rw-r--r--tensorflow/core/lib/random/distribution_sampler.cc80
-rw-r--r--tensorflow/core/lib/random/distribution_sampler.h79
-rw-r--r--tensorflow/core/lib/random/distribution_sampler_test.cc90
-rw-r--r--tensorflow/core/lib/random/exact_uniform_int.h68
-rw-r--r--tensorflow/core/lib/random/philox_random.h232
-rw-r--r--tensorflow/core/lib/random/philox_random_test.cc58
-rw-r--r--tensorflow/core/lib/random/philox_random_test_utils.h36
-rw-r--r--tensorflow/core/lib/random/random.cc22
-rw-r--r--tensorflow/core/lib/random/random.h16
-rw-r--r--tensorflow/core/lib/random/random_distributions.h361
-rw-r--r--tensorflow/core/lib/random/random_distributions_test.cc270
-rw-r--r--tensorflow/core/lib/random/random_test.cc21
-rw-r--r--tensorflow/core/lib/random/simple_philox.cc24
-rw-r--r--tensorflow/core/lib/random/simple_philox.h61
-rw-r--r--tensorflow/core/lib/random/simple_philox_test.cc120
-rw-r--r--tensorflow/core/lib/random/weighted_picker.cc203
-rw-r--r--tensorflow/core/lib/random/weighted_picker.h118
-rw-r--r--tensorflow/core/lib/random/weighted_picker_test.cc254
-rw-r--r--tensorflow/core/lib/strings/numbers.cc260
-rw-r--r--tensorflow/core/lib/strings/numbers.h92
-rw-r--r--tensorflow/core/lib/strings/numbers_test.cc113
-rw-r--r--tensorflow/core/lib/strings/ordered_code.cc515
-rw-r--r--tensorflow/core/lib/strings/ordered_code.h77
-rw-r--r--tensorflow/core/lib/strings/ordered_code_test.cc1183
-rw-r--r--tensorflow/core/lib/strings/str_util.cc312
-rw-r--r--tensorflow/core/lib/strings/str_util.h149
-rw-r--r--tensorflow/core/lib/strings/str_util_test.cc258
-rw-r--r--tensorflow/core/lib/strings/strcat.cc194
-rw-r--r--tensorflow/core/lib/strings/strcat.h229
-rw-r--r--tensorflow/core/lib/strings/strcat_test.cc324
-rw-r--r--tensorflow/core/lib/strings/stringprintf.cc85
-rw-r--r--tensorflow/core/lib/strings/stringprintf.h37
-rw-r--r--tensorflow/core/lib/strings/stringprintf_test.cc113
-rw-r--r--tensorflow/core/ops/array_ops.cc892
-rw-r--r--tensorflow/core/ops/attention_ops.cc54
-rw-r--r--tensorflow/core/ops/candidate_sampling_ops.cc351
-rw-r--r--tensorflow/core/ops/control_flow_ops.cc179
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc357
-rw-r--r--tensorflow/core/ops/image_ops.cc273
-rw-r--r--tensorflow/core/ops/io_ops.cc332
-rw-r--r--tensorflow/core/ops/linalg_ops.cc97
-rw-r--r--tensorflow/core/ops/logging_ops.cc43
-rw-r--r--tensorflow/core/ops/math_ops.cc1053
-rw-r--r--tensorflow/core/ops/nn_ops.cc543
-rw-r--r--tensorflow/core/ops/no_op.cc10
-rw-r--r--tensorflow/core/ops/parsing_ops.cc104
-rw-r--r--tensorflow/core/ops/random_ops.cc108
-rw-r--r--tensorflow/core/ops/sendrecv_ops.cc99
-rw-r--r--tensorflow/core/ops/sparse_ops.cc134
-rw-r--r--tensorflow/core/ops/state_ops.cc290
-rw-r--r--tensorflow/core/ops/string_ops.cc21
-rw-r--r--tensorflow/core/ops/summary_ops.cc115
-rw-r--r--tensorflow/core/ops/training_ops.cc199
-rw-r--r--tensorflow/core/platform/default/build_config.bzl65
-rw-r--r--tensorflow/core/platform/default/build_config/BUILD85
-rw-r--r--tensorflow/core/platform/default/build_config_root.bzl6
-rw-r--r--tensorflow/core/platform/default/dynamic_annotations.h9
-rw-r--r--tensorflow/core/platform/default/integral_types.h18
-rw-r--r--tensorflow/core/platform/default/logging.cc125
-rw-r--r--tensorflow/core/platform/default/logging.h258
-rw-r--r--tensorflow/core/platform/default/mutex.h33
-rw-r--r--tensorflow/core/platform/default/protobuf.h13
-rw-r--r--tensorflow/core/platform/default/stream_executor_util.h19
-rw-r--r--tensorflow/core/platform/default/test_benchmark.cc162
-rw-r--r--tensorflow/core/platform/default/thread_annotations.h185
-rw-r--r--tensorflow/core/platform/default/tracing.cc37
-rw-r--r--tensorflow/core/platform/default/tracing_impl.h44
-rw-r--r--tensorflow/core/platform/env.cc129
-rw-r--r--tensorflow/core/platform/env_test.cc31
-rw-r--r--tensorflow/core/platform/init_main.h16
-rw-r--r--tensorflow/core/platform/integral_types_test.cc33
-rw-r--r--tensorflow/core/platform/logging.h12
-rw-r--r--tensorflow/core/platform/logging_test.cc76
-rw-r--r--tensorflow/core/platform/port.h228
-rw-r--r--tensorflow/core/platform/port_test.cc48
-rw-r--r--tensorflow/core/platform/posix/env.cc385
-rw-r--r--tensorflow/core/platform/posix/port.cc92
-rw-r--r--tensorflow/core/platform/protobuf.h29
-rw-r--r--tensorflow/core/platform/protobuf_util.cc17
-rw-r--r--tensorflow/core/platform/regexp.h33
-rw-r--r--tensorflow/core/platform/stream_executor_util.h12
-rw-r--r--tensorflow/core/platform/tensor_coding.cc53
-rw-r--r--tensorflow/core/platform/tensor_coding.h40
-rw-r--r--tensorflow/core/platform/test.cc39
-rw-r--r--tensorflow/core/platform/test.h17
-rw-r--r--tensorflow/core/platform/test_benchmark.h58
-rw-r--r--tensorflow/core/platform/test_main.cc31
-rw-r--r--tensorflow/core/platform/thread_annotations.h14
-rw-r--r--tensorflow/core/platform/tracing.cc135
-rw-r--r--tensorflow/core/platform/tracing.h205
-rw-r--r--tensorflow/core/public/README.md90
-rw-r--r--tensorflow/core/public/env.h273
-rw-r--r--tensorflow/core/public/session.h125
-rw-r--r--tensorflow/core/public/session_options.h50
-rw-r--r--tensorflow/core/public/status.h96
-rw-r--r--tensorflow/core/public/tensor.h472
-rw-r--r--tensorflow/core/public/tensor_c_api.h243
-rw-r--r--tensorflow/core/public/tensor_shape.h239
-rw-r--r--tensorflow/core/public/tensorflow_server.h19
-rw-r--r--tensorflow/core/user_ops/fact.cc29
-rw-r--r--tensorflow/core/util/bcast.cc120
-rw-r--r--tensorflow/core/util/bcast.h99
-rw-r--r--tensorflow/core/util/bcast_test.cc226
-rw-r--r--tensorflow/core/util/device_name_utils.cc338
-rw-r--r--tensorflow/core/util/device_name_utils.h141
-rw-r--r--tensorflow/core/util/device_name_utils_test.cc369
-rw-r--r--tensorflow/core/util/event.proto29
-rw-r--r--tensorflow/core/util/events_writer.cc144
-rw-r--r--tensorflow/core/util/events_writer.h77
-rw-r--r--tensorflow/core/util/events_writer_test.cc198
-rw-r--r--tensorflow/core/util/guarded_philox_random.cc39
-rw-r--r--tensorflow/core/util/guarded_philox_random.h56
-rw-r--r--tensorflow/core/util/padding.cc24
-rw-r--r--tensorflow/core/util/padding.h37
-rw-r--r--tensorflow/core/util/port.cc13
-rw-r--r--tensorflow/core/util/port.h11
-rw-r--r--tensorflow/core/util/saved_tensor_slice.proto76
-rw-r--r--tensorflow/core/util/saved_tensor_slice_util.cc76
-rw-r--r--tensorflow/core/util/saved_tensor_slice_util.h110
-rw-r--r--tensorflow/core/util/saved_tensor_slice_util_test.cc32
-rw-r--r--tensorflow/core/util/sparse/README.md222
-rw-r--r--tensorflow/core/util/sparse/dim_comparator.h60
-rw-r--r--tensorflow/core/util/sparse/group_iterator.cc49
-rw-r--r--tensorflow/core/util/sparse/group_iterator.h120
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor.h353
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor_test.cc467
-rw-r--r--tensorflow/core/util/tensor_slice_reader.cc230
-rw-r--r--tensorflow/core/util/tensor_slice_reader.h157
-rw-r--r--tensorflow/core/util/tensor_slice_reader_cache.cc94
-rw-r--r--tensorflow/core/util/tensor_slice_reader_cache.h73
-rw-r--r--tensorflow/core/util/tensor_slice_reader_test.cc395
-rw-r--r--tensorflow/core/util/tensor_slice_set.cc148
-rw-r--r--tensorflow/core/util/tensor_slice_set.h73
-rw-r--r--tensorflow/core/util/tensor_slice_set_test.cc227
-rw-r--r--tensorflow/core/util/tensor_slice_util.h88
-rw-r--r--tensorflow/core/util/tensor_slice_util_test.cc91
-rw-r--r--tensorflow/core/util/tensor_slice_writer.cc110
-rw-r--r--tensorflow/core/util/tensor_slice_writer.h149
-rw-r--r--tensorflow/core/util/tensor_slice_writer_test.cc248
-rw-r--r--tensorflow/core/util/use_cudnn.cc20
-rw-r--r--tensorflow/core/util/use_cudnn.h12
-rw-r--r--tensorflow/core/util/util.cc81
-rw-r--r--tensorflow/core/util/util.h40
-rw-r--r--tensorflow/core/util/work_sharder.cc57
-rw-r--r--tensorflow/core/util/work_sharder.h33
-rw-r--r--tensorflow/core/util/work_sharder_test.cc57
788 files changed, 108161 insertions, 0 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
new file mode 100644
index 0000000000..c2fcfeed8c
--- /dev/null
+++ b/tensorflow/core/BUILD
@@ -0,0 +1,695 @@
+# Description:
+# TensorFlow is a computational framework, primarily for use in machine
+# learning applications.
+
+package(default_visibility = ["//tensorflow:internal"])
+
+package_group(name = "friends")
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("/tensorflow/tensorflow", "tf_copts")
+load("/tensorflow/tensorflow", "tf_cc_tests")
+load("/tensorflow/tensorflow", "tf_cuda_library")
+load("/tensorflow/tensorflow", "tf_gen_op_libs")
+load("/tensorflow/tensorflow", "tf_gpu_kernel_library")
+
+# For platform specific build config
+load(
+ "/tensorflow/core/platform/default/build_config",
+ "tf_proto_library",
+ "tf_additional_lib_srcs",
+ "tf_additional_test_srcs",
+ "tf_kernel_tests_linkstatic",
+)
+load(
+ "/tensorflow/core/platform/default/build_config_root",
+ "tf_cuda_tests_tags",
+)
+
+cc_library(
+ name = "lib",
+ srcs = glob(
+ [
+ "lib/**/*.h",
+ "lib/**/*.cc",
+ "platform/*.h",
+ "platform/*.cc",
+ "public/*.h",
+ ] + tf_additional_lib_srcs(),
+ exclude = [
+ "**/*test*",
+ ],
+ ),
+ copts = tf_copts(),
+ visibility = [
+ ":friends",
+ "//tensorflow:internal",
+ ],
+ deps = [
+ ":protos_cc",
+ "//tensorflow/core/platform/default/build_config:platformlib",
+ ],
+)
+
+tf_cuda_library(
+ name = "core_cpu",
+ srcs = glob(
+ [
+ "common_runtime/**/*.h",
+ "client/**/*.cc",
+ "common_runtime/**/*.cc",
+ "graph/**/*.h",
+ "graph/**/*.cc",
+ ],
+ exclude = [
+ "**/*test*",
+ "**/*main.cc",
+ "common_runtime/gpu/*.cc",
+ "common_runtime/copy_tensor.cc",
+ "common_runtime/gpu_device_factory.cc",
+ "common_runtime/local_session.cc",
+ "common_runtime/local_session.h",
+ ],
+ ),
+ hdrs = glob(["public/**/*.h"]),
+ copts = tf_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ ":copy_tensor",
+ ":framework",
+ ":lib",
+ ":protos_cc",
+ "//third_party/eigen3",
+ ],
+ alwayslink = 1,
+)
+
+tf_cuda_library(
+ name = "framework",
+ srcs = glob(
+ [
+ "framework/**/*.h",
+ "framework/**/*.cc",
+ "util/**/*.h",
+ "util/**/*.cc",
+ ],
+ exclude = [
+ "**/*test*",
+ "**/*main.cc",
+ ],
+ ),
+ hdrs = glob(["public/**/*.h"]),
+ copts = tf_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ ":lib",
+ ":protos_cc",
+ "//third_party/eigen3",
+ ],
+ alwayslink = 1,
+)
+
+tf_cuda_library(
+ name = "local",
+ srcs = [
+ "common_runtime/local_session.cc",
+ "common_runtime/local_session.h",
+ ],
+ copts = tf_copts(),
+ cuda_deps = [
+ ":cuda",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":core",
+ ":lib",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "copy_tensor",
+ deps = [
+ ":lib",
+ ":protos_cc",
+ ":stream_executor",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_cuda_library(
+ name = "gpu_runtime",
+ srcs = glob(
+ [
+ "common_runtime/gpu/**/*.h",
+ "common_runtime/gpu/**/*.cc",
+ ],
+ exclude = [
+ "**/*main.cc",
+ "**/*test.cc",
+ ],
+ ),
+ copts = tf_copts(),
+ cuda_deps = [
+ ":cuda",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":core_cpu",
+ ":lib",
+ ":protos_cc",
+ ":stream_executor",
+ "//third_party/eigen3",
+ ],
+ alwayslink = 1,
+)
+
+# Test support library needed for higher-level tests
+cc_library(
+ name = "testlib",
+ testonly = 1,
+ srcs = [
+ "common_runtime/kernel_benchmark_testlib.cc",
+ "common_runtime/kernel_benchmark_testlib.h",
+ "framework/function_testlib.cc",
+ "framework/function_testlib.h",
+ "framework/tensor_testutil.cc",
+ "framework/tensor_testutil.h",
+ "graph/testlib.cc",
+ "graph/testlib.h",
+ ],
+ copts = tf_copts(),
+ visibility = [
+ ":friends",
+ "//tensorflow:internal",
+ ],
+ deps = [
+ ":core_cpu",
+ ":tensorflow",
+ ":test",
+ "//tensorflow/core/platform/default/build_config:gtest",
+ ],
+)
+
+tf_cuda_library(
+ name = "tensorflow_opensource",
+ copts = tf_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ ":core",
+ ":gpu_runtime",
+ ":kernels",
+ ":lib",
+ ":local",
+ ],
+)
+
+tf_cuda_library(
+ name = "kernels",
+ srcs = glob(
+ [
+ "kernels/**/*.h",
+ "kernels/**/*.cc",
+ "ops/**/*.h",
+ "ops/**/*.cc",
+ "user_ops/**/*.h",
+ "user_ops/**/*.cc",
+ ],
+ exclude = [
+ "**/*test*",
+ "**/*main.cc",
+ "kernels/**/*.cu.cc",
+ "user_ops/**/*.cu.cc",
+ ],
+ ),
+ copts = tf_copts(),
+ cuda_deps = [
+ ":gpu_kernels",
+ ":cuda",
+ ],
+ linkstatic = 0,
+ visibility = ["//visibility:public"],
+ deps = [
+ "@gemmlowp//:eight_bit_int_gemm",
+ ":core",
+ ":lib",
+ ":protos_cc",
+ ":stream_executor",
+ "//tensorflow/models/embedding:word2vec_kernels",
+ "//tensorflow/models/embedding:word2vec_ops",
+ "//third_party/eigen3",
+ ],
+ alwayslink = 1,
+)
+
+tf_gpu_kernel_library(
+ name = "gpu_kernels",
+ srcs = glob(
+ [
+ "kernels/**/*.h",
+ "kernels/*.cu.cc",
+ "user_ops/**/*.h",
+ "user_ops/*.cu.cc",
+ ],
+ ),
+ visibility = ["//visibility:public"],
+ deps = [
+ "//third_party/eigen3",
+ ],
+)
+
+# Test support library needed for all tests
+cc_library(
+ name = "test",
+ testonly = 1,
+ srcs = [
+ "platform/test.cc",
+ ] + tf_additional_test_srcs(),
+ hdrs = [
+ "platform/test.h",
+ "platform/test_benchmark.h",
+ ],
+ copts = tf_copts(),
+ linkopts = ["-lm"],
+ deps = [
+ ":lib",
+ "//tensorflow/core/platform/default/build_config:gtest",
+ ],
+)
+
+# Main program for tests
+cc_library(
+ name = "test_main",
+ testonly = 1,
+ srcs = ["platform/test_main.cc"],
+ copts = tf_copts(),
+ linkopts = ["-lm"],
+ deps = [
+ ":test",
+ "//tensorflow/core/platform/default/build_config:test_main",
+ ],
+)
+
+# TODO(opensource): Make it work externally
+tf_proto_library(
+ name = "protos_all",
+ srcs = glob(["**/*.proto"]),
+ cc_api_version = 2,
+ go_api_version = 2,
+ java_api_version = 2,
+ py_api_version = 2,
+ visibility = ["//tensorflow:internal"],
+)
+
+cc_library(
+ name = "protos_cc",
+ deps = ["//tensorflow/core/platform/default/build_config:protos_cc"],
+)
+
+# Generates library per group of ops.
+tf_gen_op_libs(
+ op_lib_names = [
+ "array_ops",
+ "attention_ops",
+ "candidate_sampling_ops",
+ "control_flow_ops",
+ "data_flow_ops",
+ "image_ops",
+ "io_ops",
+ "linalg_ops",
+ "logging_ops",
+ "math_ops",
+ "nn_ops",
+ "no_op",
+ "parsing_ops",
+ "random_ops",
+ "sendrecv_ops",
+ "sparse_ops",
+ "state_ops",
+ "string_ops",
+ "summary_ops",
+ "training_ops",
+ ],
+)
+
+# And one for all user ops
+cc_library(
+ name = "user_ops_op_lib",
+ srcs = glob(["user_ops/**/*.cc"]),
+ copts = tf_copts(),
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+ deps = [":framework"],
+ alwayslink = 1,
+)
+
+# Low level library tests
+tf_cc_tests(
+ tests = glob(
+ [
+ "lib/**/*_test.cc",
+ "platform/**/*_test.cc",
+ ],
+ exclude = ["lib/strings/ordered_code_test.cc"],
+ ),
+ deps = [
+ ":lib",
+ ":test_main",
+ ],
+)
+
+cc_test(
+ name = "lib_jpeg_jpeg_mem_unittest",
+ srcs = ["lib/jpeg/jpeg_mem_unittest.cc"],
+ data = glob(["lib/jpeg/testdata/*.jpg"]),
+ deps = [
+ ":lib",
+ ":test_main",
+ ],
+)
+
+cc_test(
+ name = "lib_strings_ordered_code_test",
+ srcs = ["lib/strings/ordered_code_test.cc"],
+ copts = ["$(STACK_FRAME_UNLIMITED)"], # Tests initialize large vectors
+ deps = [
+ ":lib",
+ ":test_main",
+ ],
+)
+
+# higher level tests
+tf_cc_tests(
+ linkstatic = tf_kernel_tests_linkstatic(),
+ tests = glob(
+ [
+ "client/**/*_test.cc",
+ "common_runtime/**/*_test.cc",
+ "framework/**/*_test.cc",
+ "graph/**/*_test.cc",
+ "util/**/*_test.cc",
+ ],
+ exclude = [
+ # TODO(opensource): fix
+ "common_runtime/gpu/*_test.cc",
+ # Run by tests below
+ "common_runtime/gpu/gpu_region_allocator_test.cc",
+ "common_runtime/gpu/gpu_bfc_allocator_test.cc",
+ ],
+ ),
+ deps = [
+ ":core",
+ ":kernels",
+ ":lib",
+ ":local",
+ ":test_main",
+ ":testlib",
+ "//tensorflow/cc:cc_ops",
+ ],
+)
+
+# GPU-related tests
+tf_cc_tests(
+ linkstatic = tf_kernel_tests_linkstatic(),
+ tags = tf_cuda_tests_tags(),
+ tests = glob(
+ [
+ "kernels/**/*_test.cc",
+ "user_ops/**/*_test.cc",
+ "common_runtime/gpu/*_test.cc",
+ ],
+ ),
+ deps = [
+ ":kernels",
+ ":local",
+ ":test_main",
+ ":testlib",
+ "//tensorflow/cc:cc_ops",
+ ],
+)
+
+tf_cuda_library(
+ name = "stream_executor",
+ deps = [
+ "//tensorflow/core/platform/default/build_config:stream_executor",
+ ],
+)
+
+cc_library(
+ name = "cuda",
+ visibility = [
+ ":friends",
+ "//tensorflow:internal",
+ ],
+ deps = [
+ "//tensorflow/core/platform/default/build_config:cuda",
+ ],
+)
+
+cc_library(
+ name = "tensorflow",
+ visibility = ["//visibility:public"],
+ deps = [
+ "tensorflow_opensource",
+ "//tensorflow/core/platform/default/build_config:tensorflow_platform_specific",
+ ],
+)
+
+cc_library(
+ name = "core",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":core_cpu",
+ ":gpu_runtime",
+ ],
+)
+
+# Android-specific BUILD targets
+load("/tensorflow/tensorflow", "tf_android_core_proto_sources")
+
+# List of protos we want on android
+filegroup(
+ name = "android_proto_srcs",
+ srcs = tf_android_core_proto_sources(),
+ visibility = ["//visibility:public"],
+)
+
+# Core sources. Should eventually become identical to open source
+# sources.
+filegroup(
+ name = "android_srcs",
+ srcs = glob(
+ [
+ "client/**/*.cc",
+ "common_runtime/**/*.h",
+ "common_runtime/**/*.cc",
+ "framework/**/*.h",
+ "framework/**/*.cc",
+ "graph/**/*.h",
+ "graph/**/*.cc",
+ "lib/**/*.h",
+ "lib/**/*.cc",
+ "ops/**/*.cc",
+ "ops/**/*.h",
+ "platform/*.h",
+ "platform/*.cc",
+ "platform/**/*.h",
+ "platform/**/*.cc",
+ "public/**/*.h",
+ "util/**/*.h",
+ "util/**/*.cc",
+ "kernels/ops_util.cc",
+ "kernels/ops_util.h",
+ "kernels/avgpooling_op.h",
+ "kernels/maxpooling_op.h",
+ "kernels/pooling_ops_common.h",
+ "kernels/pooling_ops_common.cc",
+ "kernels/reference_gemm.h",
+ ],
+ exclude = [
+ "**/*test.cc",
+ "**/*testutil*",
+ "**/*testlib*",
+ "**/*main.cc",
+ "lib/jpeg/*.h",
+ "lib/jpeg/*.cc",
+ "lib/png/*.h",
+ "lib/png/*.cc",
+ "util/events_writer.cc",
+ "util/events_writer.h",
+ # Exclude all protobuf/google headers except protobuf_android.h
+ "platform/google/cord_coding.h",
+ "platform/google/dynamic_annotations.h",
+ "platform/google/integral_types.h",
+ "platform/google/mutex.h",
+ "platform/google/protobuf.h",
+ "platform/google/stream_executor_util.h",
+ "platform/google/tracing_impl.h",
+ "platform/google/*.cc",
+ "platform/google/test_benchmark.cc",
+ "platform/google/test_benchmark.h",
+ "kernels/**/*.cu.cc",
+ "user_ops/**/*.cu.cc",
+ "common_runtime/gpu/*.cc",
+ "common_runtime/gpu_device_factory.cc",
+ ],
+ ),
+ visibility = ["//visibility:public"],
+)
+
+# Core kernels we want on Android. Only a subset of kernels to keep
+# base library small.
+filegroup(
+ name = "android_core_ops",
+ srcs = [
+ "//tensorflow/core:kernels/aggregate_ops.cc",
+ "//tensorflow/core:kernels/aggregate_ops.h",
+ "//tensorflow/core:kernels/assign_op.h",
+ "//tensorflow/core:kernels/bias_op.cc",
+ "//tensorflow/core:kernels/bias_op.h",
+ "//tensorflow/core:kernels/cast_op.cc",
+ "//tensorflow/core:kernels/cast_op.h",
+ "//tensorflow/core:kernels/concat_op.cc",
+ "//tensorflow/core:kernels/concat_op.h",
+ "//tensorflow/core:kernels/concat_op_cpu.cc",
+ "//tensorflow/core:kernels/constant_op.cc",
+ "//tensorflow/core:kernels/constant_op.h",
+ "//tensorflow/core:kernels/cwise_ops.h",
+ "//tensorflow/core:kernels/cwise_ops_common.cc",
+ "//tensorflow/core:kernels/cwise_ops_common.h",
+ "//tensorflow/core:kernels/dense_update_ops.cc",
+ "//tensorflow/core:kernels/dense_update_ops.h",
+ "//tensorflow/core:kernels/fill_functor.h",
+ "//tensorflow/core:kernels/gather_op.cc",
+ "//tensorflow/core:kernels/identity_op.cc",
+ "//tensorflow/core:kernels/identity_op.h",
+ "//tensorflow/core:kernels/matmul_op.cc",
+ "//tensorflow/core:kernels/matmul_op.h",
+ "//tensorflow/core:kernels/no_op.cc",
+ "//tensorflow/core:kernels/no_op.h",
+ "//tensorflow/core:kernels/pack_op.cc",
+ "//tensorflow/core:kernels/reference_gemm.h",
+ "//tensorflow/core:kernels/reshape_op.cc",
+ "//tensorflow/core:kernels/reshape_op.h",
+ "//tensorflow/core:kernels/reverse_sequence_op.cc",
+ "//tensorflow/core:kernels/reverse_sequence_op.h",
+ "//tensorflow/core:kernels/sendrecv_ops.cc",
+ "//tensorflow/core:kernels/sendrecv_ops.h",
+ "//tensorflow/core:kernels/sequence_ops.cc",
+ "//tensorflow/core:kernels/shape_ops.cc",
+ "//tensorflow/core:kernels/slice_op.cc",
+ "//tensorflow/core:kernels/slice_op.h",
+ "//tensorflow/core:kernels/softmax_op.cc",
+ "//tensorflow/core:kernels/softmax_op.h",
+ "//tensorflow/core:kernels/split_op.cc",
+ "//tensorflow/core:kernels/split_op.h",
+ "//tensorflow/core:kernels/split_op_cpu.cc",
+ "//tensorflow/core:kernels/unpack_op.cc",
+ "//tensorflow/core:kernels/variable_ops.cc",
+ "//tensorflow/core:kernels/variable_ops.h",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+# Other kernels we may want on Android.
+filegroup(
+ name = "android_extended_ops",
+ srcs = [
+ "//tensorflow/core:kernels/avgpooling_op.cc",
+ "//tensorflow/core:kernels/avgpooling_op.h",
+ "//tensorflow/core:kernels/control_flow_ops.cc",
+ "//tensorflow/core:kernels/control_flow_ops.h",
+ "//tensorflow/core:kernels/conv_2d.h",
+ "//tensorflow/core:kernels/conv_ops.cc",
+ "//tensorflow/core:kernels/cwise_op_add.cc",
+ "//tensorflow/core:kernels/cwise_op_div.cc",
+ "//tensorflow/core:kernels/cwise_op_exp.cc",
+ "//tensorflow/core:kernels/cwise_op_log.cc",
+ "//tensorflow/core:kernels/cwise_op_mul.cc",
+ "//tensorflow/core:kernels/cwise_op_sigmoid.cc",
+ "//tensorflow/core:kernels/cwise_op_sqrt.cc",
+ "//tensorflow/core:kernels/cwise_op_square.cc",
+ "//tensorflow/core:kernels/cwise_op_sub.cc",
+ "//tensorflow/core:kernels/cwise_op_tanh.cc",
+ "//tensorflow/core:kernels/lrn_op.cc",
+ "//tensorflow/core:kernels/maxpooling_op.cc",
+ "//tensorflow/core:kernels/maxpooling_op.h",
+ "//tensorflow/core:kernels/reduction_ops.h",
+ "//tensorflow/core:kernels/reduction_ops_common.h",
+ "//tensorflow/core:kernels/reduction_ops_max.cc",
+ "//tensorflow/core:kernels/reduction_ops_min.cc",
+ "//tensorflow/core:kernels/reduction_ops_sum.cc",
+ "//tensorflow/core:kernels/relu_op.cc",
+ "//tensorflow/core:kernels/relu_op.h",
+ "//tensorflow/core:kernels/softplus_op.cc",
+ "//tensorflow/core:kernels/softplus_op.h",
+ "//tensorflow/core:kernels/transpose_op.cc",
+ "//tensorflow/core:kernels/transpose_op.h",
+ "//tensorflow/core:kernels/transpose_op_functor.h",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+# Test data
+filegroup(
+ name = "image_testdata",
+ srcs = [
+ # PNG data
+ "lib/png/testdata/lena_gray.png",
+ "lib/png/testdata/lena_rgba.png",
+ # JPEG data
+ "lib/jpeg/testdata/jpeg_merge_test1.jpg",
+ "lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg",
+ # Corrupted JPEG files for tests
+ "lib/jpeg/testdata/bad_huffman.jpg",
+ "lib/jpeg/testdata/corrupt.jpg",
+ # -- hand-edited variant: stops at line 0
+ "lib/jpeg/testdata/corrupt34_2.jpg",
+ # -- hand-edited variant: stops at line 4
+ "lib/jpeg/testdata/corrupt34_3.jpg",
+ # -- hand-edited variant: stops after a restart marker
+ "lib/jpeg/testdata/corrupt34_4.jpg",
+ ],
+)
+
+# For portable_proto_library
+
+# Native library support for Android applications.
+# Should be built to target Android with flag --copt=-mfpu=neon.
+cc_library(
+ name = "android_tensorflow_lib",
+ srcs = [
+ "//tensorflow/core:android_core_ops",
+ "//tensorflow/core:android_extended_ops",
+ "//tensorflow/core:android_srcs",
+ ],
+ copts = [
+ "-mfpu=neon",
+ "-std=c++11",
+ ],
+ tags = [
+ "manual",
+ "notap",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "@re2//:re2",
+ ":protos_cc",
+ "//third_party/eigen3",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/core/client/tensor_c_api.cc b/tensorflow/core/client/tensor_c_api.cc
new file mode 100644
index 0000000000..59cf0ed8f9
--- /dev/null
+++ b/tensorflow/core/client/tensor_c_api.cc
@@ -0,0 +1,370 @@
+#include "tensorflow/core/public/tensor_c_api.h"
+
+#include <memory>
+
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+// The implementation below is at the top level instead of the
+// brain namespace because we are defining 'extern "C"' functions.
+using tensorflow::error::Code;
+using tensorflow::errors::InvalidArgument;
+using tensorflow::gtl::ArraySlice;
+using tensorflow::AllocationDescription;
+using tensorflow::Status;
+using tensorflow::DataType;
+using tensorflow::Env;
+using tensorflow::GraphDef;
+using tensorflow::NewSession;
+using tensorflow::Session;
+using tensorflow::Tensor;
+using tensorflow::TensorBuffer;
+using tensorflow::SessionOptions;
+using tensorflow::TensorShape;
+
+extern "C" {
+
+// --------------------------------------------------------------------------
+struct TF_Status {
+ Status status;
+};
+
+TF_Status* TF_NewStatus() { return new TF_Status; }
+
+void TF_DeleteStatus(TF_Status* s) { delete s; }
+
+void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
+ s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
+}
+
+TF_Code TF_GetCode(const TF_Status* s) {
+ return static_cast<TF_Code>(s->status.code());
+}
+
+const char* TF_Message(const TF_Status* s) {
+ return s->status.error_message().c_str();
+}
+
+// --------------------------------------------------------------------------
+
+namespace {
+class TF_ManagedBuffer : public TensorBuffer {
+ public:
+ void* data_;
+ size_t len_;
+ void (*deallocator_)(void* data, size_t len, void* arg);
+ void* deallocator_arg_;
+
+ ~TF_ManagedBuffer() override {
+ (*deallocator_)(data_, len_, deallocator_arg_);
+ }
+
+ void* data() const override { return data_; }
+ size_t size() const override { return len_; }
+ TensorBuffer* root_buffer() override { return this; }
+ void FillAllocationDescription(AllocationDescription* proto) const override {
+ tensorflow::int64 rb = size();
+ proto->set_requested_bytes(rb);
+ proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
+ }
+};
+
+void deallocate_realigned_buffer(void* data, size_t len, void* arg) {
+ tensorflow::cpu_allocator()->DeallocateRaw(data);
+}
+} // namespace
+
+struct TF_Tensor {
+ TF_DataType dtype;
+ TensorShape shape;
+ TensorBuffer* buffer;
+};
+
+TF_Tensor* TF_NewTensor(TF_DataType dtype, tensorflow::int64* dims,
+ int num_dims, void* data, size_t len,
+ void (*deallocator)(void* data, size_t len, void* arg),
+ void* deallocator_arg) {
+ std::vector<tensorflow::int64> dimvec(num_dims);
+ for (int i = 0; i < num_dims; i++) {
+ dimvec[i] = dims[i];
+ }
+
+ TF_ManagedBuffer* buf = new TF_ManagedBuffer;
+ buf->len_ = len;
+ if (reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) {
+ // Copy the data into a buffer that satisfies Eigen's alignment
+ // requirements.
+ buf->data_ =
+ tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len);
+ std::memcpy(buf->data_, data, len);
+ buf->deallocator_ = deallocate_realigned_buffer;
+ buf->deallocator_arg_ = nullptr;
+ // Free the original buffer.
+ deallocator(data, len, deallocator_arg);
+ } else {
+ buf->data_ = data;
+ buf->deallocator_ = deallocator;
+ buf->deallocator_arg_ = deallocator_arg;
+ }
+ return new TF_Tensor{dtype, TensorShape(dimvec), buf};
+}
+
+void TF_DeleteTensor(TF_Tensor* t) {
+ t->buffer->Unref();
+ delete t;
+}
+
+TF_DataType TF_TensorType(const TF_Tensor* t) { return t->dtype; }
+int TF_NumDims(const TF_Tensor* t) { return t->shape.dims(); }
+tensorflow::int64 TF_Dim(const TF_Tensor* t, int dim_index) {
+ return t->shape.dim_size(dim_index);
+}
+size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); }
+void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); }
+
+// --------------------------------------------------------------------------
+struct TF_SessionOptions {
+ SessionOptions options;
+};
+TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
+void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
+
+void TF_SetTarget(TF_SessionOptions* options, const char* target) {
+ options->options.target = target;
+}
+
+void TF_SetConfig(TF_SessionOptions* options, const char* config,
+ size_t config_len, TF_Status* status) {
+ if (!options->options.config.ParseFromArray(config, config_len)) {
+ status->status =
+ tensorflow::errors::InvalidArgument("Unparseable ConfigProto");
+ }
+}
+
+// --------------------------------------------------------------------------
+struct TF_Session {
+ Session* session;
+};
+
+TF_Session* TF_NewSession(const TF_SessionOptions* opt, TF_Status* status) {
+ Session* session;
+ status->status = NewSession(opt->options, &session);
+ if (status->status.ok()) {
+ return new TF_Session({session});
+ } else {
+ DCHECK_EQ(nullptr, session);
+ return NULL;
+ }
+}
+
+void TF_CloseSession(TF_Session* s, TF_Status* status) {
+ status->status = s->session->Close();
+}
+
+void TF_DeleteSession(TF_Session* s, TF_Status* status) {
+ status->status = Status::OK();
+ delete s->session;
+ delete s;
+}
+
+void TF_ExtendGraph(TF_Session* s, const void* proto, size_t proto_len,
+ TF_Status* status) {
+ GraphDef g;
+ if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
+ status->status = tensorflow::errors::InvalidArgument("Invalid GraphDef");
+ return;
+ }
+ status->status = s->session->Extend(g);
+}
+
+static void DeleteArray(void* data, size_t size, void* arg) {
+ DCHECK_EQ(data, arg);
+ delete[] reinterpret_cast<char*>(arg);
+}
+
+} // end extern "C"
+
+namespace tensorflow {
+
+// Non-static for testing.
+bool TF_Tensor_DecodeStrings(TF_Tensor* src, Tensor* dst, TF_Status* status) {
+ const tensorflow::int64 num_elements = src->shape.num_elements();
+ const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
+ const size_t src_size = TF_TensorByteSize(src);
+ if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
+ num_elements) {
+ status->status = InvalidArgument(
+ "Malformed TF_STRING tensor; too short to hold number of elements");
+ return false;
+ }
+ const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
+ const char* limit = input + src_size;
+
+ *dst = Tensor(static_cast<DataType>(src->dtype), src->shape);
+ auto dstarray = dst->flat<tensorflow::string>();
+ for (tensorflow::int64 i = 0; i < num_elements; i++) {
+ tensorflow::uint64 offset =
+ reinterpret_cast<const tensorflow::uint64*>(input)[i];
+ tensorflow::uint64 len;
+ const char* p;
+ if (static_cast<ptrdiff_t>(offset) >= (limit - data_start) ||
+ !(p = tensorflow::core::GetVarint64Ptr(data_start + offset, limit,
+ &len)) ||
+ (static_cast<ptrdiff_t>(len) > (limit - p))) {
+ status->status = InvalidArgument("Malformed TF_STRING tensor; element ",
+ i, " out of range");
+ return false;
+ }
+ dstarray(i).assign(p, len);
+ }
+ return true;
+}
+
+// Non-static for testing.
+TF_Tensor* TF_Tensor_EncodeStrings(const Tensor& src) {
+ // Compute bytes needed for encoding.
+ size_t size = 0;
+ const auto& srcarray = src.flat<tensorflow::string>();
+ for (int i = 0; i < srcarray.size(); i++) {
+ const tensorflow::string& s = srcarray(i);
+ // uint64 starting_offset, varint64 length, string contents
+ size += sizeof(tensorflow::uint64) +
+ tensorflow::core::VarintLength(s.size()) + s.size();
+ }
+
+ // Encode all strings.
+ char* base = new char[size];
+ char* data_start = base + sizeof(tensorflow::uint64) * srcarray.size();
+ char* dst = data_start; // Where next string is encoded.
+ tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
+ for (int i = 0; i < srcarray.size(); i++) {
+ const tensorflow::string& s = srcarray(i);
+ *offsets = (dst - data_start);
+ offsets++;
+ dst = tensorflow::core::EncodeVarint64(dst, s.size());
+ memcpy(dst, s.data(), s.size());
+ dst += s.size();
+ }
+ CHECK_EQ(dst, base + size);
+
+ auto dims = src.shape().dim_sizes();
+ std::vector<tensorflow::int64> dimvec(dims.size());
+ for (size_t i = 0; i < dims.size(); i++) {
+ dimvec[i] = dims[i];
+ }
+ return TF_NewTensor(TF_STRING, dimvec.data(), dimvec.size(), base, size,
+ DeleteArray, base);
+}
+
+class TensorCApi {
+ public:
+ static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; }
+ static Tensor MakeTensor(TF_DataType type, const TensorShape& shape,
+ TensorBuffer* buf) {
+ return Tensor(static_cast<DataType>(type), shape, buf);
+ }
+};
+
+// Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
+// result in a zero-sized tensor.
+static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) {
+ static char empty;
+ tensorflow::int64 nelems = 1;
+ std::vector<tensorflow::int64> dims;
+ for (int i = 0; i < shape.dims(); ++i) {
+ dims.push_back(shape.dim_size(i));
+ nelems *= shape.dim_size(i);
+ }
+ CHECK_EQ(nelems, 0);
+ return TF_NewTensor(dtype, dims.data(), shape.dims(),
+ reinterpret_cast<void*>(&empty), 0,
+ [](void*, size_t, void*) {}, nullptr);
+}
+
+} // namespace tensorflow
+
+extern "C" {
+
+void TF_Run(TF_Session* s,
+ // Input tensors
+ const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
+ // Output tensors
+ const char** c_output_tensor_names, TF_Tensor** c_outputs,
+ int noutputs,
+ // Target nodes
+ const char** c_target_node_names, int ntargets, TF_Status* status) {
+ status->status = Status::OK();
+ for (int i = 0; i < noutputs; i++) {
+ c_outputs[i] = NULL;
+ }
+
+ // Initialize inputs.
+ std::vector<std::pair<tensorflow::string, Tensor>> inputs(ninputs);
+ bool ok = true;
+ for (int i = 0; i < ninputs; i++) {
+ TF_Tensor* src = c_inputs[i];
+ if (ok) {
+ inputs[i].first = c_input_names[i];
+ if (c_inputs[i]->dtype != TF_STRING) {
+ inputs[i].second = tensorflow::TensorCApi::MakeTensor(
+ src->dtype, src->shape, src->buffer);
+ } else {
+ // TF_STRING tensors require copying since Tensor class expects
+ // a sequence of string objects.
+ ok =
+ tensorflow::TF_Tensor_DecodeStrings(src, &inputs[i].second, status);
+ // Must keep looping through all inputs even if there is an error
+ // so that TF_DeleteTensor() is called unconditionally on all inputs.
+ }
+ }
+ TF_DeleteTensor(src);
+ }
+ if (!ok) {
+ return;
+ }
+
+ std::vector<tensorflow::string> output_tensor_names(noutputs);
+ std::vector<Tensor> outputs(noutputs);
+ std::vector<tensorflow::string> target_node_names(ntargets);
+ for (int i = 0; i < noutputs; i++) {
+ output_tensor_names[i] = c_output_tensor_names[i];
+ }
+ for (int i = 0; i < ntargets; i++) {
+ target_node_names[i] = c_target_node_names[i];
+ }
+ Status result =
+ s->session->Run(inputs, output_tensor_names, target_node_names, &outputs);
+ if (!result.ok()) {
+ status->status = result;
+ return;
+ }
+
+ // Store results in c_outputs[]
+ for (int i = 0; i < noutputs; i++) {
+ const Tensor& src = outputs[i];
+ if (!src.IsInitialized()) {
+ c_outputs[i] = tensorflow::EmptyTensor(
+ static_cast<TF_DataType>(src.dtype()), src.shape());
+ continue;
+ }
+ if (src.dtype() != tensorflow::DT_STRING) {
+ // Share the underlying buffer.
+ TensorBuffer* buf = tensorflow::TensorCApi::Buffer(src);
+ buf->Ref();
+ c_outputs[i] = new TF_Tensor{static_cast<TF_DataType>(src.dtype()),
+ src.shape(), buf};
+ } else {
+ c_outputs[i] = tensorflow::TF_Tensor_EncodeStrings(src);
+ }
+ }
+}
+
+} // end extern "C"
diff --git a/tensorflow/core/client/tensor_c_api_test.cc b/tensorflow/core/client/tensor_c_api_test.cc
new file mode 100644
index 0000000000..4afdd0c0df
--- /dev/null
+++ b/tensorflow/core/client/tensor_c_api_test.cc
@@ -0,0 +1,94 @@
+#include "tensorflow/core/public/tensor_c_api.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/public/tensor.h"
+
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+
+namespace tensorflow {
+bool TF_Tensor_DecodeStrings(TF_Tensor* src, Tensor* dst, TF_Status* status);
+TF_Tensor* TF_Tensor_EncodeStrings(const Tensor& src);
+} // namespace tensorflow
+
+TEST(CApi, Status) {
+ TF_Status* s = TF_NewStatus();
+ EXPECT_EQ(TF_OK, TF_GetCode(s));
+ EXPECT_EQ(tensorflow::string(), TF_Message(s));
+ TF_SetStatus(s, TF_CANCELLED, "cancel");
+ EXPECT_EQ(TF_CANCELLED, TF_GetCode(s));
+ EXPECT_EQ(tensorflow::string("cancel"), TF_Message(s));
+ TF_DeleteStatus(s);
+}
+
+static void Deallocator(void* data, size_t, void* arg) {
+ tensorflow::cpu_allocator()->DeallocateRaw(data);
+ *reinterpret_cast<bool*>(arg) = true;
+}
+
+TEST(CApi, Tensor) {
+ float* values =
+ reinterpret_cast<float*>(tensorflow::cpu_allocator()->AllocateRaw(
+ EIGEN_MAX_ALIGN_BYTES, 6 * sizeof(float)));
+ tensorflow::int64 dims[] = {2, 3};
+ bool deallocator_called = false;
+ TF_Tensor* t = TF_NewTensor(TF_FLOAT, dims, 2, values, sizeof(values),
+ &Deallocator, &deallocator_called);
+ EXPECT_FALSE(deallocator_called);
+ EXPECT_EQ(TF_FLOAT, TF_TensorType(t));
+ EXPECT_EQ(2, TF_NumDims(t));
+ EXPECT_EQ(dims[0], TF_Dim(t, 0));
+ EXPECT_EQ(dims[1], TF_Dim(t, 1));
+ EXPECT_EQ(sizeof(values), TF_TensorByteSize(t));
+ EXPECT_EQ(static_cast<void*>(values), TF_TensorData(t));
+ TF_DeleteTensor(t);
+ EXPECT_TRUE(deallocator_called);
+}
+
+static void TestEncodeDecode(int line,
+ const std::vector<tensorflow::string>& data) {
+ const tensorflow::int64 n = data.size();
+ for (std::vector<tensorflow::int64> dims :
+ std::vector<std::vector<tensorflow::int64>>{
+ {n}, {1, n}, {n, 1}, {n / 2, 2}}) {
+ // Create C++ Tensor
+ Tensor src(tensorflow::DT_STRING, TensorShape(dims));
+ for (tensorflow::int64 i = 0; i < src.NumElements(); i++) {
+ src.flat<tensorflow::string>()(i) = data[i];
+ }
+ TF_Tensor* dst = TF_Tensor_EncodeStrings(src);
+
+ // Convert back to a C++ Tensor and ensure we get expected output.
+ TF_Status* status = TF_NewStatus();
+ Tensor output;
+ ASSERT_TRUE(TF_Tensor_DecodeStrings(dst, &output, status)) << line;
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << line;
+ ASSERT_EQ(src.NumElements(), output.NumElements()) << line;
+ for (tensorflow::int64 i = 0; i < src.NumElements(); i++) {
+ ASSERT_EQ(data[i], output.flat<tensorflow::string>()(i)) << line;
+ }
+
+ TF_DeleteStatus(status);
+ TF_DeleteTensor(dst);
+ }
+}
+
+TEST(CApi, TensorEncodeDecodeStrings) {
+ TestEncodeDecode(__LINE__, {});
+ TestEncodeDecode(__LINE__, {"hello"});
+ TestEncodeDecode(__LINE__,
+ {"the", "quick", "brown", "fox", "jumped", "over"});
+
+ tensorflow::string big(1000, 'a');
+ TestEncodeDecode(__LINE__, {"small", big, "small2"});
+}
+
+TEST(CApi, SessionOptions) {
+ TF_SessionOptions* opt = TF_NewSessionOptions();
+ TF_DeleteSessionOptions(opt);
+}
+
+// TODO(jeff,sanjay): Session tests
+// . Create and delete
+// . Extend graph
+// . Run
diff --git a/tensorflow/core/common_runtime/device.cc b/tensorflow/core/common_runtime/device.cc
new file mode 100644
index 0000000000..2e3e7b6597
--- /dev/null
+++ b/tensorflow/core/common_runtime/device.cc
@@ -0,0 +1,37 @@
+#include "tensorflow/core/common_runtime/device.h"
+
+#include "tensorflow/core/framework/op_segment.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+Device::Device(Env* env, const DeviceAttributes& device_attributes,
+ Allocator* device_allocator)
+ : DeviceBase(env), device_attributes_(device_attributes) {
+ CHECK(DeviceNameUtils::ParseFullName(name(), &parsed_name_))
+ << "Invalid device name: " << name();
+ rmgr_ = new ResourceMgr(parsed_name_.job);
+}
+
+Device::~Device() { delete rmgr_; }
+
+// static
+DeviceAttributes Device::BuildDeviceAttributes(
+ const string& name, DeviceType device, Bytes memory_limit,
+ BusAdjacency bus_adjacency, const string& physical_device_desc) {
+ DeviceAttributes da;
+ da.set_name(name);
+ do {
+ da.set_incarnation(random::New64());
+ } while (da.incarnation() == 0); // This proto field must not be zero
+ da.set_device_type(device.type());
+ da.set_memory_limit(memory_limit.value());
+ da.set_bus_adjacency(bus_adjacency);
+ da.set_physical_device_desc(physical_device_desc);
+ return da;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h
new file mode 100644
index 0000000000..ff3404fea4
--- /dev/null
+++ b/tensorflow/core/common_runtime/device.h
@@ -0,0 +1,128 @@
+// A Device is a something that can perform computations as part of a
+// model. Devices can be local (runs computation on this machine), or
+// remote (contacts a device local to another machine using an RPC to
+// do the work). Devices are registered in a DeviceSet, which is also
+// responsible for the Device <-> id mapping.
+//
+// Device names
+// * Every Device should have a unique name with the format:
+// /job:___/replica:___/task:___/(gpu|cpu):___
+// An example name would be "/job:train/replica:0/task:3/gpu:2".
+// * Task numbers are within the specified replica, so there are as
+// many "task zeros" as replicas.
+
+#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_H_
+#define TENSORFLOW_COMMON_RUNTIME_DEVICE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/control_flow.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/op_segment.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/types.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+class Device : public DeviceBase {
+ public:
+ Device(Env* env, const DeviceAttributes& device_attributes,
+ Allocator* device_allocator);
+ ~Device() override;
+
+ // Full name of this device (see top comment).
+ const string& name() const { return device_attributes_.name(); }
+
+ // Parsed name of this device
+ const DeviceNameUtils::ParsedName parsed_name() const { return parsed_name_; }
+
+ // Describes what kind of device this is. This is intended to be
+ // human-readable and not computer-parsed, except that two devices
+ // with the same device_type() are expected to perform similarly
+ // (both from a computation and communication perspective).
+ const string& device_type() const { return device_attributes_.device_type(); }
+
+ // Returns an aggregation of device attributes.
+ const DeviceAttributes& attributes() const override {
+ return device_attributes_;
+ }
+
+ // Performs the actual compute function.
+ //
+ // Subclasses may override this function if they wish to perform
+ // some initialization before each compute.
+ virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) {
+ op_kernel->Compute(context);
+ }
+
+ // Asynchronous kernel's compute.
+ virtual void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
+ AsyncOpKernel::DoneCallback done) {
+ op_kernel->ComputeAsync(context, done);
+ }
+
+ // Blocks until all operations queued on the device at the time of
+ // the call have completed. Returns any error pending on the device
+ // at completion.
+ virtual Status Sync() = 0;
+
+ // Fill in the context map for the graph. Default behavior is to do
+ // nothing.
+ //
+ // The caller takes ownership over the DeviceContext objects given
+ // by the device.
+ virtual Status FillContextMap(const Graph* graph,
+ DeviceContextMap* device_context_map) {
+ return Status::OK();
+ }
+
+ // Returns the op segment of this device. The caller can reuse op
+ // kernels registered for the same session running on this device.
+ OpSegment* op_segment() { return &op_seg_; }
+
+ // Returns the resource manager associated w/ this device.
+ ResourceMgr* resource_manager() { return rmgr_; }
+
+ // Summarizes the status of this Device, for debugging.
+ string DebugString() const { return device_attributes_.DebugString(); }
+
+ // Assembles the parameter components into a complete DeviceAttributes value.
+ static DeviceAttributes BuildDeviceAttributes(
+ const string& name, DeviceType device, Bytes memory_limit,
+ BusAdjacency bus_adjacency, const string& physical_device_desc);
+
+ static DeviceAttributes BuildDeviceAttributes(const string& name,
+ DeviceType device,
+ Bytes memory_limit,
+ BusAdjacency bus_adjacency) {
+ // Pass in an empty string as physical device name.
+ return BuildDeviceAttributes(name, device, memory_limit, bus_adjacency, "");
+ }
+
+ private:
+ const DeviceAttributes device_attributes_;
+ DeviceNameUtils::ParsedName parsed_name_;
+
+ // op_seg_ maps session handle and op name to OpKernel objects.
+ OpSegment op_seg_;
+
+ // Resources associated w/ this device. E.g., shared variables, etc.
+ ResourceMgr* rmgr_ = nullptr;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Device);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/device_factory.cc b/tensorflow/core/common_runtime/device_factory.cc
new file mode 100644
index 0000000000..7d391bde1d
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_factory.cc
@@ -0,0 +1,106 @@
+#include "tensorflow/core/common_runtime/device_factory.h"
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+namespace {
+
+static mutex* get_device_factory_lock() {
+ static mutex device_factory_lock;
+ return &device_factory_lock;
+}
+
+struct FactoryItem {
+ std::unique_ptr<DeviceFactory> factory;
+ int priority;
+};
+
+std::unordered_map<string, FactoryItem>& device_factories() {
+ static std::unordered_map<string, FactoryItem>* factories =
+ new std::unordered_map<string, FactoryItem>;
+ return *factories;
+}
+} // namespace
+
+void DeviceFactory::Register(const string& device_type, DeviceFactory* factory,
+ int priority) {
+ mutex_lock l(*get_device_factory_lock());
+ std::unique_ptr<DeviceFactory> factory_ptr(factory);
+ std::unordered_map<string, FactoryItem>& factories = device_factories();
+ auto iter = factories.find(device_type);
+ if (iter == factories.end()) {
+ factories[device_type] = {std::move(factory_ptr), priority};
+ } else {
+ if (iter->second.priority < priority) {
+ iter->second = {std::move(factory_ptr), priority};
+ } else if (iter->second.priority == priority) {
+ LOG(FATAL) << "Duplicate registration of device factory for type "
+ << device_type << " with the same priority " << priority;
+ }
+ }
+}
+
+DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
+ mutex_lock l(*get_device_factory_lock()); // could use reader lock
+ auto it = device_factories().find(device_type);
+ if (it == device_factories().end()) {
+ return nullptr;
+ }
+ return it->second.factory.get();
+}
+
+void DeviceFactory::AddDevices(const SessionOptions& options,
+ const string& name_prefix,
+ std::vector<Device*>* devices) {
+ // CPU first.
+ auto cpu_factory = GetFactory("CPU");
+ if (!cpu_factory) {
+ LOG(FATAL)
+ << "CPU Factory not registered. Did you link in threadpool_device?";
+ }
+ size_t init_size = devices->size();
+ cpu_factory->CreateDevices(options, name_prefix, devices);
+ if (devices->size() == init_size) {
+ LOG(FATAL) << "No CPU devices are available in this process";
+ }
+
+ // Then GPU.
+ auto gpu_factory = GetFactory("GPU");
+ if (gpu_factory) {
+ gpu_factory->CreateDevices(options, name_prefix, devices);
+ }
+
+ // Then the rest.
+ mutex_lock l(*get_device_factory_lock());
+ for (auto& p : device_factories()) {
+ auto factory = p.second.factory.get();
+ if (factory != cpu_factory && factory != gpu_factory) {
+ factory->CreateDevices(options, name_prefix, devices);
+ }
+ }
+}
+
+Device* DeviceFactory::NewDevice(const string& type,
+ const SessionOptions& options,
+ const string& name_prefix) {
+ auto device_factory = GetFactory(type);
+ if (!device_factory) {
+ return nullptr;
+ }
+ SessionOptions opt = options;
+ (*opt.config.mutable_device_count())[type] = 1;
+ std::vector<Device*> devices;
+ device_factory->CreateDevices(opt, name_prefix, &devices);
+ CHECK_EQ(devices.size(), 1);
+ return devices[0];
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_factory.h b/tensorflow/core/common_runtime/device_factory.h
new file mode 100644
index 0000000000..57b625b3e5
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_factory.h
@@ -0,0 +1,69 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
+#define TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
+
+#include <string>
+#include <vector>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+class Device;
+struct SessionOptions;
+
+class DeviceFactory {
+ public:
+ virtual ~DeviceFactory() {}
+ static void Register(const string& device_type, DeviceFactory* factory,
+ int priority);
+ static DeviceFactory* GetFactory(const string& device_type);
+
+ // Append to "*devices" all suitable devices, respecting
+ // any device type specific properties/counts listed in "options".
+ //
+ // CPU devices are added first.
+ static void AddDevices(const SessionOptions& options,
+ const string& name_prefix,
+ std::vector<Device*>* devices);
+
+ // Helper for tests. Create a single device of type "type". The
+ // returned device is always numbered zero, so if creating multiple
+ // devices of the same type, supply distinct name_prefix arguments.
+ static Device* NewDevice(const string& type, const SessionOptions& options,
+ const string& name_prefix);
+
+ // Most clients should call AddDevices() instead.
+ virtual void CreateDevices(const SessionOptions& options,
+ const string& name_prefix,
+ std::vector<Device*>* devices) = 0;
+};
+
+namespace dfactory {
+
+template <class Factory>
+class Registrar {
+ public:
+ // Multiple registrations for the same device type with different priorities
+ // are allowed. The registration with the highest priority will be used.
+ explicit Registrar(const string& device_type, int priority = 0) {
+ DeviceFactory::Register(device_type, new Factory(), priority);
+ }
+};
+
+} // namespace dfactory
+
+#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \
+ INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
+ __COUNTER__, ##__VA_ARGS__)
+
+#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
+ ctr, ...) \
+ static ::tensorflow::dfactory::Registrar<device_factory> \
+ INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type, \
+ ##__VA_ARGS__)
+
+// __COUNTER__ must go through another macro to be properly expanded
+#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc
new file mode 100644
index 0000000000..4fa13f6b4b
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_mgr.cc
@@ -0,0 +1,90 @@
+#include "tensorflow/core/common_runtime/device_mgr.h"
+
+#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+DeviceMgr::DeviceMgr(const std::vector<Device*>& devices) {
+ for (Device* d : devices) {
+ devices_.push_back(d);
+
+ // Register under both the full name and the local name.
+ device_map_[d->name()] = d;
+ device_map_[DeviceNameUtils::LocalName(d->name())] = d;
+ device_type_counts_[d->device_type()]++;
+ }
+}
+
+DeviceMgr::~DeviceMgr() {
+ for (auto p : devices_) delete p;
+}
+
+void DeviceMgr::ListDeviceAttributes(
+ std::vector<DeviceAttributes>* devices) const {
+ devices->reserve(devices_.size());
+ for (Device* dev : devices_) {
+ devices->emplace_back(dev->attributes());
+ }
+}
+
+std::vector<Device*> DeviceMgr::ListDevices() const {
+ return std::vector<Device*>(devices_.begin(), devices_.end());
+}
+
+string DeviceMgr::DebugString() const {
+ string out;
+ for (Device* dev : devices_) {
+ strings::StrAppend(&out, dev->name(), "\n");
+ }
+ return out;
+}
+
+string DeviceMgr::DeviceMappingString() const {
+ string out;
+ for (Device* dev : devices_) {
+ if (!dev->attributes().physical_device_desc().empty()) {
+ strings::StrAppend(&out, dev->name(), " -> ",
+ dev->attributes().physical_device_desc(), "\n");
+ }
+ }
+ return out;
+}
+
+Status DeviceMgr::LookupDevice(const string& name, Device** device) const {
+ Status s;
+ auto iter = device_map_.find(name);
+ if (iter == device_map_.end()) {
+ return errors::InvalidArgument(name, " unknown device.");
+ }
+ *device = iter->second;
+ return Status::OK();
+}
+
+void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const {
+ Status s;
+ for (Device* dev : devices_) {
+ if (containers.empty()) {
+ s.Update(dev->resource_manager()->Cleanup(
+ dev->resource_manager()->default_container()));
+ } else {
+ for (const string& c : containers) {
+ s.Update(dev->resource_manager()->Cleanup(c));
+ }
+ }
+ if (!s.ok()) {
+ LOG(WARNING) << s;
+ }
+ }
+}
+
+int DeviceMgr::NumDeviceType(const string& type) const {
+ auto iter = device_type_counts_.find(type);
+ if (iter != device_type_counts_.end()) return iter->second;
+ return 0;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h
new file mode 100644
index 0000000000..c57d0222aa
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_mgr.h
@@ -0,0 +1,55 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_
+#define TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+class DeviceAttributes;
+
+class DeviceMgr {
+ public:
+ // TODO(zhifengc): Other initialization information.
+ explicit DeviceMgr(const std::vector<Device*>& devices);
+ ~DeviceMgr();
+
+ // Returns attributes of all devices.
+ void ListDeviceAttributes(std::vector<DeviceAttributes>* devices) const;
+
+ std::vector<Device*> ListDevices() const;
+
+ // Returns a string listing all devices.
+ string DebugString() const;
+
+ // Returns a string of all the device mapping.
+ string DeviceMappingString() const;
+
+ // Assigns *device with pointer to Device of the given name.
+ // Accepts either a full device name, or just the replica-local suffix.
+ Status LookupDevice(const string& name, Device** device) const;
+
+ // Clears given containers of all devices if 'container' is
+ // non-empty. Otherwise, clears default containers of all devices.
+ void ClearContainers(gtl::ArraySlice<string> containers) const;
+
+ int NumDeviceType(const string& type) const;
+
+ private:
+ typedef gtl::InlinedVector<Device*, 8> DeviceVec;
+ DeviceVec devices_;
+ std::unordered_map<string, Device*> device_map_;
+ std::unordered_map<string, int> device_type_counts_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_
diff --git a/tensorflow/core/common_runtime/device_set.cc b/tensorflow/core/common_runtime/device_set.cc
new file mode 100644
index 0000000000..3b0465d9a6
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_set.cc
@@ -0,0 +1,68 @@
+#include "tensorflow/core/common_runtime/device_set.h"
+
+#include <set>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace tensorflow {
+
+DeviceSet::DeviceSet() {}
+
+DeviceSet::~DeviceSet() {}
+
+void DeviceSet::AddDevice(Device* device) {
+ devices_.push_back(device);
+ device_by_name_.insert({device->name(), device});
+}
+
+void DeviceSet::FindMatchingDevices(const DeviceNameUtils::ParsedName& spec,
+ std::vector<Device*>* devices) const {
+ // TODO(jeff): If we are going to repeatedly lookup the set of devices
+ // for the same spec, maybe we should have a cache of some sort
+ devices->clear();
+ for (Device* d : devices_) {
+ if (DeviceNameUtils::IsCompleteSpecification(spec, d->parsed_name())) {
+ devices->push_back(d);
+ }
+ }
+}
+
+Device* DeviceSet::FindDeviceByName(const string& name) const {
+ return gtl::FindPtrOrNull(device_by_name_, name);
+}
+
+// Higher result implies lower priority.
+static int Order(const DeviceType& d) {
+ if (StringPiece(d.type()) == DEVICE_CPU) {
+ return 3;
+ } else if (StringPiece(d.type()) == DEVICE_GPU) {
+ return 2;
+ } else {
+ return 1;
+ }
+}
+
+static bool ByPriority(const DeviceType& a, const DeviceType& b) {
+ // Order by "order number"; break ties lexicographically.
+ return std::make_pair(Order(a), StringPiece(a.type())) <
+ std::make_pair(Order(b), StringPiece(b.type()));
+}
+
+std::vector<DeviceType> DeviceSet::PrioritizedDeviceTypeList() const {
+ std::vector<DeviceType> result;
+ std::set<string> seen;
+ for (Device* d : devices_) {
+ auto t = d->device_type();
+ if (seen.insert(t).second) {
+ result.emplace_back(DeviceType(t));
+ }
+ }
+ std::sort(result.begin(), result.end(), ByPriority);
+ return result;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_set.h b/tensorflow/core/common_runtime/device_set.h
new file mode 100644
index 0000000000..130d965891
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_set.h
@@ -0,0 +1,64 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_
+#define TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_
+
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+// DeviceSet is a container class for managing the various types of
+// devices used by a model.
+class DeviceSet {
+ public:
+ DeviceSet();
+ ~DeviceSet();
+
+ // Does not take ownership of 'device'.
+ void AddDevice(Device* device);
+
+ // Set the device designated as the "client". This device
+ // must also be registered via AddDevice().
+ void set_client_device(Device* device) { client_device_ = device; }
+
+ // Returns a pointer to the device designated as the "client".
+ Device* client_device() const { return client_device_; }
+
+ // Return the list of devices in this set.
+ const std::vector<Device*>& devices() const { return devices_; }
+
+ // Given a DeviceNameUtils::ParsedName (which may have some
+ // wildcards for different components), fills "*devices" with all
+ // devices in "*this" that match "spec".
+ void FindMatchingDevices(const DeviceNameUtils::ParsedName& spec,
+ std::vector<Device*>* devices) const;
+
+ // Finds the device with the given "fullname". Returns nullptr if
+ // not found.
+ Device* FindDeviceByName(const string& fullname) const;
+
+ // Return the list of unique device types in this set, ordered
+ // with more preferable devices earlier.
+ std::vector<DeviceType> PrioritizedDeviceTypeList() const;
+
+ private:
+ // Not owned.
+ std::vector<Device*> devices_;
+
+ // Fullname -> device* for device in devices_.
+ std::unordered_map<string, Device*> device_by_name_;
+
+ // client_device_ points to an element of devices_ that we consider
+ // to be the client device (in this local process).
+ Device* client_device_ = nullptr;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DeviceSet);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_
diff --git a/tensorflow/core/common_runtime/device_set_test.cc b/tensorflow/core/common_runtime/device_set_test.cc
new file mode 100644
index 0000000000..1b80a5b697
--- /dev/null
+++ b/tensorflow/core/common_runtime/device_set_test.cc
@@ -0,0 +1,65 @@
+#include "tensorflow/core/common_runtime/device_set.h"
+
+#include "tensorflow/core/public/status.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+// Return a fake device with the specified type and name.
+static Device* Dev(const char* type, const char* name) {
+ class FakeDevice : public Device {
+ public:
+ explicit FakeDevice(const DeviceAttributes& attr)
+ : Device(nullptr, attr, nullptr) {}
+ Status Sync() override { return Status::OK(); }
+ Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
+ };
+ DeviceAttributes attr;
+ attr.set_name(name);
+ attr.set_device_type(type);
+ return new FakeDevice(attr);
+}
+
+class DeviceSetTest : public testing::Test {
+ public:
+ void AddDevice(const char* type, const char* name) {
+ Device* d = Dev(type, name);
+ owned_.emplace_back(d);
+ devices_.AddDevice(d);
+ }
+
+ std::vector<DeviceType> types() const {
+ return devices_.PrioritizedDeviceTypeList();
+ }
+
+ private:
+ DeviceSet devices_;
+ std::vector<std::unique_ptr<Device>> owned_;
+};
+
+TEST_F(DeviceSetTest, PrioritizedDeviceTypeList) {
+ EXPECT_EQ(std::vector<DeviceType>{}, types());
+
+ AddDevice("CPU", "/job:a/replica:0/task:0/cpu:0");
+ EXPECT_EQ(std::vector<DeviceType>{DeviceType(DEVICE_CPU)}, types());
+
+ AddDevice("CPU", "/job:a/replica:0/task:0/cpu:1");
+ EXPECT_EQ(std::vector<DeviceType>{DeviceType(DEVICE_CPU)}, types());
+
+ AddDevice("GPU", "/job:a/replica:0/task:0/gpu:0");
+ EXPECT_EQ(
+ (std::vector<DeviceType>{DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}),
+ types());
+
+ AddDevice("T1", "/job:a/replica:0/task:0/device:T1:0");
+ AddDevice("T1", "/job:a/replica:0/task:0/device:T1:1");
+ AddDevice("T2", "/job:a/replica:0/task:0/device:T2:0");
+ EXPECT_EQ(
+ (std::vector<DeviceType>{DeviceType("T1"), DeviceType("T2"),
+ DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}),
+ types());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eigen_thread_pool.h b/tensorflow/core/common_runtime/eigen_thread_pool.h
new file mode 100644
index 0000000000..2554f3521b
--- /dev/null
+++ b/tensorflow/core/common_runtime/eigen_thread_pool.h
@@ -0,0 +1,22 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_EIGEN_THREAD_POOL_H_
+#define TENSORFLOW_COMMON_RUNTIME_EIGEN_THREAD_POOL_H_
+
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
+ public:
+ explicit EigenThreadPoolWrapper(thread::ThreadPool* pool) : pool_(pool) {}
+ ~EigenThreadPoolWrapper() override {}
+
+ void Schedule(std::function<void()> fn) override { pool_->Schedule(fn); }
+
+ private:
+ thread::ThreadPool* pool_ = nullptr;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_EIGEN_THREAD_POOL_H_
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
new file mode 100644
index 0000000000..7f2473f93b
--- /dev/null
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -0,0 +1,2118 @@
+#include "tensorflow/core/common_runtime/executor.h"
+
+#include <atomic>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+#include <deque>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/control_flow.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/op_segment.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/edgeset.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/tensor_slice_reader_cache.h"
+
+namespace tensorflow {
+
+namespace {
+
+// 1-D, 0 element tensor.
+static const Tensor* const kEmptyTensor = new Tensor;
+
+bool IsInitializationOp(const Node* node) {
+ return node->op_def().allows_uninitialized_input();
+}
+
+// Sets the timeline_label field of *node_stats, using data from *node.
+// Returns true iff the node is a transfer node.
+// TODO(tucker): merge with the DetailText function in session.cc
+// in a common location.
+bool SetTimelineLabel(const Node* node, NodeExecStats* node_stats) {
+ bool is_transfer_node = false;
+ string memory;
+ for (auto& all : node_stats->memory()) {
+ int64 tot = all.total_bytes();
+ if (tot >= 0.1 * 1048576.0) {
+ int64 peak = all.peak_bytes();
+ if (peak > 0) {
+ memory =
+ strings::StrCat(memory, "[", all.allocator_name(),
+ strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0,
+ peak / 1048576.0));
+ } else {
+ memory = strings::StrCat(memory, "[", all.allocator_name(),
+ strings::Printf(" %.1fMB] ", tot / 1048576.0));
+ }
+ }
+ }
+ const NodeDef& def = node->def();
+ string text = "";
+ if (IsSend(node)) {
+ string tensor_name;
+ TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name));
+ string recv_device;
+ TF_CHECK_OK(GetNodeAttr(def, "recv_device", &recv_device));
+ text = strings::StrCat(memory, def.name(), " = ", def.op(), "(",
+ tensor_name, " @", recv_device);
+ is_transfer_node = true;
+ } else if (IsRecv(node)) {
+ string tensor_name;
+ TF_CHECK_OK(GetNodeAttr(def, "tensor_name", &tensor_name));
+ string send_device;
+ TF_CHECK_OK(GetNodeAttr(def, "send_device", &send_device));
+ text = strings::StrCat(memory, def.name(), " = ", def.op(), "(",
+ tensor_name, " @", send_device);
+ is_transfer_node = true;
+ } else {
+ text = strings::StrCat(
+ memory, def.name(), " = ", def.op(), "(",
+ str_util::Join(
+ std::vector<StringPiece>(def.input().begin(), def.input().end()),
+ ", "),
+ ")");
+ }
+ node_stats->set_timeline_label(text);
+ return is_transfer_node;
+}
+
+// Helper routines for collecting step stats.
+namespace nodestats {
+inline int64 NowInUsec() { return Env::Default()->NowMicros(); }
+
+void SetScheduled(NodeExecStats* nt, int64 t) { nt->set_scheduled_micros(t); }
+
+void SetAllStart(NodeExecStats* nt) { nt->set_all_start_micros(NowInUsec()); }
+
+void SetOpStart(NodeExecStats* nt) {
+ DCHECK_NE(nt->all_start_micros(), 0);
+ nt->set_op_start_rel_micros(NowInUsec() - nt->all_start_micros());
+}
+
+void SetOpEnd(NodeExecStats* nt) {
+ DCHECK_NE(nt->all_start_micros(), 0);
+ nt->set_op_end_rel_micros(NowInUsec() - nt->all_start_micros());
+}
+
+void SetAllEnd(NodeExecStats* nt) {
+ DCHECK_NE(nt->all_start_micros(), 0);
+ nt->set_all_end_rel_micros(NowInUsec() - nt->all_start_micros());
+}
+
+void SetOutput(NodeExecStats* nt, int slot, AllocationType allocation_type,
+ const Tensor* v) {
+ DCHECK(v);
+ NodeOutput* no = nt->add_output();
+ no->set_slot(slot);
+ no->set_allocation_type(allocation_type);
+ v->FillDescription(no->mutable_tensor_description());
+}
+
+void SetMemory(NodeExecStats* nt, OpKernelContext* ctx) {
+ for (const auto& allocator_pair : ctx->wrapped_allocators()) {
+ AllocatorMemoryUsed* memory = nt->add_memory();
+ // retrieving the sizes from the wrapped allocator removes the
+ // executor's reference to it, so allocator_pair.second must not
+ // be dereferenced again after this statement
+ auto sizes = allocator_pair.second->GetSizesAndUnRef();
+ memory->set_allocator_name(allocator_pair.first->Name());
+ int tb = sizes.first;
+ memory->set_total_bytes(tb);
+ if (allocator_pair.first->TracksAllocationSizes()) {
+ memory->set_peak_bytes(sizes.second);
+ }
+ }
+}
+} // namespace nodestats
+
+struct NodeItem {
+ // A graph node.
+ const Node* node = nullptr;
+
+ // The kernel for this node.
+ OpKernel* kernel = nullptr;
+
+ // ExecutorImpl::tensors_[input_start] is the 1st positional input
+ // for this node.
+ int input_start = 0;
+};
+
+// Map from std::pair<node_id, output_index> to attributes.
+struct pairhash {
+ public:
+ template <typename T, typename U>
+ std::size_t operator()(const std::pair<T, U>& x) const {
+ return std::hash<T>()(x.first) ^ std::hash<U>()(x.second);
+ }
+};
+typedef std::unordered_map<std::pair<int, int>, AllocatorAttributes, pairhash>
+ DevAttrMap;
+
+typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
+typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec;
+typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
+
+class ExecutorImpl : public Executor {
+ public:
+ ExecutorImpl(const LocalExecutorParams& p, const Graph* g)
+ : params_(p), graph_(g) {
+ CHECK(p.create_kernel != nullptr);
+ CHECK(p.delete_kernel != nullptr);
+ }
+
+ ~ExecutorImpl() override {
+ for (NodeItem& item : nodes_) {
+ params_.delete_kernel(item.kernel);
+ }
+ delete graph_;
+ }
+
+ Status Initialize();
+
+ // Infer memory allocation attributes of a node n's output,
+ // based on its use node dst. Note that dst might not be directly
+ // connected to n by a single edge, but might be a downstream
+ // consumer of n's output by reference. *attr is updated with any
+ // necessary attributes.
+ Status InferAllocAttr(const Node* n, const Node* dst,
+ const DeviceNameUtils::ParsedName& local_dev_name,
+ AllocatorAttributes* attr);
+
+ // Process all Nodes in the current graph, attempting to infer the
+ // memory allocation attributes to be used wherever they may allocate
+ // a tensor buffer.
+ Status SetAllocAttrs();
+
+ void RunAsync(const Args& args, DoneCallback done) override;
+
+ private:
+ friend class ExecutorState;
+ friend class SimpleExecutorState;
+
+ // Owned.
+ LocalExecutorParams params_;
+ const Graph* graph_;
+ std::vector<NodeItem> nodes_; // nodes_.size == graph_.num_node_ids().
+ int total_tensors_ = 0; // total_tensors_ = sum(nodes_[*].num_inputs())
+
+ // The number of inputs for each frame in this graph. This is static
+ // information of the graph.
+ std::unordered_map<string, int> frame_input_count_;
+
+ DevAttrMap alloc_attr_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
+};
+
+Status ExecutorImpl::Initialize() {
+ const int num_nodes = graph_->num_node_ids();
+ nodes_.resize(num_nodes);
+
+ Status s;
+ total_tensors_ = 0;
+
+ // Preprocess every node in the graph to create an instance of op
+ // kernel for each node;
+ for (const Node* n : graph_->nodes()) {
+ const int id = n->id();
+ NodeItem* item = &nodes_[id];
+ item->node = n;
+ item->input_start = total_tensors_;
+ total_tensors_ += n->num_inputs();
+ s = params_.create_kernel(n->def(), &item->kernel);
+ if (!s.ok()) {
+ s = AttachDef(s, n->def());
+ LOG(ERROR) << "Executor failed to create kernel. " << s;
+ break;
+ }
+ CHECK(item->kernel);
+
+ // Initialize static information about the frames in the graph.
+ if (IsEnter(n)) {
+ string frame_name;
+ s = GetNodeAttr(n->def(), "frame_name", &frame_name);
+ if (!s.ok()) return s;
+ ++frame_input_count_[frame_name];
+ }
+ }
+ if (params_.has_control_flow) {
+ VLOG(2) << "Graph has control flow.";
+ }
+ if (!s.ok()) return s;
+ return SetAllocAttrs();
+}
+
+Status ExecutorImpl::SetAllocAttrs() {
+ Status s;
+ Device* device = params_.device;
+ DeviceNameUtils::ParsedName local_dev_name = device->parsed_name();
+
+ for (const Node* n : graph_->nodes()) {
+ // Examine the out edges of each node looking for special use
+ // cases that may affect memory allocation attributes.
+ for (auto e : n->out_edges()) {
+ AllocatorAttributes attr;
+ s = InferAllocAttr(n, e->dst(), local_dev_name, &attr);
+ if (!s.ok()) return s;
+ if (attr.value != 0) {
+ VLOG(2) << "node " << n->name() << " gets attr " << attr.value
+ << " for output " << e->src_output();
+ alloc_attr_[std::make_pair(n->id(), e->src_output())].Merge(attr);
+ } else {
+ VLOG(2) << "default output attr for node " << n->name() << " output "
+ << e->src_output();
+ }
+ }
+ }
+ return s;
+}
+
+Status ExecutorImpl::InferAllocAttr(
+ const Node* n, const Node* dst,
+ const DeviceNameUtils::ParsedName& local_dev_name,
+ AllocatorAttributes* attr) {
+ Status s;
+ if (IsSend(dst)) {
+ string dst_name;
+ s = GetNodeAttr(dst->def(), "recv_device", &dst_name);
+ if (!s.ok()) return s;
+ DeviceNameUtils::ParsedName parsed_dst_name;
+ if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) {
+ s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ",
+ n->name());
+ return s;
+ }
+ if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) {
+ // Value is going to be the source of an RPC.
+ attr->set_nic_compatible(true);
+ VLOG(2) << "node " << n->name() << " is the source of an RPC out";
+ } else if (local_dev_name.type == "CPU" && parsed_dst_name.type == "GPU") {
+ // Value is going to be the source of a local DMA from CPU to GPU.
+ attr->set_gpu_compatible(true);
+ VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy";
+ } else {
+ VLOG(2) << "default alloc case local type " << local_dev_name.type
+ << " remote type " << parsed_dst_name.type;
+ }
+ } else if (dst->type_string() == "ToFloat") {
+ for (auto e : dst->out_edges()) {
+ s = InferAllocAttr(n, e->dst(), local_dev_name, attr);
+ if (!s.ok()) return s;
+ }
+ }
+ return s;
+}
+
+// The state associated with one invokation of ExecutorImpl::Run.
+// ExecutorState dispatches nodes when they become ready and keeps
+// track of how many predecessors of a node have not done (pending_).
+class ExecutorState {
+ public:
+ ExecutorState(const Executor::Args& args, ExecutorImpl* impl);
+ ~ExecutorState();
+
+ void RunAsync(Executor::DoneCallback done);
+
+ private:
+ typedef ExecutorState ME;
+
+ // Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
+ // TODO(yuanbyu): A better way to do "has_value"?
+ struct Entry {
+ Tensor val = *kEmptyTensor; // A tensor value.
+ Tensor* ref = nullptr; // A tensor reference.
+ mutex* ref_mu = nullptr; // mutex for *ref if ref is not nullptr.
+ bool has_value = false; // Whether the value exists
+
+ // Every entry carries an optional DeviceContext containing
+ // Device-specific information about how the Tensor was produced.
+ DeviceContext* device_context = nullptr;
+
+ // The attributes of the allocator that creates the tensor.
+ AllocatorAttributes alloc_attr;
+ };
+
+ // Contains a map from node id to the DeviceContext object that was
+ // assigned by the device at the beginning of a step.
+ DeviceContextMap device_context_map_;
+
+ struct IterationState {
+ // The state of an iteration.
+
+ // The pending count for each graph node. One copy per iteration.
+ // Iteration i can be garbage collected when it is done.
+ // TODO(yuanbyu): This vector currently has size of the number of nodes
+ // in this partition. This is not efficient if the subgraph for the frame
+ // is only a small subset of the partition. We should make the vector
+ // size to be only the size of the frame subgraph.
+ std::vector<int>* pending_count;
+
+ // The dead input count for each graph node. One copy per iteration.
+ std::vector<int>* dead_count;
+
+ // One copy per iteration. For iteration k, i-th node's j-th input is in
+ // input_tensors[k][impl_->nodes[i].input_start + j]. An entry is either
+ // a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
+ //
+ // NOTE: No need to protect input_tensors[i] by any locks because it
+ // is resized once. Each element of tensors_ is written once by the
+ // source node of an edge and is cleared by the destination of the same
+ // edge. The latter node is never run concurrently with the former node.
+ std::vector<Entry>* input_tensors;
+
+ // The number of outstanding ops for each iteration.
+ int outstanding_ops;
+
+ // The number of outstanding frames for each iteration.
+ int outstanding_frame_count;
+
+ ~IterationState() {
+ delete pending_count;
+ delete dead_count;
+ delete input_tensors;
+ }
+ };
+
+ struct FrameState {
+ // A new frame is created for each loop. Execution starts at iteration 0.
+ // When a value at iteration 0 passes through a NextIteration node,
+ // iteration 1 is created and starts running. Note that iteration 0 may
+ // still be running so multiple iterations may run in parallel. The
+ // frame maintains the state of iterations in several data structures
+ // such as pending_count and input_tensors. When iteration 0 completes,
+ // we garbage collect the state of iteration 0.
+ //
+ // A frame instance is considered "done" and can be garbage collected
+ // if all its inputs have entered and all its iterations are "done".
+ //
+ // A frame manages the live iterations of an iterative computation.
+ // Iteration i is considered "done" when there are no outstanding ops,
+ // frames at iteration i are done, all recvs for this iteration are
+ // completed, and iteration i-1 is done. For iteration 0, we instead
+ // wait for there to be no more pending inputs of the frame.
+ //
+ // Frames and iterations are garbage collected once they are done.
+ // The state we need to keep around is highly dependent on the
+ // parallelism enabled by the scheduler. We may want to have the
+ // scheduler dynamically control the outstanding number of live
+ // parallel frames and iterations. To reduce the state space, the
+ // scheduler might want to schedule ops in inner frames first and
+ // lower iterations first.
+ //
+ // This frame state is mostly initialized lazily on demand so we
+ // don't introduce unnecessary overhead.
+
+ // The name of this frame, which is the concatenation of its parent
+ // frame name, the iteration of the parent frame when this frame was
+ // created, and the value of the attr 'frame_name'.
+ string frame_name;
+
+ // The unique id for this frame. Generated by fingerprinting
+ // frame_name.
+ uint64 frame_id;
+
+ // The iteration id of its parent frame when this frame is created.
+ // -1 if there is no parent frame. The frame_name/parent_iter pair
+ // uniquely identifies this FrameState.
+ int64 parent_iter = -1;
+
+ // The FrameState of its parent frame.
+ FrameState* parent_frame = nullptr;
+
+ // The highest iteration number we have reached so far in this frame.
+ int64 iteration_count = 0;
+
+ // The number of inputs this frame is still waiting.
+ int num_pending_inputs = 0;
+
+ // The number of outstanding iterations.
+ int num_outstanding_iterations = 0;
+
+ // The maximum allowed number of parallel iterations.
+ int max_parallel_iterations = 1;
+
+ // The iteration states of this frame.
+ std::vector<IterationState*> iterations;
+
+ // The NextIteration nodes to enter a new iteration. If the number of
+ // outstanding iterations reaches the limit, we will defer the start of
+ // the next iteration until the number of outstanding iterations falls
+ // below the limit.
+ std::vector<std::pair<const Node*, Entry>> next_iter_roots;
+
+ // The values of the loop invariants for this loop. They are added into
+ // this list as they "enter" the frame. When a loop invariant enters,
+ // we make it available to all active iterations. When the frame starts
+ // a new iteration, we make all the current loop invariants available
+ // to the new iteration.
+ std::vector<std::pair<const Node*, Entry>> inv_values;
+
+ // The list of dead exit nodes for the current highest iteration. We
+ // will only "execute" the dead exits of the final iteration.
+ std::vector<const Node*> dead_exits;
+
+ IterationState* GetIteration(int64 iter) {
+ int index = iter % iterations.size();
+ return iterations[index];
+ }
+
+ void SetIteration(int64 iter, IterationState* state) {
+ int index = iter % iterations.size();
+ iterations[index] = state;
+ }
+
+ ~FrameState() {
+ for (size_t i = 0; i < iterations.size(); ++i) {
+ delete iterations[i];
+ iterations[i] = nullptr;
+ }
+ }
+ };
+
+ // A tagged node: <frame*, iter, node*>.
+ struct TaggedNode {
+ const Node* node = nullptr;
+ FrameState* input_frame = nullptr;
+ int64 input_iter = -1;
+ bool is_dead = false;
+
+ TaggedNode(const Node* t_node, FrameState* in_frame, int64 in_iter,
+ bool dead) {
+ node = t_node;
+ input_frame = in_frame;
+ input_iter = in_iter;
+ is_dead = dead;
+ }
+ };
+
+ typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
+ typedef gtl::InlinedVector<Entry, 4> EntryVector;
+
+ // Not owned.
+ Rendezvous* rendezvous_;
+ StepStatsCollector* stats_collector_;
+ // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper instead of a
+ // pointer? (avoids having to delete).
+ checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
+ FunctionCallFrame* call_frame_;
+ const ExecutorImpl* impl_;
+ CancellationManager* cancellation_manager_;
+ Executor::Args::Runner runner_;
+
+ // Owned.
+
+ // Step-local resource manager.
+ ResourceMgr step_resource_manager_;
+
+ // The root frame in which the execution of this step is started.
+ FrameState* root_frame_;
+
+ // Invoked when the execution finishes.
+ Executor::DoneCallback done_cb_;
+
+ std::atomic_int_fast32_t num_outstanding_ops_;
+
+ mutex mu_;
+ Status status_ GUARDED_BY(mu_);
+
+ // Mapping from frame name to outstanding frames. A new frame is created
+ // at some iteration of an active frame. So the unique key for the new
+ // child frame is composed of the name of the parent frame, the iteration
+ // number at which the parent frame is creating the new frame, and the
+ // name of the new frame from nodedef.
+ std::unordered_map<string, FrameState*> outstanding_frames_ GUARDED_BY(mu_);
+
+ // The unique name of a frame.
+ inline string MakeFrameName(FrameState* frame, int64 iter_id, string name) {
+ return strings::StrCat(frame->frame_name, ";", iter_id, ";", name);
+ }
+
+ // Initialize the pending count for a graph.
+ static void InitializePending(const Graph* graph, std::vector<int>* pending);
+
+ // Find an existing or create a new child frame in the frame 'frame' at
+ // iteration 'iter'.
+ void FindOrCreateChildFrame(FrameState* frame, int64 iter, const Node* node,
+ FrameState** child) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Increments the iteration id. If this is a new iteration, initialize it.
+ void IncrementIteration(FrameState* frame, TaggedNodeSeq* ready)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Returns true if the computation in the frame is completed.
+ bool IsFrameDone(FrameState* frame) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Returns true if the iteration of the frame is completed.
+ bool IsIterationDone(FrameState* frame, int64 iter)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Get the output frame/iter of a node. Create new frame/iteration if
+ // needed. If there are dead roots for the new iteration, we need to
+ // "execute" them so ad them to the ready queue. Returns true if
+ // we need to check for the completion of output frame/iter.
+ bool SetOutputFrameIter(const TaggedNode& tagged_node,
+ const EntryVector& outputs, FrameState** frame,
+ int64* iter, TaggedNodeSeq* ready)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Cleanup frames and iterations
+ void CleanupFramesIterations(FrameState* frame, int64 iter,
+ TaggedNodeSeq* ready)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Activate all the deferred NextIteration nodes in a new iteration.
+ void ActivateNexts(FrameState* frame, int64 iter, TaggedNodeSeq* ready)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Activate all the current loop invariants in a new iteration.
+ void ActivateLoopInvs(FrameState* frame, int64 iter, TaggedNodeSeq* ready)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Add a new loop invariant and make it available to all active iterations.
+ void AddLoopInv(FrameState* frame, const Node* node, const Entry& value,
+ TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Activate the successors of a node.
+ void ActivateNode(const Node* node, const bool is_dead, FrameState* frame,
+ int64 iter, const EntryVector& outputs,
+ TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Process a ready node in current thread.
+ void Process(TaggedNode node, int64 scheduled_usec);
+
+ // Before invoking item->kernel, fills in its "inputs".
+ Status PrepareInputs(const NodeItem& item, Entry* first_input,
+ TensorValueVec* inputs,
+ DeviceContextVec* input_device_contexts,
+ AllocatorAttributeVec* input_alloc_attrs,
+ bool* is_input_dead);
+
+ // After item->kernel computation is done, processes its outputs.
+ Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
+ EntryVector* outputs, NodeExecStats* stats);
+
+ // After processing the outputs, propagates the outputs to their dsts.
+ void PropagateOutputs(const TaggedNode& tagged_node,
+ const EntryVector& outputs, TaggedNodeSeq* ready);
+
+ // "node" just finishes. Takes ownership of "stats". Returns true if
+ // execution has completed.
+ bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
+ NodeExecStats* stats, std::deque<TaggedNode>* inline_ready);
+
+ // Call Process() on all nodes in 'inline_ready'.
+ void ProcessInline(const std::deque<TaggedNode>& inline_ready);
+
+ // Schedule all the expensive nodes in 'ready', and put all the inexpensive
+ // nodes in 'ready' into 'inline_ready'.
+ void ScheduleReady(const TaggedNodeSeq& ready,
+ std::deque<TaggedNode>* inline_ready);
+
+ // One thread of control finishes.
+ void Finish();
+};
+
+ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
+ : rendezvous_(args.rendezvous),
+ stats_collector_(args.stats_collector),
+ slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
+ call_frame_(args.call_frame),
+ impl_(impl),
+ cancellation_manager_(args.cancellation_manager),
+ runner_(args.runner),
+ num_outstanding_ops_(0) {
+ // We start the entire execution in iteration 0 of the root frame
+ // so let us create the root frame and the state for iteration 0.
+ // Initialize the frame.
+ root_frame_ = new FrameState;
+ root_frame_->frame_name = "_root"; // assume to be unique
+ root_frame_->frame_id = 0; // must be 0
+ root_frame_->num_pending_inputs = 0;
+ root_frame_->num_outstanding_iterations = 1;
+ root_frame_->max_parallel_iterations = 1; // enough for root frame
+ root_frame_->iterations.resize(root_frame_->max_parallel_iterations);
+
+ VLOG(2) << "Create frame: " << root_frame_->frame_name;
+
+ // Initialize the iteration.
+ IterationState* iter_state = new IterationState;
+ root_frame_->iterations[0] = iter_state;
+ iter_state->outstanding_ops = 0;
+ iter_state->outstanding_frame_count = 0;
+ iter_state->pending_count = new std::vector<int>;
+ iter_state->dead_count = new std::vector<int>(impl->graph_->num_node_ids());
+ iter_state->input_tensors = new std::vector<Entry>(impl_->total_tensors_);
+
+ // Initialize the executor state.
+ outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
+}
+
+ExecutorState::~ExecutorState() {
+ for (auto name_frame : outstanding_frames_) {
+ delete name_frame.second;
+ }
+
+ for (auto it : device_context_map_) {
+ it.second->Unref();
+ }
+
+ delete slice_reader_cache_;
+}
+
+void ExecutorState::InitializePending(const Graph* graph,
+ std::vector<int>* pending) {
+ pending->resize(graph->num_node_ids());
+ for (const Node* n : graph->nodes()) {
+ const int id = n->id();
+ const int num_in_edges = n->in_edges().size();
+ if (IsMerge(n)) {
+ // merge waits all control inputs so we initialize the pending
+ // count to be the number of control edges.
+ int32 num_control_edges = 0;
+ for (const Edge* edge : n->in_edges()) {
+ if (edge->IsControlEdge()) {
+ num_control_edges++;
+ }
+ }
+ // Use bit 0 to indicate if there is a ready live data input.
+ (*pending)[id] = num_control_edges << 1;
+ } else {
+ (*pending)[id] = num_in_edges;
+ }
+ }
+}
+
+void ExecutorState::RunAsync(Executor::DoneCallback done) {
+ const Graph* graph = impl_->graph_;
+ TaggedNodeSeq ready;
+
+ {
+ // Initialize the executor state. We grab the mutex here just to
+ // keep the thread safety analysis happy.
+ mutex_lock l(mu_);
+ std::vector<int>* pending = root_frame_->iterations[0]->pending_count;
+ InitializePending(graph, pending);
+ }
+
+ // Ask the device to fill in the device context map.
+ Device* device = impl_->params_.device;
+ device->FillContextMap(graph, &device_context_map_);
+
+ // Initialize the ready queue.
+ for (const Node* n : graph->nodes()) {
+ const int num_in_edges = n->in_edges().size();
+ if (num_in_edges == 0) {
+ ready.push_back(TaggedNode{n, root_frame_, 0, false});
+ }
+ }
+ if (ready.empty()) {
+ done(Status::OK());
+ } else {
+ num_outstanding_ops_ = ready.size();
+ root_frame_->iterations[0]->outstanding_ops = ready.size();
+ done_cb_ = done;
+ // Schedule to run all the ready ops in thread pool.
+ ScheduleReady(ready, nullptr);
+ }
+}
+
+namespace {
+
+// This function is provided for use by OpKernelContext when allocating
+// the index'th output of node. It provides access to the
+// AllocatorAttributes computed during initialization to determine in
+// which memory region the tensor should be allocated.
+AllocatorAttributes OutputAttributes(const DevAttrMap* attr_map,
+ const Node* node,
+ const OpKernel* op_kernel, int index) {
+ DCHECK_GE(index, 0);
+
+ AllocatorAttributes attr;
+ int nid = node->id();
+ const auto& iter = attr_map->find(std::make_pair(nid, index));
+ if (iter != attr_map->end()) {
+ attr = iter->second;
+ VLOG(2) << "nondefault attr " << attr.value << " for node " << node->name()
+ << " output " << index;
+ } else {
+ VLOG(2) << "default attr for node " << node->name() << " output " << index;
+ }
+
+ DCHECK_LT(index, op_kernel->output_memory_types().size());
+ bool on_host = op_kernel->output_memory_types()[index] == HOST_MEMORY;
+ attr.set_on_host(on_host);
+ return attr;
+}
+
+// Helpers to make a copy of 'p' and makes a copy of the input type
+// vector and the device context vector.
+//
+// NOTE: We need to make a copy of p.input for asynchronous kernel
+// because OpKernelContext methods like input_type(i) needs the param
+// points to valid input type vector. It's not an issue for sync
+// kernels because the type vector is kept on the stack.
+OpKernelContext::Params* CopyParams(const OpKernelContext::Params& p) {
+ OpKernelContext::Params* ret = new OpKernelContext::Params;
+ *ret = p;
+ ret->inputs = new TensorValueVec(*p.inputs);
+ ret->input_device_contexts = new DeviceContextVec(*p.input_device_contexts);
+ ret->input_alloc_attrs = new AllocatorAttributeVec(*p.input_alloc_attrs);
+ return ret;
+}
+
+// Helpers to delete 'p' and copies made by CopyParams.
+void DeleteParams(OpKernelContext::Params* p) {
+ delete p->inputs;
+ delete p->input_device_contexts;
+ delete p->input_alloc_attrs;
+ delete p;
+}
+
+} // namespace
+
+void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ TaggedNodeSeq ready;
+ std::deque<TaggedNode> inline_ready;
+
+ // Parameters passed to OpKernel::Compute.
+ TensorValueVec inputs;
+ DeviceContextVec input_device_contexts;
+ AllocatorAttributeVec input_alloc_attrs;
+
+ OpKernelContext::Params params;
+ Device* device = impl_->params_.device;
+ params.device = device;
+ // track allocations if and only if we are collecting statistics
+ params.track_allocations = (stats_collector_ != nullptr);
+ params.rendezvous = rendezvous_;
+ params.cancellation_manager = cancellation_manager_;
+ params.call_frame = call_frame_;
+ params.function_library = impl_->params_.function_library;
+ params.resource_manager = device->resource_manager();
+ params.step_resource_manager = &step_resource_manager_;
+ params.slice_reader_cache = slice_reader_cache_;
+ params.inputs = &inputs;
+ params.input_device_contexts = &input_device_contexts;
+ params.input_alloc_attrs = &input_alloc_attrs;
+
+ Status s;
+ NodeExecStats* stats = nullptr;
+ EntryVector outputs;
+ bool completed = false;
+ inline_ready.push_back(tagged_node);
+ while (!inline_ready.empty()) {
+ tagged_node = inline_ready.front();
+ inline_ready.pop_front();
+ const Node* node = tagged_node.node;
+ FrameState* input_frame = tagged_node.input_frame;
+ int64 input_iter = tagged_node.input_iter;
+ const int id = node->id();
+ const NodeItem& item = nodes[id];
+
+ // Set the device_context for this node id, if it exists.
+ auto dc_it = device_context_map_.find(id);
+ if (dc_it != device_context_map_.end()) {
+ params.op_device_context = dc_it->second;
+ }
+
+ if (stats_collector_) {
+ stats = new NodeExecStats;
+ stats->set_node_name(node->name());
+ nodestats::SetScheduled(stats, scheduled_usec);
+ nodestats::SetAllStart(stats);
+ }
+
+ VLOG(1) << "Process node: " << id << " " << SummarizeNodeDef(node->def());
+
+ std::vector<Entry>* input_tensors;
+ {
+ // Need the lock because the iterations vector could be resized by
+ // another thread.
+ mutex_lock l(mu_);
+ input_tensors = input_frame->GetIteration(input_iter)->input_tensors;
+ }
+ Entry* first_input = input_tensors->data() + item.input_start;
+ outputs.clear();
+ outputs.resize(node->num_outputs());
+
+ // Only execute this node if it is not dead or it is a send/recv
+ // transfer node. For transfer nodes, we need to propagate the "dead"
+ // bit even when the node is dead.
+ AsyncOpKernel* async = nullptr;
+ if (!tagged_node.is_dead || IsTransferNode(node)) {
+ // Prepares inputs.
+ bool is_input_dead = false;
+ s = PrepareInputs(item, first_input, &inputs, &input_device_contexts,
+ &input_alloc_attrs, &is_input_dead);
+ if (!s.ok()) {
+ // Continue to process the nodes in 'inline_ready'.
+ completed = NodeDone(s, item.node, ready, stats, &inline_ready);
+ continue;
+ }
+
+ // Set up compute params.
+ OpKernel* op_kernel = item.kernel;
+ params.op_kernel = op_kernel;
+ params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
+ params.is_input_dead = is_input_dead;
+ params.output_alloc_attr = [this, node, op_kernel](int index) {
+ return OutputAttributes(&impl_->alloc_attr_, node, op_kernel, index);
+ };
+
+ async = op_kernel->AsAsync();
+ if (async) {
+ // Asynchronous computes.
+ auto pcopy = CopyParams(params);
+ auto ctx = new OpKernelContext(*pcopy);
+ auto done = [this, tagged_node, item, first_input, ctx, stats,
+ pcopy]() {
+ VLOG(2) << this << " Async kernel done: "
+ << SummarizeNodeDef(item.node->def());
+ if (stats_collector_) nodestats::SetOpEnd(stats);
+ EntryVector outputs;
+ Status s = ProcessOutputs(item, ctx, &outputs, stats);
+ if (stats_collector_) nodestats::SetMemory(stats, ctx);
+ // Clears inputs.
+ int num_inputs = tagged_node.node->num_inputs();
+ for (int i = 0; i < num_inputs; ++i) {
+ (first_input + i)->val = *kEmptyTensor;
+ }
+ TaggedNodeSeq ready;
+ if (s.ok()) {
+ PropagateOutputs(tagged_node, outputs, &ready);
+ }
+ // Schedule to run all the ready ops in thread pool.
+ bool completed = NodeDone(s, item.node, ready, stats, nullptr);
+ delete ctx;
+ DeleteParams(pcopy);
+ if (completed) Finish();
+ };
+ if (stats_collector_) nodestats::SetOpStart(stats);
+ device->ComputeAsync(async, ctx, done);
+ } else {
+ // Synchronous computes.
+ OpKernelContext ctx(params);
+ if (stats_collector_) nodestats::SetOpStart(stats);
+ device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
+ if (stats_collector_) nodestats::SetOpEnd(stats);
+
+ // Processes outputs.
+ s = ProcessOutputs(item, &ctx, &outputs, stats);
+ if (stats_collector_) nodestats::SetMemory(stats, &ctx);
+ }
+ }
+
+ if (!async) {
+ // Clears inputs.
+ int num_inputs = node->num_inputs();
+ for (int i = 0; i < num_inputs; ++i) {
+ (first_input + i)->val = *kEmptyTensor;
+ }
+ // Propagates outputs.
+ if (s.ok()) {
+ PropagateOutputs(tagged_node, outputs, &ready);
+ }
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ // Postprocess.
+ completed = NodeDone(s, item.node, ready, stats, &inline_ready);
+ }
+ } // while !inline_ready.empty()
+
+ // This thread of computation is done if completed = true.
+ if (completed) Finish();
+}
+
+Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
+ TensorValueVec* inputs,
+ DeviceContextVec* input_device_contexts,
+ AllocatorAttributeVec* input_alloc_attrs,
+ bool* is_input_dead) {
+ const Node* node = item.node;
+
+ inputs->clear();
+ inputs->resize(node->num_inputs());
+ input_device_contexts->clear();
+ input_device_contexts->resize(node->num_inputs());
+ input_alloc_attrs->clear();
+ input_alloc_attrs->resize(node->num_inputs());
+
+ *is_input_dead = false;
+
+ bool is_merge = IsMerge(node);
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ const bool expect_ref = IsRefType(node->input_type(i));
+ Entry* entry = first_input + i;
+ (*input_device_contexts)[i] = entry->device_context;
+ (*input_alloc_attrs)[i] = entry->alloc_attr;
+
+ // i-th input.
+ TensorValue* inp = &(*inputs)[i];
+
+ // Only merge and transfer nodes can have no-value inputs.
+ if (!entry->has_value) {
+ if (!is_merge) {
+ DCHECK(IsTransferNode(node));
+ inp->tensor = &entry->val;
+ *is_input_dead = true;
+ }
+ continue;
+ }
+ if (entry->ref == nullptr) {
+ if (expect_ref) {
+ return AttachDef(
+ errors::InvalidArgument(i, "-th input expects a ref type"),
+ item.kernel->def());
+ }
+ inp->tensor = &entry->val;
+ } else {
+ if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) {
+ return AttachDef(
+ errors::FailedPrecondition("Attempting to use uninitialized value ",
+ item.kernel->def().input(i)),
+ item.kernel->def());
+ }
+ if (expect_ref) {
+ inp->mutex_if_ref = entry->ref_mu;
+ inp->tensor = entry->ref;
+ } else {
+ // Automatically deref the tensor ref when the op expects a
+ // tensor but is given a ref to a tensor. Need to deref it
+ // under the mutex.
+ {
+ mutex_lock l(*(entry->ref_mu));
+ entry->val = *entry->ref;
+ }
+ inp->tensor = &entry->val;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
+ EntryVector* outputs,
+ NodeExecStats* stats) {
+ const Node* node = item.node;
+ outputs->clear();
+ outputs->resize(node->num_outputs());
+
+ Status s = ctx->status();
+ if (!s.ok()) {
+ s = AttachDef(s, item.kernel->def());
+ LOG(WARNING) << this << " Compute status: " << s;
+ return s;
+ }
+
+ // Get the device_context for this node id, if it exists.
+ DeviceContext* device_context = nullptr;
+ auto dc_it = device_context_map_.find(node->id());
+ if (dc_it != device_context_map_.end()) {
+ device_context = dc_it->second;
+ }
+
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ TensorValue val = ctx->release_output(i);
+ // Only Switch and Recv nodes can generate new dead outputs
+ if (*ctx->is_output_dead() || val.tensor == nullptr) {
+ DCHECK(IsSwitch(node) || IsRecv(node));
+ } else {
+ Entry* out = &((*outputs)[i]);
+ out->has_value = true;
+
+ // Set the device context of the output entry.
+ out->device_context = device_context;
+
+ // Set the allocator attributes of the output entry.
+ out->alloc_attr = ctx->output_alloc_attr(i);
+
+ // Sanity check of output tensor types.
+ DataType dtype = val->dtype();
+ if (val.is_ref()) dtype = MakeRefType(dtype);
+ if (dtype == node->output_type(i)) {
+ if (val.is_ref()) {
+ out->ref = val.tensor;
+ out->ref_mu = val.mutex_if_ref;
+ } else {
+ out->val = *val.tensor;
+ }
+ if (stats_collector_ && val.tensor->IsInitialized()) {
+ nodestats::SetOutput(stats, i, ctx->output_allocation_type(i),
+ val.tensor);
+ }
+ } else {
+ s.Update(errors::Internal("Output ", i, " of type ",
+ DataTypeString(dtype),
+ " does not match declared output type ",
+ DataTypeString(node->output_type(i)),
+ " for node ", SummarizeNodeDef(node->def())));
+ }
+ }
+ if (!val.is_ref()) {
+ // If OpKernelContext returns outputs via pass-by-value, we
+ // don't need this trouble.
+ delete val.tensor;
+ }
+ }
+ return s;
+}
+
+void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
+ const EntryVector& outputs,
+ TaggedNodeSeq* ready) {
+ FrameState* input_frame = tagged_node.input_frame;
+ int64 input_iter = tagged_node.input_iter;
+
+ // Propagates outputs along out edges, and puts newly ready nodes
+ // into the ready queue.
+ ready->clear();
+
+ {
+ FrameState* output_frame = input_frame;
+ int64 output_iter = input_iter;
+
+ mutex_lock l(mu_);
+ // Sets the output_frame and output_iter of node.
+ bool maybe_completed = SetOutputFrameIter(
+ tagged_node, outputs, &output_frame, &output_iter, ready);
+ if (output_frame != nullptr) {
+ // Continue to process the out nodes:
+ ActivateNode(tagged_node.node, tagged_node.is_dead, output_frame,
+ output_iter, outputs, ready);
+ }
+
+ // At this point, this node is completely done.
+ input_frame->GetIteration(input_iter)->outstanding_ops--;
+ CleanupFramesIterations(input_frame, input_iter, ready);
+
+ // The execution of a node such as Enter may cause the completion of
+ // output_frame:output_iter, so perform cleanup if output_frame:output_iter
+ // is indeed completed.
+ if (maybe_completed) {
+ CleanupFramesIterations(output_frame, output_iter, ready);
+ }
+ }
+}
+
+void ExecutorState::ActivateNode(const Node* node, const bool is_dead,
+ FrameState* output_frame, int64 output_iter,
+ const EntryVector& outputs,
+ TaggedNodeSeq* ready) {
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ IterationState* output_iter_state = output_frame->GetIteration(output_iter);
+ std::vector<int>* pending = output_iter_state->pending_count;
+ std::vector<int>* dead_count = output_iter_state->dead_count;
+ for (const Edge* e : node->out_edges()) {
+ const Node* dst_node = e->dst();
+ const int dst_id = dst_node->id();
+ const int src_slot = e->src_output();
+
+ bool dst_dead = false;
+ bool dst_ready = false;
+ bool dst_need_input = !e->IsControlEdge();
+ if (IsMerge(dst_node)) {
+ // A merge node is ready if a) all control edges are enabled and a
+ // live data input becomes available, or b) all control edges are
+ // enabled and all data inputs are dead.
+ if (e->IsControlEdge()) {
+ (*pending)[dst_id] -= 2;
+ int count = (*pending)[dst_id];
+ dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs());
+ dst_ready = (count == 1) || ((count == 0) && dst_dead);
+ } else {
+ if (outputs[src_slot].has_value) {
+ // This is a live data input.
+ int count = (*pending)[dst_id];
+ (*pending)[dst_id] |= 0x1;
+ dst_ready = (count == 0);
+ } else {
+ // This is a dead data input.
+ ++(*dead_count)[dst_id];
+ dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs());
+ dst_ready = ((*pending)[dst_id] == 0) && dst_dead;
+ }
+ // This input for dst is not needed if !dst_ready. We suppress the
+ // propagation to make the thread safety analysis happy.
+ dst_need_input = dst_ready;
+ }
+ } else {
+ // A non-merge node is ready if all its inputs are ready. We wait
+ // for all inputs to come in even if we know the node is dead. This
+ // ensures that all input tensors get cleaned up.
+ if (is_dead || (!e->IsControlEdge() && !outputs[src_slot].has_value)) {
+ ++(*dead_count)[dst_id];
+ }
+ dst_dead = (*dead_count)[dst_id] > 0;
+ dst_ready = (--(*pending)[dst_id] == 0);
+ }
+
+ if (dst_need_input) {
+ const NodeItem& dst_item = nodes[dst_id];
+ const int dst_slot = e->dst_input();
+ std::vector<Entry>* input_tensors = output_iter_state->input_tensors;
+ int dst_loc = dst_item.input_start + dst_slot;
+ (*input_tensors)[dst_loc] = outputs[src_slot];
+ }
+
+ // Add dst to the ready queue if it's ready
+ if (dst_ready) {
+ dst_dead = dst_dead && !IsControlTrigger(dst_node);
+ ready->push_back(
+ TaggedNode(dst_node, output_frame, output_iter, dst_dead));
+ output_iter_state->outstanding_ops++;
+ }
+ }
+}
+
+void ExecutorState::ActivateNexts(FrameState* frame, int64 iter,
+ TaggedNodeSeq* ready) {
+ // Propagate the deferred NextIteration nodes to the new iteration.
+ for (auto& node_entry : frame->next_iter_roots) {
+ const Node* node = node_entry.first;
+ const Entry& entry = node_entry.second;
+ const bool is_dead = !entry.has_value;
+ ActivateNode(node, is_dead, frame, iter, {entry}, ready);
+ }
+ frame->next_iter_roots.clear();
+}
+
+void ExecutorState::ActivateLoopInvs(FrameState* frame, int64 iter,
+ TaggedNodeSeq* ready) {
+ // Propagate loop invariants to the new iteration.
+ for (auto& node_entry : frame->inv_values) {
+ const Node* node = node_entry.first;
+ const Entry& entry = node_entry.second;
+ const bool is_dead = !entry.has_value;
+ ActivateNode(node, is_dead, frame, iter, {entry}, ready);
+ }
+}
+
+void ExecutorState::AddLoopInv(FrameState* frame, const Node* node,
+ const Entry& entry, TaggedNodeSeq* ready) {
+ // Store this value.
+ frame->inv_values.push_back({node, entry});
+
+ // Make this value available to all iterations.
+ bool is_dead = !entry.has_value;
+ for (int i = 1; i <= frame->iteration_count; ++i) {
+ ActivateNode(node, is_dead, frame, i, {entry}, ready);
+ }
+}
+
+bool ExecutorState::NodeDone(const Status& s, const Node* node,
+ const TaggedNodeSeq& ready, NodeExecStats* stats,
+ std::deque<TaggedNode>* inline_ready) {
+ if (stats_collector_) {
+ nodestats::SetAllEnd(stats);
+ if (!SetTimelineLabel(node, stats)) {
+ // Only record non-transfer nodes.
+ stats_collector_->Save(impl_->params_.device->name(), stats);
+ } else {
+ delete stats;
+ }
+ }
+
+ Rendezvous* captured_rendezvous = nullptr; // Will be set on error.
+ if (!s.ok()) {
+ // Some error happened. This thread of computation is done.
+ mutex_lock l(mu_);
+ if (status_.ok()) {
+ captured_rendezvous = rendezvous_;
+ if (captured_rendezvous) captured_rendezvous->Ref();
+ status_ = s;
+ }
+ }
+ if (captured_rendezvous) {
+ // If we captured the rendezvous_ pointer, we are in an error condition.
+ // Use captured_rendezvous, in case "this" is deleted by another thread.
+ TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
+ captured_rendezvous->StartAbort(s);
+ captured_rendezvous->Unref();
+ }
+
+ bool completed = false;
+ int ready_size = ready.size();
+ if (ready_size == 0 || !s.ok()) {
+ completed = (num_outstanding_ops_.fetch_sub(1) == 1);
+ } else if (ready_size > 1) {
+ num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed);
+ }
+
+ // Schedule the ready nodes in 'ready'.
+ if (s.ok()) {
+ ScheduleReady(ready, inline_ready);
+ }
+ return completed;
+}
+
+void ExecutorState::ProcessInline(const std::deque<TaggedNode>& inline_ready) {
+ if (inline_ready.empty()) return;
+ int64 scheduled_usec = 0;
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ for (auto& tagged_node : inline_ready) {
+ Process(tagged_node, scheduled_usec);
+ }
+}
+
+void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
+ std::deque<TaggedNode>* inline_ready) {
+ if (ready.empty()) return;
+
+ int64 scheduled_usec = 0;
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ if (inline_ready == nullptr) {
+ // Schedule to run all the ready ops in thread pool.
+ for (auto& tagged_node : ready) {
+ runner_(std::bind(&ME::Process, this, tagged_node, scheduled_usec));
+ }
+ return;
+ }
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ const TaggedNode* curr_expensive_node = nullptr;
+ for (auto& tagged_node : ready) {
+ const NodeItem& item = nodes[tagged_node.node->id()];
+ if (tagged_node.is_dead || !item.kernel->IsExpensive()) {
+ // Inline this inexpensive node.
+ inline_ready->push_back(tagged_node);
+ } else {
+ if (curr_expensive_node) {
+ // Dispatch to another thread since there is plenty of work to
+ // do for this thread.
+ runner_(std::bind(&ME::Process, this, *curr_expensive_node,
+ scheduled_usec));
+ }
+ curr_expensive_node = &tagged_node;
+ }
+ }
+ if (curr_expensive_node) {
+ if (inline_ready->empty()) {
+ // Tail recursion optimization
+ inline_ready->push_back(*curr_expensive_node);
+ } else {
+ // There are inline nodes to run already. We dispatch this expensive
+ // node to other thread.
+ runner_(
+ std::bind(&ME::Process, this, *curr_expensive_node, scheduled_usec));
+ }
+ }
+}
+
+void ExecutorState::Finish() {
+ mu_.lock();
+ auto status = status_;
+ auto done_cb = done_cb_;
+ auto runner = runner_;
+ mu_.unlock();
+ delete this;
+ CHECK(done_cb != nullptr);
+ runner([done_cb, status]() { done_cb(status); });
+}
+
+bool ExecutorState::IsFrameDone(FrameState* frame) {
+ return (frame->num_pending_inputs == 0 &&
+ frame->num_outstanding_iterations == 0);
+}
+
+bool ExecutorState::IsIterationDone(FrameState* frame, int64 iter) {
+ IterationState* iter_state = frame->GetIteration(iter);
+ if (iter_state->outstanding_ops == 0 &&
+ iter_state->outstanding_frame_count == 0) {
+ if (iter == 0) {
+ // The enclosing frame has no pending input.
+ return frame->num_pending_inputs == 0;
+ } else {
+ // The preceding iteration is deleted (and therefore done).
+ return (frame->GetIteration(iter - 1) == nullptr);
+ }
+ }
+ return false;
+}
+
+void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
+ const Node* node,
+ FrameState** child) {
+ // Get the child frame name.
+ string enter_name;
+ Status s = GetNodeAttr(node->def(), "frame_name", &enter_name);
+ CHECK(s.ok()) << s;
+ const string child_name = MakeFrameName(frame, iter, enter_name);
+
+ auto it = outstanding_frames_.find(child_name);
+ if (it != outstanding_frames_.end()) {
+ *child = it->second;
+ } else {
+ // Need to create a new frame instance.
+ VLOG(2) << "Create frame: " << child_name;
+
+ FrameState* temp = new FrameState;
+ temp->frame_name = child_name;
+ temp->frame_id = Hash64(child_name);
+ temp->parent_frame = frame;
+ temp->parent_iter = iter;
+ s = GetNodeAttr(node->def(), "parallel_iterations",
+ &temp->max_parallel_iterations);
+ CHECK(s.ok()) << s;
+ // 'iterations' is a fixed-length circular buffer.
+ temp->iterations.resize(temp->max_parallel_iterations + 1);
+ IterationState* iter_state = new IterationState;
+ temp->iterations[0] = iter_state;
+
+ iter_state->outstanding_ops = 0;
+ iter_state->outstanding_frame_count = 0;
+ iter_state->pending_count = new std::vector<int>;
+ InitializePending(impl_->graph_, iter_state->pending_count);
+ iter_state->dead_count =
+ new std::vector<int>(impl_->graph_->num_node_ids());
+ iter_state->input_tensors = new std::vector<Entry>(impl_->total_tensors_);
+
+ auto frame_pending = impl_->frame_input_count_.find(enter_name);
+ DCHECK(frame_pending != impl_->frame_input_count_.end());
+ temp->num_pending_inputs = frame_pending->second;
+ temp->num_outstanding_iterations = 1;
+ *child = temp;
+
+ frame->GetIteration(iter)->outstanding_frame_count++;
+ outstanding_frames_[child_name] = temp;
+ }
+}
+
+void ExecutorState::IncrementIteration(FrameState* frame,
+ TaggedNodeSeq* ready) {
+ frame->iteration_count++;
+ int64 next_iter = frame->iteration_count;
+
+ VLOG(2) << "Create iteration: [" << frame->frame_name << ", " << next_iter
+ << "]";
+
+ IterationState* iter_state = new IterationState;
+ frame->SetIteration(next_iter, iter_state);
+ frame->num_outstanding_iterations++;
+ frame->dead_exits.clear();
+
+ iter_state->outstanding_ops = 0;
+ iter_state->outstanding_frame_count = 0;
+ iter_state->pending_count = new std::vector<int>;
+ InitializePending(impl_->graph_, iter_state->pending_count);
+ iter_state->dead_count = new std::vector<int>(impl_->graph_->num_node_ids());
+ iter_state->input_tensors = new std::vector<Entry>(impl_->total_tensors_);
+
+ // Activate the successors of the deferred roots in the new iteration.
+ ActivateNexts(frame, next_iter, ready);
+
+ // Activate the loop invariants in the new iteration.
+ ActivateLoopInvs(frame, next_iter, ready);
+}
+
+bool ExecutorState::SetOutputFrameIter(const TaggedNode& tagged_node,
+ const EntryVector& outputs,
+ FrameState** output_frame,
+ int64* output_iter,
+ TaggedNodeSeq* ready) {
+ const Node* node = tagged_node.node;
+ FrameState* input_frame = tagged_node.input_frame;
+ int64 input_iter = tagged_node.input_iter;
+ bool is_dead = tagged_node.is_dead;
+ bool is_enter = IsEnter(node);
+
+ if (is_enter) {
+ FindOrCreateChildFrame(input_frame, input_iter, node, output_frame);
+ // Propagate if this is a loop invariant.
+ bool is_constant;
+ Status s = GetNodeAttr(node->def(), "is_constant", &is_constant);
+ CHECK(s.ok()) << s;
+ if (is_constant) {
+ AddLoopInv(*output_frame, node, outputs[0], ready);
+ }
+ --(*output_frame)->num_pending_inputs;
+ *output_iter = 0;
+ } else if (IsExit(node)) {
+ if (is_dead) {
+ // Stop and remember this node if it is a dead exit.
+ if (input_iter == input_frame->iteration_count) {
+ input_frame->dead_exits.push_back(node);
+ }
+ *output_frame = nullptr;
+ } else {
+ *output_frame = input_frame->parent_frame;
+ *output_iter = input_frame->parent_iter;
+ }
+ } else if (IsNextIteration(node)) {
+ if (is_dead) {
+ // Stop the deadness propagation
+ *output_frame = nullptr;
+ } else {
+ if (input_iter == input_frame->iteration_count &&
+ input_frame->num_outstanding_iterations ==
+ input_frame->max_parallel_iterations) {
+ // Reached the maximum for parallel iterations.
+ input_frame->next_iter_roots.push_back({node, outputs[0]});
+ *output_frame = nullptr;
+ } else {
+ // If this is a new iteration, start it.
+ if (input_iter == input_frame->iteration_count) {
+ IncrementIteration(input_frame, ready);
+ }
+ *output_iter = input_iter + 1;
+ }
+ }
+ }
+ return is_enter;
+}
+
+void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter,
+ TaggedNodeSeq* ready) {
+ int64 curr_iter = iter;
+ while (curr_iter <= frame->iteration_count &&
+ IsIterationDone(frame, curr_iter)) {
+ // Delete the iteration curr_iter
+ VLOG(2) << "Delete iteration [" << frame->frame_name << ", " << curr_iter
+ << "].";
+
+ delete frame->GetIteration(curr_iter);
+ frame->SetIteration(curr_iter, nullptr);
+ --frame->num_outstanding_iterations;
+ ++curr_iter;
+
+ // If there is a deferred iteration, start it.
+ if (frame->next_iter_roots.size() > 0) {
+ IncrementIteration(frame, ready);
+ }
+ }
+
+ if (IsFrameDone(frame)) {
+ FrameState* parent_frame = frame->parent_frame;
+ int64 parent_iter = frame->parent_iter;
+
+ // Propagate all the dead exits to the parent frame.
+ for (const Node* node : frame->dead_exits) {
+ auto parent_iter_state = parent_frame->GetIteration(parent_iter);
+ std::vector<int>* pending = parent_iter_state->pending_count;
+ std::vector<int>* dead_count = parent_iter_state->dead_count;
+ for (const Edge* e : node->out_edges()) {
+ const Node* dst_node = e->dst();
+ const int dst_id = dst_node->id();
+
+ bool dst_dead = true;
+ bool dst_ready = false;
+ // We know this is a dead input to dst
+ if (IsMerge(dst_node)) {
+ if (e->IsControlEdge()) {
+ (*pending)[dst_id] -= 2;
+ int count = (*pending)[dst_id];
+ dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs());
+ dst_ready = (count == 1) || ((count == 0) && dst_dead);
+ } else {
+ ++(*dead_count)[dst_id];
+ dst_dead = ((*dead_count)[dst_id] == dst_node->num_inputs());
+ dst_ready = ((*pending)[dst_id] == 0) && dst_dead;
+ }
+ } else {
+ ++(*dead_count)[dst_id];
+ dst_ready = (--(*pending)[dst_id] == 0);
+ }
+ if (dst_ready) {
+ ready->push_back(
+ TaggedNode(dst_node, parent_frame, parent_iter, dst_dead));
+ parent_iter_state->outstanding_ops++;
+ }
+ }
+ }
+
+ // Delete the frame
+ const string& frame_name = frame->frame_name;
+ VLOG(2) << "Delete frame " << frame_name;
+ outstanding_frames_.erase(frame_name);
+ delete frame;
+
+ // Cleanup recursively
+ if (parent_frame != nullptr) {
+ parent_frame->GetIteration(parent_iter)->outstanding_frame_count--;
+ CleanupFramesIterations(parent_frame, parent_iter, ready);
+ }
+ }
+}
+
+// When ExecutorImpl graph has no control flow nodes,
+// SimpleExecutorState is used instead of ExecutorState. It maintains
+// fewer internal state and is convenient for experimenting with async
+// op kernels.
+class SimpleExecutorState {
+ public:
+ SimpleExecutorState(const Executor::Args& args, ExecutorImpl* impl);
+ ~SimpleExecutorState() {
+ for (auto it : device_context_map_) {
+ it.second->Unref();
+ }
+ delete slice_reader_cache_;
+ }
+ void RunAsync(Executor::DoneCallback done);
+
+ private:
+ typedef SimpleExecutorState ME;
+
+ // Not owned.
+ Rendezvous* rendezvous_;
+ StepStatsCollector* stats_collector_;
+ checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
+ FunctionCallFrame* call_frame_;
+ const ExecutorImpl* impl_;
+ CancellationManager* cancellation_manager_;
+ Executor::Args::Runner runner_;
+
+ // Owned.
+
+ // i-th node's j-th input is in tensors_[impl_->nodes[i].input_start
+ // + j]. The output is either a tensor pointer (pass-by-reference)
+ // or a tensor (pass-by-value).
+ //
+ // NOTE: Not protected by mu_ because tensors_ is resized once. Each
+ // element of tensors_ is written once by the source node of an edge
+ // and is cleared by the destination of the same edge. The latter
+ // node is never run concurrently with the former node.
+ struct Entry {
+ Tensor val = *kEmptyTensor; // A tensor value.
+ Tensor* ref = nullptr; // A tensor reference.
+ mutex* ref_mu = nullptr; // mutex for *ref if ref is not nullptr.
+
+ // Every entry carries an optional DeviceContext containing
+ // Device-specific information about how the Tensor was produced.
+ DeviceContext* device_context = nullptr;
+
+ // The attributes of the allocator that creates the tensor.
+ AllocatorAttributes alloc_attr;
+ };
+
+ // Contains a map from node id to the DeviceContext object that was
+ // assigned by the device at the beginning of a step.
+ DeviceContextMap device_context_map_;
+
+ std::vector<Entry> input_tensors_;
+
+ // Step-local resource manager.
+ ResourceMgr step_resource_manager_;
+
+ // Invoked when the execution finishes.
+ Executor::DoneCallback done_cb_;
+
+ // How many active threads of computation are being used. Same as
+ // the number of pending Process() functions.
+ std::atomic_int_fast32_t num_active_;
+
+ mutex mu_;
+ Status status_ GUARDED_BY(mu_);
+
+ // i-th kernel is still waiting for pending[i] inputs.
+ class CountDown {
+ public:
+ CountDown() : v_(0) {}
+ void Set(int32 v) { v_.store(v); }
+ bool Dec() {
+ return v_.load(std::memory_order_acquire) == 1 || v_.fetch_sub(1) == 1;
+ }
+
+ private:
+ std::atomic_int_fast32_t v_;
+ };
+ std::vector<CountDown> pending_;
+
+ // Process Node identified by "id" in current thread. "scheduled_usec"
+ // indicates when the node becomes ready and gets scheduled.
+ void Process(int id, int64 scheduled_usec);
+
+ // Before invoking item->kernel, fills in its "inputs".
+ Status PrepareInputs(const NodeItem& item, TensorValueVec* inputs,
+ DeviceContextVec* input_device_contexts);
+
+ // After item->kernel computation is done, processes its outputs
+ // and returns nodes that become "ready".
+ typedef gtl::InlinedVector<int, 8> ReadyNodeIds;
+ Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
+ ReadyNodeIds* ready, NodeExecStats* stats);
+
+ // "node" just finishes. Takes ownership of "stats". Returns true if
+ // execution has completed.
+ bool NodeDone(const Status& s, const Node* node, const ReadyNodeIds& ready,
+ NodeExecStats* stats, std::deque<int>* inline_ready);
+
+ // Call Process() on all nodes in 'inline_ready'.
+ void ProcessInline(const std::deque<int>& inline_ready);
+
+ // Schedule all the expensive nodes in 'ready', and put all the inexpensive
+ // nodes in 'ready' into 'inline_ready'.
+ void ScheduleReady(const ReadyNodeIds& ready, std::deque<int>* inline_ready);
+
+ // One thread of control finishes.
+ void Finish();
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SimpleExecutorState);
+};
+
+SimpleExecutorState::SimpleExecutorState(const Executor::Args& args,
+ ExecutorImpl* impl)
+ : rendezvous_(args.rendezvous),
+ stats_collector_(args.stats_collector),
+ slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
+ call_frame_(args.call_frame),
+ impl_(impl),
+ cancellation_manager_(args.cancellation_manager),
+ runner_(args.runner),
+ num_active_(0),
+ pending_(impl_->nodes_.size()) {}
+
+void SimpleExecutorState::ProcessInline(const std::deque<int>& inline_ready) {
+ if (inline_ready.empty()) return;
+ int64 scheduled_usec = 0;
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ for (int id : inline_ready) {
+ Process(id, scheduled_usec);
+ }
+}
+
+void SimpleExecutorState::ScheduleReady(const ReadyNodeIds& ready,
+ std::deque<int>* inline_ready) {
+ if (ready.empty()) return;
+
+ int64 scheduled_usec = 0;
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ if (inline_ready == nullptr) {
+ // Schedule to run all the ready ops in thread pool.
+ for (auto id : ready) {
+ runner_(std::bind(&ME::Process, this, id, scheduled_usec));
+ }
+ return;
+ }
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ int curr_expensive_node = -1;
+ for (auto id : ready) {
+ if (!nodes[id].kernel->IsExpensive()) {
+ // Inline this inexpensive node.
+ inline_ready->push_back(id);
+ } else {
+ if (curr_expensive_node != -1) {
+ // Dispatch to another thread since there is plenty of work to
+ // do for this thread.
+ runner_(
+ std::bind(&ME::Process, this, curr_expensive_node, scheduled_usec));
+ }
+ curr_expensive_node = id;
+ }
+ }
+ if (curr_expensive_node != -1) {
+ if (inline_ready->empty()) {
+ // Tail recursion optimization
+ inline_ready->push_back(curr_expensive_node);
+ } else {
+ // There are inline nodes to run already. We dispatch this expensive
+ // node to other thread.
+ runner_(
+ std::bind(&ME::Process, this, curr_expensive_node, scheduled_usec));
+ }
+ }
+}
+
+void SimpleExecutorState::RunAsync(Executor::DoneCallback done) {
+ const Graph* graph = impl_->graph_;
+ ReadyNodeIds ready;
+
+ // Ask the device to fill in the device context map.
+ Device* device = impl_->params_.device;
+ device->FillContextMap(graph, &device_context_map_);
+
+ for (const Node* n : graph->nodes()) {
+ const int id = n->id();
+ const int num_in_edges = n->in_edges().size();
+ pending_[id].Set(num_in_edges);
+ if (num_in_edges == 0) {
+ ready.push_back(id);
+ }
+ }
+ if (ready.empty()) {
+ done(Status::OK());
+ } else {
+ num_active_ = ready.size();
+ done_cb_ = done;
+ input_tensors_.resize(impl_->total_tensors_);
+ // Schedule to run all the ready ops in thread pool.
+ ScheduleReady(ready, nullptr);
+ }
+}
+
+Status SimpleExecutorState::PrepareInputs(
+ const NodeItem& item, TensorValueVec* inputs,
+ DeviceContextVec* input_device_contexts) {
+ const Node* node = item.node;
+
+ inputs->clear();
+ inputs->resize(node->num_inputs());
+ input_device_contexts->clear();
+ input_device_contexts->resize(node->num_inputs());
+
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ const bool expect_ref = IsRefType(node->input_type(i));
+ Entry* entry = input_tensors_.data() + item.input_start + i;
+ (*input_device_contexts)[i] = entry->device_context;
+
+ // i-th input.
+ TensorValue* inp = &(*inputs)[i];
+
+ if (entry->ref == nullptr) {
+ if (expect_ref) {
+ return AttachDef(
+ errors::InvalidArgument(i, "-th input expects a ref type"),
+ item.kernel->def());
+ }
+ inp->tensor = &entry->val;
+ } else {
+ if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) {
+ return AttachDef(
+ errors::FailedPrecondition("Attempting to use uninitialized value ",
+ item.kernel->def().input(i)),
+ item.kernel->def());
+ }
+ if (expect_ref) {
+ inp->mutex_if_ref = entry->ref_mu;
+ inp->tensor = entry->ref;
+ } else {
+ // Automatically deref the tensor ref when the op expects a
+ // tensor but is given a ref to a tensor. Need to deref it
+ // under the mutex.
+ {
+ mutex_lock l(*(entry->ref_mu));
+ entry->val = *entry->ref;
+ }
+ inp->tensor = &entry->val;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+void SimpleExecutorState::Process(int id, int64 scheduled_usec) {
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ ReadyNodeIds ready;
+ std::deque<int> inline_ready;
+
+ // Parameters passed to OpKernel::Compute.
+ TensorValueVec inputs;
+ DeviceContextVec input_device_contexts;
+
+ OpKernelContext::Params params;
+ Device* device = impl_->params_.device;
+ params.device = device;
+ // track allocations if and only if we are collecting statistics
+ params.track_allocations = (stats_collector_ != nullptr);
+ params.rendezvous = rendezvous_;
+ params.cancellation_manager = cancellation_manager_;
+ params.call_frame = call_frame_;
+ params.function_library = impl_->params_.function_library;
+ params.resource_manager = device->resource_manager();
+ params.step_resource_manager = &step_resource_manager_;
+ params.slice_reader_cache = slice_reader_cache_;
+ params.inputs = &inputs;
+ params.input_device_contexts = &input_device_contexts;
+ params.frame_iter = FrameAndIter(0, 0);
+
+ Status s;
+ NodeExecStats* stats = nullptr;
+ bool completed = false;
+ inline_ready.push_back(id);
+ while (!inline_ready.empty()) {
+ id = inline_ready.front();
+ inline_ready.pop_front();
+ const NodeItem& item = nodes[id];
+ const Node* node = item.node;
+
+ // Set the device_context for this node id, if it exists.
+ auto dc_it = device_context_map_.find(id);
+ if (dc_it != device_context_map_.end()) {
+ params.op_device_context = dc_it->second;
+ }
+
+ if (stats_collector_) {
+ stats = new NodeExecStats;
+ stats->set_node_name(node->name());
+ nodestats::SetScheduled(stats, scheduled_usec);
+ nodestats::SetAllStart(stats);
+ }
+
+ VLOG(1) << "Process node: " << id << " " << SummarizeNodeDef(node->def());
+
+ // Prepares inputs.
+ s = PrepareInputs(item, &inputs, &input_device_contexts);
+ if (!s.ok()) {
+ // Continue to process the nodes in 'inline_ready'.
+ completed = NodeDone(s, item.node, ready, stats, &inline_ready);
+ continue;
+ }
+
+ OpKernel* op_kernel = item.kernel;
+ params.op_kernel = op_kernel;
+ params.output_alloc_attr = [this, node, op_kernel](int index) {
+ return OutputAttributes(&impl_->alloc_attr_, node, op_kernel, index);
+ };
+
+ // Asynchronous computes.
+ AsyncOpKernel* async = op_kernel->AsAsync();
+ if (async) {
+ auto pcopy = CopyParams(params);
+ auto ctx = new OpKernelContext(*pcopy);
+ auto done = [this, item, ctx, stats, pcopy]() {
+ VLOG(2) << this
+ << " Async kernel done: " << SummarizeNodeDef(item.node->def());
+ if (stats_collector_) nodestats::SetOpEnd(stats);
+ ReadyNodeIds ready;
+ Status s = ProcessOutputs(item, ctx, &ready, stats);
+ if (stats_collector_) nodestats::SetMemory(stats, ctx);
+ // Schedule to run all the ready ops in thread pool.
+ bool completed = NodeDone(s, item.node, ready, stats, nullptr);
+ delete ctx;
+ DeleteParams(pcopy);
+ if (completed) Finish();
+ };
+ if (stats_collector_) nodestats::SetOpStart(stats);
+ device->ComputeAsync(async, ctx, done);
+ } else {
+ // Synchronous computes.
+ OpKernelContext ctx(params);
+ if (stats_collector_) nodestats::SetOpStart(stats);
+ device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
+ if (stats_collector_) nodestats::SetOpEnd(stats);
+
+ s = ProcessOutputs(item, &ctx, &ready, stats);
+ if (stats_collector_) nodestats::SetMemory(stats, &ctx);
+ if (stats_collector_) {
+ scheduled_usec = nodestats::NowInUsec();
+ }
+ completed = NodeDone(s, node, ready, stats, &inline_ready);
+ }
+ } // while !inline_ready.empty()
+
+ // This thread of computation is done if completed = true.
+ if (completed) Finish();
+}
+
+bool SimpleExecutorState::NodeDone(const Status& s, const Node* node,
+ const ReadyNodeIds& ready,
+ NodeExecStats* stats,
+ std::deque<int>* inline_ready) {
+ if (stats_collector_) {
+ nodestats::SetAllEnd(stats);
+ if (!SetTimelineLabel(node, stats)) {
+ // Only record non-transfer nodes.
+ stats_collector_->Save(impl_->params_.device->name(), stats);
+ } else {
+ delete stats;
+ }
+ }
+
+ Rendezvous* captured_rendezvous = nullptr; // Will be set on error.
+ if (!s.ok()) {
+ // Some error happened. This thread of computation is done.
+ mutex_lock l(mu_);
+ if (status_.ok()) {
+ captured_rendezvous = rendezvous_;
+ if (captured_rendezvous) captured_rendezvous->Ref();
+ status_ = s;
+ }
+ }
+ if (captured_rendezvous) {
+ // If we captured the rendezvous_ pointer, we are in an error condition.
+ // Use captured_rendezvous, in case "this" is deleted by another thread.
+ TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
+ captured_rendezvous->StartAbort(s);
+ captured_rendezvous->Unref();
+ }
+
+ bool completed = false;
+ int ready_size = ready.size();
+ if (ready_size == 0 || !s.ok()) {
+ completed = (num_active_.fetch_sub(1) == 1);
+ } else if (ready_size > 1) {
+ num_active_.fetch_add(ready_size - 1, std::memory_order_relaxed);
+ }
+
+ // Schedule the ready nodes in 'ready'.
+ if (s.ok()) {
+ ScheduleReady(ready, inline_ready);
+ }
+ return completed;
+}
+
+void SimpleExecutorState::Finish() {
+ mu_.lock();
+ auto ret = status_;
+ auto done_cb = done_cb_;
+ auto runner = runner_;
+ mu_.unlock();
+ delete this;
+ CHECK(done_cb != nullptr);
+ runner([done_cb, ret]() { done_cb(ret); });
+}
+
+Status SimpleExecutorState::ProcessOutputs(const NodeItem& item,
+ OpKernelContext* ctx,
+ ReadyNodeIds* ready,
+ NodeExecStats* stats) {
+ Status s = ctx->status();
+ if (!s.ok()) {
+ s = AttachDef(s, item.kernel->def());
+ LOG(WARNING) << this << " Compute status: " << s;
+ return s;
+ }
+
+ // Processes outputs.
+ gtl::InlinedVector<Entry, 4> outputs;
+ const Node* node = item.node;
+ outputs.resize(node->num_outputs());
+
+ // Get the device_context for this node id, if it exists.
+ DeviceContext* device_context = nullptr;
+ auto dc_it = device_context_map_.find(node->id());
+ if (dc_it != device_context_map_.end()) {
+ device_context = dc_it->second;
+ }
+
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ TensorValue val = ctx->release_output(i);
+ // Sanity check of output tensor types.
+ DataType dtype = val->dtype();
+ if (val.is_ref()) dtype = MakeRefType(dtype);
+ if (dtype == node->output_type(i)) {
+ Entry* out = &(outputs[i]);
+ if (val.is_ref()) {
+ out->ref = val.tensor;
+ out->ref_mu = val.mutex_if_ref;
+ } else {
+ out->val = *val.tensor;
+ }
+
+ // Set the device context of the output entry.
+ out->device_context = device_context;
+
+ // Set the allocator attributes of the output entry.
+ out->alloc_attr = ctx->output_alloc_attr(i);
+
+ if (stats_collector_ && val.tensor->IsInitialized()) {
+ nodestats::SetOutput(stats, i, ctx->output_allocation_type(i),
+ val.tensor);
+ }
+ } else {
+ s.Update(
+ errors::Internal("Output ", i, " of type ", DataTypeString(dtype),
+ " does not match declared output type ",
+ DataTypeString(node->output_type(i)),
+ " for operation ", SummarizeNodeDef(node->def())));
+ }
+ if (!val.is_ref()) {
+ // If OpKernelContext returns outputs via pass-by-value, we
+ // don't need this trouble.
+ delete val.tensor;
+ }
+ }
+ if (!s.ok()) return s;
+
+ // Clears inputs.
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ input_tensors_[item.input_start + i].val = *kEmptyTensor;
+ }
+
+ // Propagates outputs along out edges.
+ ready->clear();
+ const std::vector<NodeItem>& nodes = impl_->nodes_;
+ for (const Edge* e : node->out_edges()) {
+ const int src_slot = e->src_output();
+ const int dst_id = e->dst()->id();
+ const NodeItem& dst_item = nodes[dst_id];
+ if (!e->IsControlEdge()) {
+ const int dst_slot = e->dst_input();
+ input_tensors_[dst_item.input_start + dst_slot] = outputs[src_slot];
+ }
+ if (pending_[dst_id].Dec()) {
+ ready->push_back(dst_id);
+ }
+ }
+ return Status::OK();
+}
+
+// NOTE(yuanbyu): Use the executor that supports control flow by default.
+const bool use_control_flow_executor = true;
+void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
+ if (params_.has_control_flow || use_control_flow_executor) {
+ (new ExecutorState(args, this))->RunAsync(done);
+ } else {
+ (new SimpleExecutorState(args, this))->RunAsync(done);
+ }
+}
+
+} // end namespace
+
+Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph,
+ Executor** executor) {
+ ExecutorImpl* impl = new ExecutorImpl(params, graph);
+ Status s = impl->Initialize();
+ if (s.ok()) {
+ *executor = impl;
+ } else {
+ delete impl;
+ }
+ return s;
+}
+
+Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
+ const NodeDef& ndef, OpKernel** kernel) {
+ auto device_type = DeviceType(device->attributes().device_type());
+ auto allocator = device->GetAllocator(AllocatorAttributes());
+ return CreateOpKernel(device_type, device, allocator, flib, ndef, kernel);
+}
+
+void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
+
+Status CreateCachedKernel(Device* device, const string& session,
+ FunctionLibraryRuntime* flib, const NodeDef& ndef,
+ OpKernel** kernel) {
+ auto op_seg = device->op_segment();
+ auto create_fn = [device, flib, &ndef](OpKernel** kernel) {
+ return CreateNonCachedKernel(device, flib, ndef, kernel);
+ };
+ return op_seg->FindOrCreate(session, ndef.name(), kernel, create_fn);
+}
+
+// Deletes "kernel".
+void DeleteCachedKernel(Device* device, const string& session,
+ OpKernel* kernel) {
+ // Do nothing.
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
new file mode 100644
index 0000000000..82bcbab836
--- /dev/null
+++ b/tensorflow/core/common_runtime/executor.h
@@ -0,0 +1,209 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
+#define TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+class StepStatsCollector;
+
+// Executor runs a graph computation.
+// Example:
+// Graph* graph = ...;
+// ... construct graph ...
+// Executor* executor;
+// TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor));
+// Rendezvous* rendezvous = NewNaiveRendezvous();
+// TF_CHECK_OK(rendezvous->Send("input", some_input_tensor));
+// TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr}));
+// TF_CHECK_OK(rendezvous->Recv("input", &output_tensor));
+// ... ...
+//
+// Multiple threads can call Executor::Run concurrently.
+class Executor {
+ public:
+ virtual ~Executor() {}
+
+ // RunAsync() executes the graph computation. "done" is run when the
+ // graph computation completes. If any error happens during the
+ // computation, "done" is run and the error is passed to "done".
+ //
+ // RunAsync() is given a few arguments in Args. The caller must
+ // ensure objects passed in Args (rendezvous, stats_collector, etc.)
+ // are alive at least until done is invoked. All pointers to the
+ // argument objects can be nullptr.
+ //
+ // RunAsync() uses the given "rendezvous", if not null, as the
+ // mechanism to communicate inputs and outputs of the underlying
+ // graph computation.
+ //
+ // RunAsync() calls "stats_collector", if not null, to keep track of
+ // stats. This allows us to collect statistics and traces on demand.
+ //
+ // RunAsync() is provided a "call_frame", if the executor is used
+ // for executing a function, is used to pass arguments and return
+ // values between the caller and the callee.
+ //
+ // RunAsync() uses "cancellation_manager", if not nullptr, to
+ // register callbacks that should be called if the graph computation
+ // is cancelled. Note that the callbacks merely unblock any
+ // long-running computation, and a cancelled step will terminate by
+ // returning/calling the DoneCallback as usual.
+ //
+ // RunAsync() dispatches closures to "runner". Typically, "runner"
+ // is backed up by a bounded threadpool.
+ struct Args {
+ Rendezvous* rendezvous = nullptr;
+ StepStatsCollector* stats_collector = nullptr;
+ FunctionCallFrame* call_frame = nullptr;
+ CancellationManager* cancellation_manager = nullptr;
+
+ typedef std::function<void()> Closure;
+ typedef std::function<void(Closure)> Runner;
+ Runner runner = nullptr;
+ };
+ typedef std::function<void(const Status&)> DoneCallback;
+ virtual void RunAsync(const Args& args, DoneCallback done) = 0;
+
+ // Synchronous wrapper for RunAsync().
+ Status Run(const Args& args) {
+ Status ret;
+ Notification n;
+ RunAsync(args, [&ret, &n](const Status& s) {
+ ret = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return ret;
+ }
+};
+
+// Creates an Executor that computes the given "graph".
+//
+// If successful, returns the constructed executor in "*executor". The
+// caller keeps the ownership of "device". The returned executor takes
+// the ownership of "graph". Otherwise, returns an error status.
+//
+// "params" provides a set of context for the executor. We expect that
+// different context would provide different implementations.
+struct LocalExecutorParams {
+ Device* device;
+
+ // The library runtime support.
+ FunctionLibraryRuntime* function_library;
+
+ // True iff the computation contains control flow nodes.
+ bool has_control_flow;
+
+ // create_kernel returns an instance of op kernel based on NodeDef.
+ // delete_kernel is called for every kernel used by the executor
+ // when the executor is deleted.
+ std::function<Status(const NodeDef&, OpKernel**)> create_kernel;
+ std::function<void(OpKernel*)> delete_kernel;
+};
+::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
+ const Graph* graph, Executor** executor);
+
+// A class to help run multiple executors in parallel and wait until
+// all of them are complete.
+//
+// ExecutorBarrier deletes itself after the function returned by Get()
+// is called.
+class ExecutorBarrier {
+ public:
+ typedef std::function<void(const Status&)> StatusCallback;
+
+ // Create an ExecutorBarrier for 'num' different executors.
+ //
+ // 'r' is the shared Rendezvous object that is used to communicate
+ // state. If any of the executors experiences an error, the
+ // rendezvous object will be aborted exactly once.
+ //
+ // 'done' is called after the last executor completes, and
+ // ExecutorBarrier is deleted.
+ ExecutorBarrier(int num, Rendezvous* r, StatusCallback done)
+ : rendez_(r), done_cb_(done), pending_(num) {}
+
+ ~ExecutorBarrier() {}
+
+ // Returns a closure that Executors must call when they are done
+ // computing, passing the status of their execution as an argument.
+ StatusCallback Get() {
+ return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1);
+ }
+
+ private:
+ Rendezvous* rendez_ = nullptr;
+ StatusCallback done_cb_ = nullptr;
+
+ mutable mutex mu_;
+ int pending_ GUARDED_BY(mu_) = 0;
+ Status status_ GUARDED_BY(mu_);
+
+ void WhenDone(const Status& s) {
+ bool error = false;
+ StatusCallback done = nullptr;
+ Status status;
+ {
+ mutex_lock l(mu_);
+ // If we are the first error encountered, mark the status
+ // appropriately and later trigger an abort of the Rendezvous
+ // object by this thread only.
+ if (status_.ok() && !s.ok()) {
+ error = true;
+ status_ = s;
+ }
+
+ // If this is the last call to WhenDone, call the final callback
+ // below.
+ if (--pending_ == 0) {
+ CHECK(done_cb_ != nullptr);
+ done = done_cb_;
+ done_cb_ = nullptr;
+ }
+ status = status_;
+ }
+ if (error) {
+ rendez_->StartAbort(status);
+ }
+ if (done != nullptr) {
+ delete this;
+ done(status);
+ }
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier);
+};
+
+// A few helpers to facilitate create/delete kernels.
+
+// Creates a kernel based on "ndef" on device "device". The kernel can
+// access the functions in the "flib". The caller takes ownership of
+// returned "*kernel".
+Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
+ const NodeDef& ndef, OpKernel** kernel);
+
+// Deletes "kernel" returned by CreateKernel.
+void DeleteNonCachedKernel(OpKernel* kernel);
+
+// Creates a kernel based on "ndef" on device "device". The kernel can
+// access the functions in the "flib". The caller does not take
+// ownership of returned "*kernel". If a kernel has been created for
+// ndef.name(), returns the same kernel instance.
+Status CreateCachedKernel(Device* device, const string& session,
+ FunctionLibraryRuntime* flib, const NodeDef& ndef,
+ OpKernel** kernel);
+
+// Deletes "kernel" returned by CreateCachedKernel.
+void DeleteCachedKernel(Device* device, const string& session,
+ OpKernel* kernel);
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
new file mode 100644
index 0000000000..2b1a041235
--- /dev/null
+++ b/tensorflow/core/common_runtime/function.cc
@@ -0,0 +1,1335 @@
+#include "tensorflow/core/common_runtime/function.h"
+
+#include <deque>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/optimizer_cse.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace tensorflow {
+
+// A few string constant used throughout this module.
+static const char* const kArgOp = "_Arg";
+static const char* const kRetOp = "_Retval";
+static const char* const kGradientOp = "SymbolicGradient";
+static const char* const kNodeLabel = "Func";
+
+// Represents the index-th output of a node.
+struct Endpoint {
+ Node* node;
+ int index;
+
+ // Returns the string name represents this endpoint.
+ string name() const {
+ if (index == 0) {
+ return node->name();
+ } else {
+ return strings::StrCat(node->name(), ":", index);
+ }
+ }
+
+ DataType dtype() const { return node->output_type(index); }
+};
+
+struct EndpointHash {
+ uint64 operator()(const Endpoint& x) const {
+ return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
+ x.index);
+ }
+};
+
+struct EndpointEq {
+ bool operator()(const Endpoint& x, const Endpoint& y) const {
+ return (x.node == y.node) && (x.index == y.index);
+ }
+};
+
+// The following Add* routines are used to add a few graph nodes while
+// functions are transformed.
+static Node* AddNoOp(Graph* g) {
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("NoOp");
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ return ret;
+}
+
+static Node* AddIdentity(Graph* g, Endpoint input) {
+ DCHECK_LT(0, input.dtype());
+ DCHECK_LT(input.dtype(), DT_FLOAT_REF);
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("Identity");
+ ndef.add_input(input.name());
+ AddNodeAttr("T", input.dtype(), &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(input.node, input.index, ret, 0);
+ return ret;
+}
+
+static Node* AddArg(Graph* g, DataType dtype, int index) {
+ DCHECK_LT(0, dtype);
+ DCHECK_LT(dtype, DT_FLOAT_REF);
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op(kArgOp);
+ AddNodeAttr("T", dtype, &ndef);
+ AddNodeAttr("index", index, &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ return ret;
+}
+
+static Node* AddRet(Graph* g, Endpoint input, int index) {
+ DCHECK_LT(0, input.dtype());
+ DCHECK_LT(input.dtype(), DT_FLOAT_REF);
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op(kRetOp);
+ ndef.add_input(input.name());
+ AddNodeAttr("T", input.dtype(), &ndef);
+ AddNodeAttr("index", index, &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(input.node, input.index, ret, 0);
+ return ret;
+}
+
+static Node* AddZerosLike(Graph* g, Endpoint input) {
+ DCHECK_LT(0, input.dtype());
+ DCHECK_LT(input.dtype(), DT_FLOAT_REF);
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("ZerosLike");
+ ndef.add_input(input.name());
+ AddNodeAttr("T", input.dtype(), &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(input.node, input.index, ret, 0);
+ return ret;
+}
+
+static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<Endpoint> grads) {
+ const int num_x = n->num_inputs();
+ const int num_y = n->num_outputs();
+ CHECK_EQ(num_y, grads.size());
+
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op(kGradientOp);
+
+ // The gradient node should have num_x + num_y inputs.
+ std::vector<Endpoint> n_inputs(num_x);
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ n_inputs[e->dst_input()] = {e->src(), e->src_output()};
+ }
+ DataTypeVector in_types;
+ for (const Endpoint& ep : n_inputs) {
+ ndef.add_input(ep.name());
+ in_types.push_back(ep.dtype());
+ }
+ for (const Endpoint& ep : grads) {
+ ndef.add_input(ep.name());
+ in_types.push_back(ep.dtype());
+ }
+ CHECK_EQ(ndef.input_size(), num_x + num_y);
+
+ AddNodeAttr("Tin", in_types, &ndef);
+
+ // The gradient node's outputs have the same types as the node 'n's
+ // inputs.
+ AddNodeAttr("Tout", n->input_types(), &ndef);
+ NameAttrList func;
+ func.set_name(n->type_string());
+ *(func.mutable_attr()) = n->def().attr();
+ AddNodeAttr("f", func, &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ return ret;
+}
+
+class ArgOp : public OpKernel {
+ public:
+ explicit ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ auto frame = ctx->call_frame();
+ OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
+ Tensor val;
+ OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
+ OP_REQUIRES(ctx, val.dtype() == dtype_,
+ errors::InvalidArgument(
+ "Type mismatch: actual ", DataTypeString(val.dtype()),
+ " vs. expect ", DataTypeString(dtype_)));
+ ctx->set_output(0, val);
+ }
+
+ private:
+ int index_;
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ArgOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp);
+REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_GPU), ArgOp);
+
+class RetvalOp : public OpKernel {
+ public:
+ explicit RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& val = ctx->input(0);
+ OP_REQUIRES(ctx, val.dtype() == dtype_,
+ errors::InvalidArgument(
+ "Type mismatch: actual ", DataTypeString(val.dtype()),
+ " vs. expect ", DataTypeString(dtype_)));
+ auto frame = ctx->call_frame();
+ OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
+ OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val));
+ }
+
+ private:
+ int index_;
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp);
+REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_GPU), RetvalOp);
+
+static const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
+
+class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
+ public:
+ FunctionLibraryRuntimeImpl(Device* device, Runner runner,
+ const FunctionLibraryDefinition* lib_def);
+
+ ~FunctionLibraryRuntimeImpl() override;
+
+ Status Instantiate(const string& function_name,
+ const InstantiateAttrValueMap& attrs,
+ Handle* handle) override;
+
+ const FunctionBody* GetFunctionBody(Handle handle) override;
+
+ Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override;
+
+ void Run(const Options& opts, Handle handle, gtl::ArraySlice<Tensor> args,
+ std::vector<Tensor>* rets, DoneCallback done) override;
+
+ bool IsDefined(const string& function_name) override;
+
+ private:
+ typedef FunctionLibraryRuntimeImpl ME;
+
+ Device* const device_;
+ Runner runner_ = nullptr;
+ const FunctionLibraryDefinition* const lib_def_;
+ std::function<Status(const string&, const OpDef**)> get_func_sig_;
+ std::function<Status(const NodeDef&, OpKernel**)> create_kernel_;
+
+ mutable mutex mu_;
+
+ // Maps function instantiation to a handle. The key is a
+ // canonicalized representation of the function name and
+ // instantiation attrs. The handle is an index into the items_.
+ std::unordered_map<string, Handle> table_ GUARDED_BY(mu_);
+
+ // func_graphs_ never shrinks or reorders its members.
+ std::vector<FunctionBody*> func_graphs_ GUARDED_BY(mu_);
+
+ // The instantiated and transformed function is encoded as a Graph
+ // object, and an executor is created for the graph.
+ struct Item : public core::RefCounted {
+ Executor* exec = nullptr;
+
+ ~Item() override { delete this->exec; }
+ };
+ std::vector<Item*> items_;
+
+ Status FunctionDefToBody(const FunctionDef& fdef,
+ const InstantiateAttrValueMap& attrs,
+ FunctionBody** fbody);
+ Status CreateItem(Handle handle, Item** item);
+ Status GetOrCreateItem(Handle handle, Item** item);
+ Status InstantiateSymbolicGradient(const InstantiateAttrValueMap& attrs,
+ FunctionBody** g_body);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
+};
+
+FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
+ Device* device, Runner runner, const FunctionLibraryDefinition* lib_def)
+ : device_(device), runner_(runner), lib_def_(lib_def) {
+ get_func_sig_ = [this](const string& op, const OpDef** sig) {
+ Status s;
+ *sig = lib_def_->LookUp(op, &s);
+ return s;
+ };
+ create_kernel_ = [this](const NodeDef& ndef, OpKernel** kernel) {
+ return CreateKernel(ndef, kernel);
+ };
+}
+
+FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
+ for (FunctionBody* p : func_graphs_) delete p;
+ for (Item* item : items_)
+ if (item) item->Unref();
+}
+
+// An asynchronous op kernel which executes an instantiated function
+// defined in a library.
+class CallOp : public AsyncOpKernel {
+ public:
+ CallOp(FunctionLibraryRuntime::Handle handle, OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx), handle_(handle) {}
+
+ ~CallOp() override {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ FunctionLibraryRuntime* lib = ctx->function_library();
+ OP_REQUIRES_ASYNC(ctx, lib != nullptr,
+ errors::Internal("No function library is provided."),
+ done);
+ FunctionLibraryRuntime::Options opts;
+ std::vector<Tensor> args;
+ args.reserve(ctx->num_inputs());
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ args.push_back(ctx->input(i));
+ }
+ std::vector<Tensor>* rets = new std::vector<Tensor>;
+ lib->Run(opts, handle_, args, rets,
+ [ctx, done, rets](const Status& status) {
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ } else {
+ CHECK_EQ(rets->size(), ctx->num_outputs());
+ for (size_t i = 0; i < rets->size(); ++i) {
+ ctx->set_output(i, (*rets)[i]);
+ }
+ }
+ delete rets;
+ done();
+ });
+ }
+
+ private:
+ FunctionLibraryRuntime::Handle handle_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CallOp);
+};
+
+const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
+ mutex_lock l(mu_);
+ CHECK_LE(0, h);
+ CHECK_LT(h, func_graphs_.size());
+ return func_graphs_[h];
+}
+
+Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
+ OpKernel** kernel) {
+ if (ndef.op() != kGradientOp && (lib_def_->Find(ndef.op()) == nullptr)) {
+ return CreateNonCachedKernel(device_, this, ndef, kernel);
+ }
+
+ // Try to instantiate this function for the func/attr. Maybe its
+ // cached already.
+ Handle handle;
+ TF_RETURN_IF_ERROR(Instantiate(ndef.op(), ndef.attr(), &handle));
+
+ const FunctionBody* fbody = GetFunctionBody(handle);
+ CHECK_NOTNULL(fbody);
+
+ // Constructs a CallOp kernel for running the instantiated function.
+ Status s;
+ auto device_type = DeviceType(device_->attributes().device_type());
+ OpKernelConstruction construction(
+ device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
+ &fbody->fdef.signature(), this, fbody->arg_types, fbody->ret_types, &s);
+ *kernel = new CallOp(handle, &construction);
+ if (!s.ok()) {
+ delete kernel;
+ }
+ return s;
+}
+
+Status FunctionLibraryRuntimeImpl::FunctionDefToBody(
+ const FunctionDef& fdef, const InstantiateAttrValueMap& attrs,
+ FunctionBody** fbody) {
+ // Instantiates the function template into a graph def.
+ InstantiationResult result;
+ TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig_, &result));
+
+ Graph* graph = new Graph(lib_def_);
+ GraphConstructorOptions opts;
+ opts.allow_internal_ops = true;
+ opts.expect_device_spec = false;
+ Status s = ConvertGraphDefToGraph(opts, result.gdef, graph);
+ if (!s.ok()) {
+ delete graph;
+ } else {
+ *fbody = new FunctionBody(fdef, result.arg_types, result.ret_types, graph);
+ }
+ return s;
+}
+
+Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient(
+ const InstantiateAttrValueMap& attrs, FunctionBody** g_body) {
+ const AttrValue* f = gtl::FindOrNull(attrs, "f");
+ if (f == nullptr) {
+ return errors::InvalidArgument("SymbolicGradient is missing attr: f");
+ }
+ const auto& func = f->func();
+ const FunctionDef* fdef = lib_def_->Find(func.name());
+ if (fdef == nullptr) {
+ // f is a primitve op.
+ gradient::Creator creator;
+ TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
+ if (creator == nullptr) {
+ return errors::InvalidArgument("No gradient is defined for ",
+ func.name());
+ }
+ FunctionDef grad_fdef;
+ TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
+ TF_RETURN_IF_ERROR(FunctionDefToBody(grad_fdef, func.attr(), g_body));
+ } else {
+ // f is a user-defined function.
+ Handle f_handle;
+ TF_RETURN_IF_ERROR(Instantiate(func.name(), func.attr(), &f_handle));
+ const FunctionBody* f_body = GetFunctionBody(f_handle);
+ CHECK_NOTNULL(f_body);
+ *g_body = SymbolicGradient(*f_body);
+ }
+ return Status::OK();
+}
+
+Status FunctionLibraryRuntimeImpl::Instantiate(
+ const string& function_name, const InstantiateAttrValueMap& attrs,
+ Handle* handle) {
+ const string key = Canonicalize(function_name, attrs);
+ {
+ mutex_lock l(mu_);
+ *handle = gtl::FindWithDefault(table_, key, kInvalidHandle);
+ if (*handle != kInvalidHandle) {
+ return Status::OK();
+ }
+ }
+
+ Status s;
+ FunctionBody* fbody = nullptr;
+ if (function_name == kGradientOp) {
+ TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(attrs, &fbody));
+ } else {
+ const FunctionDef* fdef = lib_def_->Find(function_name);
+ if (fdef == nullptr) {
+ return errors::NotFound("Function ", function_name, " is not defined.");
+ }
+ TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, &fbody));
+ }
+
+ {
+ mutex_lock l(mu_);
+ *handle = gtl::FindWithDefault(table_, key, kInvalidHandle);
+ if (*handle != kInvalidHandle) {
+ delete fbody;
+ } else {
+ *handle = func_graphs_.size();
+ table_.insert({key, *handle});
+ func_graphs_.push_back(fbody);
+ items_.resize(func_graphs_.size());
+ }
+ }
+ return Status::OK();
+}
+
+static void DumpGraph(const char* label, const Graph* g) {
+ if (VLOG_IS_ON(1)) {
+ LOG(INFO) << label << ": " << std::endl << DebugString(g);
+ }
+}
+
+static void SimplifyGraph(Graph* g) {
+ if (RemoveListArrayConverter(g)) {
+ DumpGraph("RemoveListArrayConverter", g);
+ }
+ bool changed;
+ do {
+ changed = false;
+ if (RemoveDeadNodes(g)) {
+ changed = true;
+ DumpGraph("RemoveDeadNodes", g);
+ }
+ if (RemoveIdentityNodes(g)) {
+ changed = true;
+ DumpGraph("RemoveIdentityNodes", g);
+ }
+ FixupSourceAndSinkEdges(g);
+ OptimizeCSE(g, nullptr);
+ DumpGraph("OptimizeCSE", g);
+ } while (changed);
+}
+
+void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g) {
+ DumpGraph("Initial", *g);
+ const int kNumInlineRounds = 10;
+ for (int i = 0; i < kNumInlineRounds; ++i) {
+ if (!ExpandInlineFunctions(lib, *g)) break;
+ DumpGraph("ExpandInlineFunctions", *g);
+ SimplifyGraph(*g);
+ }
+
+ // Makes a copy so that we densify node ids.
+ Graph* copy = new Graph((*g)->op_registry());
+ CopyGraph(**g, copy);
+ delete *g;
+ *g = copy;
+ DumpGraph("ReCopy", *g);
+}
+
+Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
+ const FunctionBody* fbody = GetFunctionBody(handle);
+ CHECK_NOTNULL(fbody);
+ Graph* g = new Graph(lib_def_);
+ CopyGraph(*fbody->graph, g);
+ OptimizeGraph(this, &g);
+
+ // Creates an executor based on the g. This must be done without
+ // holding mu_ because create_kernel_ calls back into the library.
+ LocalExecutorParams params;
+ params.device = device_;
+ params.function_library = this;
+ params.has_control_flow = false;
+ params.create_kernel = create_kernel_;
+ params.delete_kernel = [](OpKernel* kernel) {
+ DeleteNonCachedKernel(kernel);
+ };
+ Executor* exec;
+ TF_RETURN_IF_ERROR(NewLocalExecutor(params, g, &exec));
+
+ *item = new Item;
+ (*item)->exec = exec;
+ return Status::OK();
+}
+
+Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
+ {
+ mutex_lock l(mu_);
+ if (handle >= items_.size()) {
+ return errors::NotFound("Function handle ", handle,
+ " is not valid. Likely an internal error.");
+ }
+ *item = items_[handle];
+ if (*item != nullptr) {
+ (*item)->Ref();
+ return Status::OK();
+ }
+ }
+ // NOTE: We need to call CreateItem out of mu_ because creating an
+ // executor needs to call CreateKernel.
+ TF_RETURN_IF_ERROR(CreateItem(handle, item));
+
+ {
+ mutex_lock l(mu_);
+ if (items_[handle] == nullptr) {
+ // Install *item in items_.
+ items_[handle] = *item;
+ (*item)->Ref();
+ }
+ }
+ return Status::OK();
+}
+
+void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
+ gtl::ArraySlice<Tensor> args,
+ std::vector<Tensor>* rets,
+ DoneCallback done) {
+ if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
+ return done(errors::Cancelled(""));
+ }
+ const FunctionBody* fbody = GetFunctionBody(handle);
+ FunctionCallFrame* frame =
+ new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
+ Status s = frame->SetArgs(args);
+ if (!s.ok()) {
+ delete frame;
+ return done(s);
+ }
+ Item* item = nullptr;
+ s = GetOrCreateItem(handle, &item);
+ if (!s.ok()) {
+ delete frame;
+ return done(s);
+ }
+ Executor::Args exec_args;
+ exec_args.call_frame = frame;
+ exec_args.cancellation_manager = opts.cancellation_manager;
+ exec_args.runner = runner_;
+ item->exec->RunAsync(
+ // Executor args
+ exec_args,
+ // Done callback.
+ [item, frame, rets, done](const Status& status) {
+ item->Unref();
+ Status s = status;
+ if (s.ok()) {
+ s = frame->GetRetvals(rets);
+ }
+ delete frame;
+ done(s);
+ });
+}
+
+bool FunctionLibraryRuntimeImpl::IsDefined(const string& function_name) {
+ return lib_def_->Find(function_name) != nullptr;
+}
+
+FunctionLibraryRuntime* NewFunctionLibraryRuntime(
+ Device* device, Runner runner, const FunctionLibraryDefinition* lib_def) {
+ return new FunctionLibraryRuntimeImpl(device, runner, lib_def);
+}
+
+bool RemoveDeadNodes(Graph* g) {
+ std::vector<bool> visited(g->num_node_ids(), false);
+ visited[Graph::kSourceId] = true;
+ visited[Graph::kSinkId] = true;
+ std::deque<Node*> q;
+ for (auto n : g->nodes()) {
+ if (n->op_def().is_stateful()) {
+ visited[n->id()] = true;
+ } else if (n->type_string() == kArgOp) {
+ visited[n->id()] = true;
+ } else if (n->type_string() == kRetOp) {
+ visited[n->id()] = true;
+ q.push_back(n);
+ }
+ }
+ while (!q.empty()) {
+ const Node* n = q.front();
+ q.pop_front();
+ visited[n->id()] = true;
+ for (auto e : n->in_edges()) {
+ q.push_back(e->src());
+ }
+ }
+ bool removed_any = false;
+ for (Node* n : g->nodes()) {
+ if (!visited[n->id()]) {
+ g->RemoveNode(n);
+ removed_any = true;
+ }
+ }
+ return removed_any;
+}
+
+namespace {
+// If 'edges' contains only 1 non-control edge, returns it. Otherwise,
+// returns a nullptr.
+const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) {
+ const Edge* ret = nullptr;
+ for (const Edge* e : edges) {
+ if (e->IsControlEdge() || ret) return nullptr;
+ ret = e;
+ }
+ return ret;
+}
+} // end namespace
+
+bool RemoveIdentityNodes(Graph* g) {
+ bool removed_any = false;
+ gtl::InlinedVector<Node*, 8> matches;
+ for (Node* n : g->nodes()) {
+ if ((n->type_string() == "Identity") && GetTheOnlyDataEdge(n->in_edges())) {
+ matches.push_back(n);
+ }
+ }
+ if (!matches.empty()) {
+ for (Node* n : matches) {
+ const Edge* in = GetTheOnlyDataEdge(n->in_edges());
+ for (const Edge* out : n->out_edges()) {
+ if (out->IsControlEdge()) {
+ g->AddControlEdge(in->src(), out->dst());
+ } else {
+ g->AddEdge(in->src(), in->src_output(), out->dst(), out->dst_input());
+ }
+ }
+ g->RemoveNode(n);
+ removed_any = true;
+ }
+ }
+ return removed_any;
+}
+
+bool RemoveListArrayConverter(Graph* g) {
+ gtl::InlinedVector<Node*, 8> matches;
+ for (Node* n : g->nodes()) {
+ if ((n->type_string() == "_ListToArray") ||
+ (n->type_string() == "_ArrayToList")) {
+ matches.push_back(n);
+ }
+ }
+ bool removed_any = false;
+ if (!matches.empty()) {
+ for (Node* n : matches) {
+ if (n->num_inputs() != n->num_outputs()) {
+ continue; // Not expected. Skip.
+ }
+ gtl::InlinedVector<Node*, 8> identity_nodes(n->num_inputs(), nullptr);
+
+ // Process input edges first.
+ Node* input_control_node = nullptr;
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) {
+ if (input_control_node == nullptr) {
+ // If node "n" has any control dependencies, adds a no-op
+ // node (input_control_node) which the additional Identity
+ // nodes depends on and the input_control_node depends on
+ // the node "n"s control dependencies.
+ input_control_node = AddNoOp(g);
+ }
+ g->AddControlEdge(e->src(), input_control_node);
+ } else {
+ const int index = e->dst_input();
+ Node** id_node = &identity_nodes[index];
+ if (*id_node != nullptr) {
+ LOG(ERROR)
+ << "RemoveListArrayConverter unexpected duplicated input: "
+ << e->dst_input();
+ return removed_any;
+ }
+ *id_node = AddIdentity(g, {e->src(), e->src_output()});
+ }
+ }
+
+ // If node "n" has any control dependencies, the added identity
+ // nodes should have control dependencies on input_control_node.
+ if (input_control_node != nullptr) {
+ for (Node* id : identity_nodes) {
+ g->AddControlEdge(input_control_node, id);
+ }
+ }
+
+ Node* output_control_node = nullptr;
+ for (const Edge* e : n->out_edges()) {
+ if (e->IsControlEdge()) {
+ if (output_control_node == nullptr) {
+ // If node "n" is control-depended upon by other nodes,
+ // adds a no-op node (output_control_node) which those
+ // nodes will depend on and output_control_node depends on
+ // all Identity nodes.
+ output_control_node = AddNoOp(g);
+ }
+ g->AddControlEdge(output_control_node, e->dst());
+ } else {
+ Node* id_node = identity_nodes[e->src_output()];
+ if (id_node == nullptr) {
+ LOG(ERROR) << "RemoveListArrayConverter unexpected missing input: "
+ << e->src_output();
+ return removed_any;
+ }
+ CHECK(id_node);
+ g->AddEdge(id_node, 0, e->dst(), e->dst_input());
+ }
+ }
+
+ // If any nodes have control dependencies on node "n", those
+ // nodes should have control dependencies on
+ // output_control_node.
+ if (output_control_node != nullptr) {
+ for (Node* id : identity_nodes) {
+ g->AddControlEdge(id, output_control_node);
+ }
+ }
+
+ g->RemoveNode(n);
+ removed_any = true;
+ }
+ }
+ return removed_any;
+}
+
+// Returns true iff the function '*fbody' can be inlined at 'node'
+// based on the type signature of 'node' and 'fbody'.
+static bool ValidateInlining(const Node* node, const FunctionBody* fbody) {
+ if (static_cast<size_t>(node->num_inputs()) != fbody->arg_types.size()) {
+ return false;
+ }
+ if (static_cast<size_t>(node->num_inputs()) != fbody->arg_nodes.size()) {
+ return false;
+ }
+ if (static_cast<size_t>(node->num_outputs()) != fbody->ret_types.size()) {
+ return false;
+ }
+ if (static_cast<size_t>(node->num_outputs()) != fbody->ret_nodes.size()) {
+ return false;
+ }
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ if (node->input_type(i) != fbody->arg_types[i]) return false;
+ }
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ if (node->output_type(i) != fbody->ret_types[i]) return false;
+ }
+ return true;
+}
+
+// Given a "caller" in "graph", which is a function call of a function
+// to "fbody". Replaces the "caller" with fbody->graph and connects
+// edges properly.
+static void InlineFunctionBody(Graph* g, Node* caller,
+ const FunctionBody* fbody) {
+ if (!ValidateInlining(caller, fbody)) {
+ LOG(WARNING) << "Inlining mismatch: " << caller->DebugString() << " vs. "
+ << DebugString(fbody->graph);
+ return;
+ }
+
+ // Duplicate fbody->graph into 'g'. First, we copy the nodes of
+ // fbody->graph into 'g' except the source and sink nodes. We copy
+ // edges among nodes in 'fbody->graph'.
+ //
+ // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
+ // remember 'y' in node_map[x->id()].
+ std::vector<Node*> node_map(fbody->graph->num_node_ids());
+ for (Node* n : fbody->graph->nodes()) {
+ if (n->IsSource() || n->IsSink()) continue;
+ CHECK(n->IsOp());
+ node_map[n->id()] = g->CopyNode(n);
+ }
+ for (const Edge* e : fbody->graph->edges()) {
+ if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() ||
+ e->dst()->IsSink()) {
+ continue;
+ }
+ Node* src_copy = node_map[e->src()->id()];
+ Node* dst_copy = node_map[e->dst()->id()];
+ g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
+ }
+
+ // Connect input edges.
+ //
+ // For data edges coming into "caller", we first compute the
+ // <src>:<src_output> for the i-th input in "inputs". We create one
+ // Identity node for each input. Then, we connect inputs[i] to to
+ // the i-th identity node added. The nodes that previously connects
+ // to the j-th output of i-th arg node are reconnected to th i-th
+ // identity node.
+ //
+ // If "caller" has any input control dependencies, we add a NoOp
+ // node "input_control_node". This "input_control_node" depends on
+ // what "caller" depends on, and the added identity nodes depend on
+ // "input_control_node".
+ std::vector<Endpoint> inputs(caller->num_inputs());
+ Node* input_control_node = nullptr;
+ for (const Edge* e : caller->in_edges()) {
+ if (e->IsControlEdge()) {
+ if (input_control_node == nullptr) {
+ input_control_node = AddNoOp(g);
+ }
+ g->AddControlEdge(e->src(), input_control_node);
+ } else {
+ inputs[e->dst_input()] = {e->src(), e->src_output()};
+ }
+ }
+ for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
+ Node* arg = node_map[fbody->arg_nodes[i]->id()];
+ Node* n = AddIdentity(g, inputs[i]);
+ if (input_control_node) {
+ g->AddControlEdge(input_control_node, n);
+ }
+ for (const Edge* e : arg->out_edges()) {
+ if (e->IsControlEdge()) {
+ g->AddControlEdge(n, e->dst());
+ } else {
+ g->AddEdge(n, 0, e->dst(), e->dst_input());
+ }
+ }
+ node_map[fbody->arg_nodes[i]->id()] = n;
+ g->RemoveNode(arg); // 'arg' is disconnected.
+ }
+
+ // Connect output edges.
+ //
+ // For i-th return node in fbody->graph, we add in "g" an identity
+ // node (outputs[i-th]). We then reconnect every incoming edge into
+ // the i-th return node to the added identity node.
+ //
+ // For every data edge coming out of "callee"s i-th output, we
+ // reconnect it to the i-th identity added above.
+ //
+ // If "callee" is control-depended upon by any other nodes, we add a
+ // NoOp node "output_control_node". "output_control_node" depends on
+ // all identity nodes added above. And nodes previously depend on
+ // "callee" is changed to depend on "output_control_node".
+ std::vector<Node*> outputs(caller->num_inputs());
+ for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
+ Node* ret = node_map[fbody->ret_nodes[i]->id()];
+ Endpoint data; // Data input for the ret node.
+ for (const Edge* e : ret->in_edges()) {
+ if (!e->IsControlEdge()) {
+ data = {e->src(), e->src_output()};
+ break;
+ }
+ }
+ CHECK(data.node != nullptr);
+ Node* n = AddIdentity(g, data);
+ outputs[i] = n;
+ for (const Edge* e : ret->in_edges()) {
+ if (e->IsControlEdge()) {
+ g->AddControlEdge(e->src(), n);
+ }
+ }
+ g->RemoveNode(ret); // 'ret' is disconnected.
+ }
+ Node* output_control_node = nullptr;
+ for (const Edge* e : caller->out_edges()) {
+ if (e->IsControlEdge()) {
+ if (output_control_node == nullptr) {
+ output_control_node = AddNoOp(g);
+ for (Node* n : outputs) {
+ g->AddControlEdge(n, output_control_node);
+ }
+ }
+ g->AddControlEdge(output_control_node, e->dst());
+ } else {
+ g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input());
+ }
+ }
+ g->RemoveNode(caller); // 'caller' is replaced with inlined nodes.
+}
+
+bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
+ std::vector<std::pair<Node*, const FunctionBody*>> candidates;
+ for (Node* node : graph->nodes()) {
+ VLOG(3) << "Expanding " << node->DebugString();
+ FunctionLibraryRuntime::Handle handle;
+ Status s =
+ lib->Instantiate(node->type_string(), node->def().attr(), &handle);
+ if (!s.ok()) {
+ // Either "node" is a primitive op, or the instantiation failed.
+ if (errors::IsNotFound(s)) {
+ VLOG(2) << "ExpandInlineFunctions " << s;
+ } else {
+ LOG(ERROR) << "ExpandInlineFunctions " << s;
+ }
+ continue;
+ }
+ const FunctionBody* fbody = lib->GetFunctionBody(handle);
+ CHECK_NOTNULL(fbody);
+ candidates.push_back({node, fbody});
+ }
+ for (const auto& p : candidates) {
+ InlineFunctionBody(graph, p.first, p.second);
+ }
+ return !candidates.empty();
+}
+
+// TODO(zhifengc): Maybe this should be the default Graph::AsGraphDef.
+// and stash the original NodeDef name as an attr for documentation
+// purpose.
+static void ToGraphDef(const Graph* g, GraphDef* gdef) {
+ // We visit nodes in forward topological sort order, which is a
+ // possible execution order of the graph.
+ std::vector<int> pending(g->num_node_ids());
+ std::deque<const Node*> ready;
+ for (const Node* n : g->nodes()) {
+ pending[n->id()] = n->in_edges().size();
+ if (pending[n->id()] == 0) ready.push_back(n);
+ }
+ gtl::InlinedVector<const Edge*, 4> inputs;
+ gdef->Clear();
+ while (!ready.empty()) {
+ const Node* n = ready.front();
+ ready.pop_front();
+ for (const Edge* e : n->out_edges()) {
+ const Node* next = e->dst();
+ if (--pending[next->id()] == 0) {
+ ready.push_back(next);
+ }
+ }
+ if (!n->IsOp()) continue;
+ NodeDef* ndef = gdef->add_node();
+ ndef->set_name(strings::StrCat("n", n->id()));
+ ndef->set_op(n->type_string());
+ *(ndef->mutable_attr()) = n->def().attr();
+ inputs.clear();
+ inputs.resize(n->num_inputs());
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) {
+ inputs.push_back(e);
+ } else {
+ if (inputs[e->dst_input()] == nullptr) {
+ inputs[e->dst_input()] = e;
+ } else {
+ LOG(WARNING) << "Malformed graph node. multiple input edges: "
+ << n->DebugString();
+ }
+ }
+ }
+ // node->name() is merely NodeDef::name, which are not guaranteed
+ // to be unique and stable after optimization rewrites. Therefore,
+ // we use "n<node id>" instead.
+ for (const Edge* e : inputs) {
+ if (e == nullptr) {
+ ndef->add_input("unknown");
+ } else if (!e->src()->IsOp()) {
+ } else if (e->IsControlEdge()) {
+ ndef->add_input(strings::StrCat("^n", e->src()->id()));
+ } else if (e->src_output() == 0) {
+ ndef->add_input(strings::StrCat("n", e->src()->id()));
+ } else {
+ ndef->add_input(
+ strings::StrCat("n", e->src()->id(), ":", e->src_output()));
+ }
+ }
+ }
+}
+
+string DebugString(const Graph* g) {
+ GraphDef gdef;
+ ToGraphDef(g, &gdef);
+ return DebugString(gdef);
+}
+
+FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
+ DataTypeSlice ret_t, Graph* g)
+ : fdef(f),
+ graph(g),
+ arg_types(arg_t.begin(), arg_t.end()),
+ ret_types(ret_t.begin(), ret_t.end()) {
+ this->arg_nodes.resize(arg_types.size());
+ this->ret_nodes.resize(ret_types.size());
+ for (Node* n : this->graph->nodes()) {
+ gtl::InlinedVector<Node*, 4>* node_vec;
+ if (n->type_string() == kRetOp) {
+ node_vec = &this->ret_nodes;
+ } else if (n->type_string() == kArgOp) {
+ node_vec = &this->arg_nodes;
+ } else {
+ continue;
+ }
+ int index;
+ TF_CHECK_OK(GetNodeAttr(n->def(), "index", &index));
+ CHECK_LE(0, index);
+ CHECK_LT(index, node_vec->size());
+ (*node_vec)[index] = n;
+ }
+}
+
+FunctionBody::~FunctionBody() { delete this->graph; }
+
+class SymbolicGradientHelper {
+ public:
+ explicit SymbolicGradientHelper(const FunctionBody& f) : fbody_(&f) {}
+
+ ~SymbolicGradientHelper() { delete gbody_; }
+
+ FunctionBody* Compute();
+
+ private:
+ const FunctionBody* fbody_;
+ FunctionBody* gbody_ = nullptr;
+
+ // A vector of output endpoints which represents backpropagated
+ // gradients
+ typedef std::vector<Endpoint> BackpropedGradients;
+
+ // backprops_ is a map from an output endpoint to its accumulated
+ // gradients. When an output endpoint has accumulated all its
+ // gradients, we add a node which sums them up.
+ std::unordered_map<Endpoint, BackpropedGradients, EndpointHash, EndpointEq>
+ backprops_;
+
+ // pending[i] is count-down counter for i-th node's expected
+ // backprops. When pending[i] becomes zero, we collected all
+ // backprop gradients for all output endpoint of the ith-node.
+ std::vector<int> pending_;
+
+ // 'ready' keeps track of nodes that have been completely
+ // backpropped. Initially, for every output y of the function f, we
+ // add dy as an input of the the gradient function.
+ std::deque<Node*> ready_;
+
+ // Makes a copy of fbody_ in gbody_.
+ void Copy();
+
+ // Initialize pending_ and ready_.
+ void InitBackprop();
+
+ // In the original function body, there is a forward edge from 'src'
+ // to 'dst', when the backprop algorithm constructs the node
+ // 'dst_grad' which computes the gradient, we need to propagate it
+ // to 'src'.
+ void BackpropAlongEdge(const Endpoint& dst_grad, const Endpoint& src);
+ void BackpropZerosAlongEdge(const Endpoint& src);
+
+ Endpoint SumGradients(const Endpoint& src);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientHelper);
+};
+
+void SymbolicGradientHelper::Copy() {
+ const Graph& src = *(fbody_->graph);
+ gbody_->graph = new Graph(src.op_registry());
+ Graph* dst = gbody_->graph;
+
+ std::vector<Node*> node_map(src.num_node_ids());
+
+ // Copy the nodes.
+ node_map[src.source_node()->id()] = dst->source_node();
+ node_map[src.sink_node()->id()] = dst->sink_node();
+ for (Node* n : src.nodes()) {
+ if (n->IsSource() || n->IsSink()) continue;
+ CHECK(n->IsOp());
+ node_map[n->id()] = dst->CopyNode(n);
+ }
+
+ // Copy the edges.
+ for (const Edge* e : src.edges()) {
+ Node* src_copy = node_map[e->src()->id()];
+ Node* dst_copy = node_map[e->dst()->id()];
+ dst->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
+ }
+
+ // Save inputs in copied graph.
+ CHECK_EQ(fbody_->arg_types.size(), fbody_->arg_nodes.size());
+ gbody_->arg_types = fbody_->arg_types;
+ for (std::size_t i = 0; i < fbody_->arg_nodes.size(); ++i) {
+ gbody_->arg_nodes.push_back(node_map[fbody_->arg_nodes[i]->id()]);
+ }
+
+ // Save outputs in copied graph.
+ CHECK_EQ(fbody_->ret_types.size(), fbody_->ret_nodes.size());
+ gbody_->ret_types = fbody_->ret_types;
+ for (std::size_t i = 0; i < fbody_->ret_nodes.size(); ++i) {
+ gbody_->ret_nodes.push_back(node_map[fbody_->ret_nodes[i]->id()]);
+ }
+}
+
+void SymbolicGradientHelper::BackpropAlongEdge(const Endpoint& dst_grad,
+ const Endpoint& src) {
+ CHECK_NOTNULL(src.node);
+ auto iter = backprops_.find(src);
+ if (iter != backprops_.end()) {
+ auto* grads = &iter->second;
+ grads->push_back(dst_grad);
+ if (--pending_[src.node->id()] == 0) {
+ ready_.push_back(src.node);
+ }
+ }
+}
+
+void SymbolicGradientHelper::BackpropZerosAlongEdge(const Endpoint& src) {
+ CHECK_NOTNULL(src.node);
+ auto iter = backprops_.find(src);
+ if (iter != backprops_.end()) {
+ if (--pending_[src.node->id()] == 0) {
+ ready_.push_back(src.node);
+ }
+ }
+}
+
+void SymbolicGradientHelper::InitBackprop() {
+ Graph* g = gbody_->graph;
+ pending_.resize(g->num_node_ids(), 0);
+ {
+ backprops_.clear();
+ std::unordered_set<Node*> visited;
+ std::deque<Node*> queue;
+ for (Node* n : gbody_->arg_nodes) {
+ queue.push_back(n);
+ }
+
+ // Going forward to figure out which endpoints need backprop-ed.
+ // A node's endpoints need to be backprop-ed only if one of the
+ // arg node can reach the node via data edges.
+ while (!queue.empty()) {
+ Node* n = queue.front();
+ queue.pop_front();
+ visited.insert(n);
+ for (int i = 0; i < n->num_outputs(); ++i) {
+ backprops_[{n, i}].clear();
+ }
+ int num_expected_backprops = 0;
+ for (const Edge* e : n->out_edges()) {
+ if (e->IsControlEdge()) continue;
+ ++num_expected_backprops;
+ if (visited.find(e->dst()) == visited.end()) {
+ queue.push_back(e->dst());
+ }
+ }
+ pending_[n->id()] = num_expected_backprops;
+ }
+ }
+
+ {
+ const int num_y = gbody_->ret_nodes.size();
+ for (int i = 0; i < num_y; ++i) {
+ Node* y = gbody_->ret_nodes[i];
+ DCHECK_EQ(y->type_string(), kRetOp);
+ const DataType dtype = y->input_type(0);
+ const int index = gbody_->arg_nodes.size();
+ Node* dy = AddArg(g, dtype, index);
+ gbody_->arg_types.push_back(dtype);
+ gbody_->arg_nodes.push_back(dy);
+
+ // What's the input to y?
+ Endpoint y_in{nullptr, 0};
+ for (const Edge* e : y->in_edges()) {
+ if (!e->IsControlEdge()) {
+ y_in = {e->src(), e->src_output()};
+ break;
+ }
+ }
+ CHECK_NOTNULL(y_in.node);
+ BackpropAlongEdge({dy, 0}, y_in);
+ }
+ }
+}
+
+Endpoint SymbolicGradientHelper::SumGradients(const Endpoint& src) {
+ Graph* g = gbody_->graph;
+ const DataType dtype = src.dtype();
+ auto iter = backprops_.find(src);
+ CHECK(iter != backprops_.end());
+ const auto& grads = iter->second;
+ if (grads.empty()) {
+ // Nothing propagated back. The best we can come up is zeros.
+ Node* zero_like = AddZerosLike(g, src);
+ return {zero_like, 0};
+ }
+ if (grads.size() == 1) {
+ // Just one backprop edge.
+ return grads[0];
+ }
+ // Otherwise, adds backprop-ed gradients.
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("AddN"); // N-way Add
+ for (const Endpoint& ep : grads) {
+ ndef.add_input(ep.name());
+ }
+ AddNodeAttr("N", static_cast<int64>(grads.size()), &ndef);
+ AddNodeAttr("T", dtype, &ndef);
+ Status s;
+ Node* add = gbody_->graph->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ for (size_t i = 0; i < grads.size(); ++i) {
+ const Endpoint& ep = grads[i];
+ g->AddEdge(ep.node, ep.index, add, i);
+ }
+ return {add, 0};
+}
+
+static bool IsPrimitiveOpWithNoGrad(const string& func) {
+ gradient::Creator creator;
+ Status s = gradient::GetOpGradientCreator(func, &creator);
+ return s.ok() && (creator == nullptr);
+}
+
+FunctionBody* SymbolicGradientHelper::Compute() {
+ CHECK(gbody_ == nullptr);
+ gbody_ = new FunctionBody;
+
+ // Copy fbody_ into gbody_.
+ Copy();
+
+ // Initialize backprops.
+ InitBackprop();
+
+ // Backward propagation.
+ gtl::InlinedVector<Endpoint, 8> dy;
+ Graph* g = gbody_->graph;
+ while (!ready_.empty()) {
+ // n has collected all gradients.
+ Node* n = ready_.front();
+ ready_.pop_front();
+
+ if (n->type_string() == kArgOp) {
+ // We'll handle the _Arg node after backprop is done.
+ continue;
+ }
+
+ // "n" has num_x inputs and num_y outputs.
+ const int num_x = n->num_inputs();
+ const int num_y = n->num_outputs();
+
+ // dy[i] is the sum of i-th output's backpropped gradients.
+ dy.clear();
+ dy.resize(num_y, {nullptr, 0});
+ for (int i = 0; i < num_y; ++i) {
+ dy[i] = SumGradients({n, i});
+ }
+
+ if (IsPrimitiveOpWithNoGrad(n->type_string())) {
+ // No grad defined for this op. Backprops zeros along the in
+ // edges.
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ BackpropZerosAlongEdge({e->src(), e->src_output()});
+ }
+ continue;
+ }
+
+ // Adds a gradient node with num_x + num_y inputs and num_x
+ // outputs.
+ Node* grad = AddSymGrad(g, n, dy);
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ g->AddEdge(e->src(), e->src_output(), grad, e->dst_input());
+ }
+ for (int i = 0; i < num_y; ++i) {
+ g->AddEdge(dy[i].node, dy[i].index, grad, num_x + i);
+ }
+
+ // Backprops along the in edges.
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) continue;
+ BackpropAlongEdge({grad, e->dst_input()}, {e->src(), e->src_output()});
+ }
+ }
+
+ // The gradient's retval nodes.
+ for (Node* n : gbody_->ret_nodes) {
+ g->RemoveNode(n);
+ }
+ gbody_->ret_types = fbody_->arg_types;
+ gbody_->ret_nodes.clear();
+ for (size_t i = 0; i < fbody_->arg_types.size(); ++i) {
+ Endpoint grad = SumGradients({gbody_->arg_nodes[i], 0});
+ Node* ret = AddRet(g, grad, i);
+ gbody_->ret_nodes.push_back(ret);
+ }
+
+ auto ret = gbody_;
+ gbody_ = nullptr;
+ return ret;
+}
+
+FunctionBody* SymbolicGradient(const FunctionBody& f) {
+ return SymbolicGradientHelper(f).Compute();
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h
new file mode 100644
index 0000000000..634b31232a
--- /dev/null
+++ b/tensorflow/core/common_runtime/function.h
@@ -0,0 +1,100 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
+#define TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
+
+#include <functional>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// Creates a FunctionLibraryRuntime, which instantiates functions
+// defined in "lib_def" and executes functions on the "device".
+//
+// The returned object does not take ownerships of "device" or
+// "lib_def". The caller must ensure "device" and "lib_def" outlives
+// the returned object.
+typedef std::function<void()> Closure;
+typedef std::function<void(Closure)> Runner;
+FunctionLibraryRuntime* NewFunctionLibraryRuntime(
+ Device* device, Runner runner, const FunctionLibraryDefinition* lib_def);
+
+// FunctionLibraryRuntime::GetFunctionBody returns a description of an
+// instantiated function that is represented as a Graph with arg/ret
+// nodes annotated.
+struct FunctionBody {
+ FunctionDef fdef;
+ Graph* graph = nullptr; // owned.
+ DataTypeVector arg_types;
+ DataTypeVector ret_types;
+ gtl::InlinedVector<Node*, 4> arg_nodes;
+ gtl::InlinedVector<Node*, 4> ret_nodes;
+
+ FunctionBody() {}
+ FunctionBody(const FunctionDef& f, DataTypeSlice arg_types,
+ DataTypeSlice ret_types, Graph* g);
+ ~FunctionBody();
+};
+
+// Debugging facility. Returns a debug string for a graph
+// representing an instantiated function.
+string DebugString(const Graph* instantiated_func_graph);
+
+// A few hand-crafted optimization on the instantiated function body
+// (a Graph*).
+
+// Removes nodes that are
+// 1. not stateful; and
+// 2. not _Arg; and
+// 3. not reachable from _Retval.
+// Returns true iff any node is removed from "g".
+bool RemoveDeadNodes(Graph* g);
+
+// Find a pattern:
+// src -(in)-> node -(out)-> dst, where
+// 1) node is an identity node;
+// 2) in is the only incoming data edge;
+// 3) out is the only outgoing data edge;
+//
+// Rewrites the above pattern with src->dst and relevant data
+// dependencies updated. Repeat the process until no such pattern
+// left.
+bool RemoveIdentityNodes(Graph* g);
+
+// Rewrites _ListToArray and _ArrayToList to a set of Identity nodes.
+bool RemoveListArrayConverter(Graph* g);
+
+// For each node in "graph", if "lib" indicates that the node is a
+// function call, inline the function body. Returns true if at least
+// one node is inlined.
+//
+// This routine goes through "graph" nodes once and applies the
+// inlining. The caller may decide to apply the inlining on "graph"
+// multiple times by calling ExpandInlineFunctions a few times.
+bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph);
+
+// Applies graph rewrite optimzation such as inlining, dead code
+// removal, etc.
+//
+// **g is a graph constructed based on the runtime library 'lib'.
+// OptimizeGraph mutates **g extensively and replaces '*g' with a
+// complete copy. Therefore, the caller should not keep any references
+// to nodes *g.
+void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g);
+
+// Given a numerical function "f", returns another numerical function
+// "g", such that if "f" takes N inputs and produces M outputs, "g"
+// takes N + M inputs and produces N outputs. I.e., if
+// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
+// g is a function which is
+// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
+// dL/dy1, dL/dy2, ..., dL/dy_M),
+// where L is a scalar-value function of (...x_i...).
+//
+// TODO(zhifengc): Asks math expert to say the comment again.
+FunctionBody* SymbolicGradient(const FunctionBody& f);
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
diff --git a/tensorflow/core/common_runtime/gpu/dma_helper.h b/tensorflow/core/common_runtime/gpu/dma_helper.h
new file mode 100644
index 0000000000..7b0750f405
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/dma_helper.h
@@ -0,0 +1,18 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_DMA_HELPER_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_DMA_HELPER_H_
+
+#include "tensorflow/core/public/tensor.h"
+
+// For internal use only. Visibility should be limited to brain/framework.
+
+namespace tensorflow {
+class DMAHelper {
+ public:
+ static bool CanUseDMA(const Tensor* t) { return t->CanUseDMA(); }
+ static const void* base(const Tensor* t) { return t->base<const void>(); }
+ static void* base(Tensor* t) { return t->base<void>(); }
+ static TensorBuffer* buffer(Tensor* t) { return t->buf_; }
+ static const TensorBuffer* buffer(const Tensor* t) { return t->buf_; }
+};
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_DMA_HELPER_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.cc b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.cc
new file mode 100644
index 0000000000..742459c63b
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.cc
@@ -0,0 +1,49 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+GPUAllocatorRetry::GPUAllocatorRetry() : env_(Env::Default()) {}
+
+void* GPUAllocatorRetry::AllocateRaw(
+ std::function<void*(size_t alignment, size_t num_bytes,
+ bool verbose_failure)> alloc_func,
+ int max_millis_to_wait, size_t alignment, size_t num_bytes) {
+ if (num_bytes == 0) {
+ LOG(WARNING) << "Request to allocate 0 bytes";
+ return nullptr;
+ }
+ uint64 deadline_micros = env_->NowMicros() + max_millis_to_wait * 1000;
+ void* ptr = nullptr;
+ while (ptr == nullptr) {
+ ptr = alloc_func(alignment, num_bytes, false);
+ if (ptr == nullptr) {
+ uint64 now = env_->NowMicros();
+ if (now < deadline_micros) {
+ mutex_lock l(mu_);
+ WaitForMilliseconds(&l, &memory_returned_,
+ (deadline_micros - now) / 1000);
+ } else {
+ return alloc_func(alignment, num_bytes, true);
+ }
+ }
+ }
+ return ptr;
+}
+
+void GPUAllocatorRetry::DeallocateRaw(std::function<void(void*)> dealloc_func,
+ void* ptr) {
+ if (ptr == nullptr) {
+ LOG(ERROR) << "Request to free nullptr";
+ return;
+ }
+ dealloc_func(ptr);
+ {
+ mutex_lock l(mu_);
+ memory_returned_.notify_all();
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h
new file mode 100644
index 0000000000..a3298ab222
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h
@@ -0,0 +1,36 @@
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ALLOCATOR_RETRY_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ALLOCATOR_RETRY_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+
+// A retrying wrapper for a memory allocator.
+class GPUAllocatorRetry {
+ public:
+ GPUAllocatorRetry();
+
+ // Call 'alloc_func' to obtain memory. On first call,
+ // 'verbose_failure' will be false. If return value is nullptr,
+ // then wait up to 'max_millis_to_wait' milliseconds, retrying each
+ // time a call to DeallocateRaw() is detected, until either a good
+ // pointer is returned or the deadline is exhausted. If the
+ // deadline is exahusted, try one more time with 'verbose_failure'
+ // set to true. The value returned is either the first good pointer
+ // obtained from 'alloc_func' or nullptr.
+ void* AllocateRaw(std::function<void*(size_t alignment, size_t num_bytes,
+ bool verbose_failure)> alloc_func,
+ int max_millis_to_wait, size_t alignment, size_t bytes);
+
+ // Calls dealloc_func(ptr) and then notifies any threads blocked in
+ // AllocateRaw() that would like to retry.
+ void DeallocateRaw(std::function<void(void* ptr)> dealloc_func, void* ptr);
+
+ private:
+ Env* env_;
+ mutex mu_;
+ condition_variable memory_returned_;
+};
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ALLOCATOR_RETRY_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc
new file mode 100644
index 0000000000..db1c58cc65
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_allocator_retry_test.cc
@@ -0,0 +1,175 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
+
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/env.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+class FakeAllocator {
+ public:
+ FakeAllocator(size_t cap, int millis_to_wait)
+ : memory_capacity_(cap), millis_to_wait_(millis_to_wait) {}
+
+ // Allocate just keeps track of the number of outstanding allocations,
+ // not their sizes. Assume a constant size for each.
+ void* AllocateRaw(size_t alignment, size_t num_bytes) {
+ return retry_.AllocateRaw(
+ [this](size_t a, size_t nb, bool v) {
+ mutex_lock l(mu_);
+ if (memory_capacity_ > 0) {
+ --memory_capacity_;
+ return good_ptr_;
+ } else {
+ return static_cast<void*>(nullptr);
+ }
+ },
+ millis_to_wait_, alignment, num_bytes);
+ }
+
+ void DeallocateRaw(void* ptr) {
+ retry_.DeallocateRaw(
+ [this](void* p) {
+ mutex_lock l(mu_);
+ ++memory_capacity_;
+ },
+ ptr);
+ }
+
+ private:
+ GPUAllocatorRetry retry_;
+ void* good_ptr_ = reinterpret_cast<void*>(0xdeadbeef);
+ mutex mu_;
+ size_t memory_capacity_ GUARDED_BY(mu_);
+ int millis_to_wait_;
+};
+
+class GPUAllocatorRetryTest : public ::testing::Test {
+ protected:
+ GPUAllocatorRetryTest() {}
+
+ void LaunchConsumerThreads(int num_consumers, int cap_needed) {
+ consumer_count_.resize(num_consumers, 0);
+ for (int i = 0; i < num_consumers; ++i) {
+ consumers_.push_back(Env::Default()->StartThread(
+ ThreadOptions(), "anon_thread", [this, i, cap_needed]() {
+ do {
+ void* ptr = nullptr;
+ for (int j = 0; j < cap_needed; ++j) {
+ ptr = alloc_->AllocateRaw(16, 1);
+ if (ptr == nullptr) {
+ mutex_lock l(mu_);
+ has_failed_ = true;
+ return;
+ }
+ }
+ ++consumer_count_[i];
+ for (int j = 0; j < cap_needed; ++j) {
+ alloc_->DeallocateRaw(ptr);
+ }
+ } while (!notifier_.HasBeenNotified());
+ }));
+ }
+ }
+
+ // Wait up to wait_micros microseconds for has_failed_ to equal expected,
+ // then terminate all threads.
+ void JoinConsumerThreads(bool expected, int wait_micros) {
+ while (wait_micros > 0) {
+ {
+ mutex_lock l(mu_);
+ if (has_failed_ == expected) break;
+ }
+ int interval_micros = std::min(1000, wait_micros);
+ Env::Default()->SleepForMicroseconds(interval_micros);
+ wait_micros -= interval_micros;
+ }
+ notifier_.Notify();
+ for (auto c : consumers_) {
+ // Blocks until thread terminates.
+ delete c;
+ }
+ }
+
+ std::unique_ptr<FakeAllocator> alloc_;
+ std::vector<Thread*> consumers_;
+ std::vector<int> consumer_count_;
+ Notification notifier_;
+ mutex mu_;
+ bool has_failed_ GUARDED_BY(mu_) = false;
+ int count_ GUARDED_BY(mu_) = 0;
+};
+
+// Verifies correct retrying when memory is slightly overcommitted but
+// we allow retry.
+TEST_F(GPUAllocatorRetryTest, RetrySuccess) {
+ // Support up to 2 allocations simultaneously, waits up to 10 msec for
+ // a chance to alloc.
+ alloc_.reset(new FakeAllocator(2, 10000));
+ // Launch 3 consumers, each of whom needs 1 unit at a time.
+ LaunchConsumerThreads(3, 1);
+ // This should be enough time for each consumer to be satisfied many times.
+ Env::Default()->SleepForMicroseconds(50000);
+ JoinConsumerThreads(false, 0);
+ for (int i = 0; i < 3; ++i) {
+ LOG(INFO) << "Consumer " << i << " is " << consumer_count_[i];
+ }
+ {
+ mutex_lock l(mu_);
+ EXPECT_FALSE(has_failed_);
+ }
+ EXPECT_GT(consumer_count_[0], 0);
+ EXPECT_GT(consumer_count_[1], 0);
+ EXPECT_GT(consumer_count_[2], 0);
+}
+
+// Verifies OutOfMemory failure when memory is slightly overcommitted
+// and retry is not allowed.
+TEST_F(GPUAllocatorRetryTest, NoRetryFail) {
+ // Support up to 2 allocations simultaneously, waits up to 0 msec for
+ // a chance to alloc.
+ alloc_.reset(new FakeAllocator(2, 0));
+ // Launch 3 consumers, each of whom needs 1 unit at a time.
+ LaunchConsumerThreads(3, 1);
+ Env::Default()->SleepForMicroseconds(50000);
+ // Will wait up to 10 seconds for proper race condition to occur, resulting
+ // in failure.
+ JoinConsumerThreads(true, 10000000);
+ for (int i = 0; i < 3; ++i) {
+ LOG(INFO) << "Consumer " << i << " is " << consumer_count_[i];
+ }
+ {
+ mutex_lock l(mu_);
+ EXPECT_TRUE(has_failed_);
+ }
+}
+
+// Verifies OutOfMemory failure when retry is allowed but memory capacity
+// is too low even for retry.
+TEST_F(GPUAllocatorRetryTest, RetryInsufficientFail) {
+ // Support up to 2 allocations simultaneously, waits up to 10 msec for
+ // a chance to alloc.
+ alloc_.reset(new FakeAllocator(2, 10000));
+ // Launch 3 consumers, each of whom needs 2 units at a time. We expect
+ // deadlock where 2 consumers each hold 1 unit, and timeout trying to
+ // get the second.
+ LaunchConsumerThreads(3, 2);
+ Env::Default()->SleepForMicroseconds(50000);
+ // Will wait up to 10 seconds for proper race condition to occur, resulting
+ // in failure.
+ JoinConsumerThreads(true, 10000000);
+ for (int i = 0; i < 3; ++i) {
+ LOG(INFO) << "Consumer " << i << " is " << consumer_count_[i];
+ }
+ {
+ mutex_lock l(mu_);
+ EXPECT_TRUE(has_failed_);
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
new file mode 100644
index 0000000000..3df833594f
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
@@ -0,0 +1,397 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
+
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/lib/core/bits.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+GPUBFCAllocator::GPUBFCAllocator(int device_id, size_t total_memory)
+ : device_id_(device_id) {
+ // Get a pointer to the stream_executor for this device
+ stream_exec_ = GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+
+ // Allocate the requested amount of memory.
+ gpu_memory_size_ = total_memory;
+
+ LOG(INFO) << "Allocating " << strings::HumanReadableNumBytes(gpu_memory_size_)
+ << " bytes.";
+ gpu::DeviceMemory<char> gpu_mem =
+ stream_exec_->AllocateArray<char>(gpu_memory_size_);
+
+ QCHECK(gpu_mem != nullptr)
+ << " Could not allocate GPU device memory for device " << device_id
+ << ". Tried to allocate "
+ << strings::HumanReadableNumBytes(gpu_memory_size_);
+ base_ptr_ = gpu_mem.opaque();
+ LOG(INFO) << "GPU " << device_id << " memory begins at " << base_ptr_
+ << " extends to "
+ << static_cast<void*>(
+ (static_cast<char*>(base_ptr_) + gpu_memory_size_));
+
+ // Create a bunch of bins of various good sizes.
+
+ // Covers allocations of exactly 256 bytes (the minimum size).
+ bins_.insert(std::make_pair(256, new Bin(256)));
+
+ // We create bins to fit all possible ranges that cover the
+ // gpu_memory_size_ starting from allocations up to 1024 bytes to
+ // allocations up to (and including) the memory limit.
+ for (size_t bin_size = 1024; bin_size < gpu_memory_size_ * 2; bin_size *= 2) {
+ LOG(INFO) << "Creating bin of max chunk size "
+ << strings::HumanReadableNumBytes(bin_size);
+ bins_.insert(std::make_pair(bin_size, new Bin(bin_size)));
+ }
+
+ // Create one large chunk for the whole memory space that will
+ // be chunked later.
+ GPUBFCAllocator::Chunk* c = new GPUBFCAllocator::Chunk();
+ c->ptr = gpu_mem.opaque();
+ c->size = gpu_memory_size_;
+ c->in_use = false;
+ c->prev = nullptr;
+ c->next = nullptr;
+
+ ptr_to_chunk_map_.insert(std::make_pair(c->ptr, c));
+
+ // Insert the chunk into the right bin.
+ ReassignChunkToBin(c);
+}
+
+GPUBFCAllocator::~GPUBFCAllocator() {
+ // Return memory back.
+ if (base_ptr_) {
+ gpu::DeviceMemoryBase gpu_ptr{base_ptr_};
+ stream_exec_->Deallocate(&gpu_ptr);
+ }
+
+ gtl::STLDeleteValues(&bins_);
+}
+
+void* GPUBFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes) {
+ static const int64 kMaxMillisToWait = 10000; // 10 seconds
+ return retry_helper_.AllocateRaw(
+ [this](size_t a, size_t nb, bool v) {
+ return AllocateRawInternal(a, nb, v);
+ },
+ kMaxMillisToWait, unused_alignment, num_bytes);
+}
+
+void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment,
+ size_t num_bytes,
+ bool dump_log_on_failure) {
+ if (num_bytes == 0) {
+ LOG(ERROR) << "tried to allocate 0 bytes";
+ return nullptr;
+ }
+ // First, always allocate memory of at least 256 bytes, and always
+ // allocate multiples of 256 bytes so all memory addresses are
+ // nicely byte aligned.
+ size_t rounded_bytes = (256 * ((num_bytes + 255) / 256));
+ DCHECK_EQ(0, rounded_bytes % 256);
+
+ // The BFC allocator tries to find the best fit first.
+ //
+ // First identify the first bin that could satisfy rounded_bytes.
+ auto it = bins_.lower_bound(rounded_bytes);
+ if (it == bins_.end()) {
+ LOG(ERROR) << " Asked for " << rounded_bytes << " but largest bin was "
+ << bins_.rbegin()->first;
+ return nullptr;
+ }
+
+ mutex_lock l(lock_);
+ for (; it != bins_.end(); ++it) {
+ // Start searching from the first bin for the smallest chunk that fits
+ // rounded_bytes.
+ Bin* b = it->second;
+ for (GPUBFCAllocator::Chunk* chunk : b->chunks) {
+ if (!chunk->in_use && chunk->size > rounded_bytes) {
+ // We found an existing chunk that fits us that wasn't in use.
+ chunk->in_use = true;
+
+ // If we can break the size of the chunk into two reasonably
+ // large pieces, do so.
+ //
+ // TODO(vrv): What should be the criteria when deciding when
+ // to split?
+ if (chunk->size >= rounded_bytes * 2) {
+ SplitChunk(chunk, rounded_bytes);
+ }
+
+ // The requested size of the returned chunk is what the user
+ // has allocated.
+ chunk->requested_size = num_bytes;
+
+ VLOG(4) << "Returning: " << chunk->ptr;
+ return chunk->ptr;
+ }
+ }
+ }
+
+ // We searched all bins for an existing free chunk to use and
+ // couldn't find one. This means we must have run out of memory,
+ // Dump the memory log for analysis.
+ if (dump_log_on_failure) {
+ DumpMemoryLog(rounded_bytes);
+ LOG(WARNING) << "Ran out of memory trying to allocate "
+ << strings::HumanReadableNumBytes(num_bytes)
+ << ". See logs for memory state";
+ }
+ return nullptr;
+}
+
+void GPUBFCAllocator::SplitChunk(GPUBFCAllocator::Chunk* c, size_t num_bytes) {
+ // Create a new chunk starting num_bytes after c
+ GPUBFCAllocator::Chunk* new_chunk = new GPUBFCAllocator::Chunk();
+ new_chunk->ptr = static_cast<void*>(static_cast<char*>(c->ptr) + num_bytes);
+ VLOG(6) << "Adding to chunk map: " << new_chunk->ptr;
+ ptr_to_chunk_map_.insert(std::make_pair(new_chunk->ptr, new_chunk));
+
+ // Set the new sizes of the chunks.
+ new_chunk->size = c->size - num_bytes;
+ c->size = num_bytes;
+
+ // The new chunk is not in use.
+ new_chunk->in_use = false;
+
+ // Maintain the pointers.
+ // c <-> c_neighbor becomes
+ // c <-> new_chunk <-> c_neighbor
+ GPUBFCAllocator::Chunk* c_neighbor = c->next;
+ new_chunk->prev = c;
+ new_chunk->next = c_neighbor;
+ c->next = new_chunk;
+ if (c_neighbor) {
+ c_neighbor->prev = new_chunk;
+ }
+
+ // Maintain the bins
+ ReassignChunkToBin(new_chunk);
+ ReassignChunkToBin(c);
+}
+
+void GPUBFCAllocator::DeallocateRaw(void* ptr) {
+ retry_helper_.DeallocateRaw([this](void* p) { DeallocateRawInternal(p); },
+ ptr);
+}
+
+void GPUBFCAllocator::DeallocateRawInternal(void* ptr) {
+ if (ptr == nullptr) {
+ LOG(ERROR) << "tried to deallocate nullptr";
+ return;
+ }
+ mutex_lock l(lock_);
+
+ // Find the chunk from the ptr.
+ auto it = ptr_to_chunk_map_.find(ptr);
+ CHECK(it != ptr_to_chunk_map_.end())
+ << "Asked to deallocate a pointer we never allocated: " << ptr;
+
+ GPUBFCAllocator::Chunk* c = it->second;
+ VLOG(6) << "Chunk at " << c->ptr << " no longer in use";
+ // Mark the chunk as no longer in use
+ c->in_use = false;
+
+ // Consider coalescing it.
+ MaybeCoalesce(c);
+}
+
+// Merges c1 and c2 when c1->next is c2 and c2->prev is c1.
+// We merge c2 into c1.
+void GPUBFCAllocator::Merge(GPUBFCAllocator::Chunk* c1,
+ GPUBFCAllocator::Chunk* c2) {
+ // We can only merge chunks that are not in use.
+ DCHECK(!c1->in_use && !c2->in_use);
+
+ // c1's prev doesn't change, still points to the same ptr, and is
+ // still not in use.
+
+ // Fix up neighbor pointers
+ //
+ // c1 <-> c2 <-> c3 should become
+ // c1 <-> c3
+ GPUBFCAllocator::Chunk* c3 = c2->next;
+ c1->next = c3;
+ CHECK(c2->prev == c1);
+ if (c3 != nullptr) {
+ c3->prev = c1;
+ }
+
+ // Set the new size
+ c1->size += c2->size;
+
+ // Delete c2 and cleanup all state
+ RemoveChunkFromBin(c2);
+}
+
+void GPUBFCAllocator::ReassignChunkToBin(GPUBFCAllocator::Chunk* c) {
+ auto it = bins_.lower_bound(c->size);
+ CHECK(it != bins_.end()) << " Tried to reassign to non-existent bin for size "
+ << c->size;
+
+ Bin* new_bin = it->second;
+
+ // If the bin has not changed, do nothing.
+ Bin* old_bin = c->bin;
+ if (old_bin != nullptr && new_bin == old_bin) {
+ return;
+ }
+
+ // The bin has changed. Add the chunk to the new bin and remove
+ // the chunk from the old bin.
+ new_bin->chunks.insert(c);
+ c->bin = new_bin;
+
+ if (old_bin == nullptr) {
+ return;
+ }
+
+ // Remove chunk from old bin
+ for (auto it = old_bin->chunks.begin(); it != old_bin->chunks.end(); ++it) {
+ if (*it == c) {
+ old_bin->chunks.erase(it);
+ return;
+ }
+ }
+ CHECK(false) << "Could not find chunk in old bin";
+}
+
+void GPUBFCAllocator::RemoveChunkFromBin(GPUBFCAllocator::Chunk* c) {
+ Bin* b = c->bin;
+ for (auto it = b->chunks.begin(); it != b->chunks.end(); ++it) {
+ Chunk* other_c = *it;
+ if (other_c->ptr == c->ptr) {
+ b->chunks.erase(it);
+ VLOG(4) << "Removing: " << c->ptr;
+ ptr_to_chunk_map_.erase(c->ptr);
+ delete c;
+ return;
+ }
+ }
+
+ CHECK(false) << "Could not find chunk in bin";
+}
+
+void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) {
+ // This chunk is no longer in-use, consider coalescing the chunk
+ // with adjacent chunks.
+ Chunk* chunk_to_reassign = nullptr;
+
+ // If the next chunk is free, coalesce the two, if the result would
+ // fit in an existing bin.
+ if (c->next && !c->next->in_use) {
+ VLOG(8) << "Chunk at " << c->next->ptr << " merging with c " << c->ptr;
+
+ chunk_to_reassign = c;
+
+ // Deletes c->next
+ Merge(c, c->next);
+ }
+
+ // If the previous chunk is free, coalesce the two
+ if (c->prev && !c->prev->in_use) {
+ VLOG(8) << "Chunk at " << c->ptr << " merging into c->prev "
+ << c->prev->ptr;
+
+ chunk_to_reassign = c->prev;
+
+ // Deletes c
+ Merge(c->prev, c);
+ }
+
+ // Reassign the final merged chunk into the right bin.
+ if (chunk_to_reassign) {
+ ReassignChunkToBin(chunk_to_reassign);
+ }
+}
+
+void GPUBFCAllocator::AddAllocVisitor(Visitor visitor) {
+ VLOG(1) << "AddVisitor";
+ mutex_lock l(lock_);
+ region_visitors_.push_back(visitor);
+ visitor(base_ptr_, gpu_memory_size_);
+}
+
+bool GPUBFCAllocator::TracksAllocationSizes() { return true; }
+
+size_t GPUBFCAllocator::RequestedSize(void* ptr) {
+ mutex_lock l(lock_);
+ auto it = ptr_to_chunk_map_.find(ptr);
+ CHECK(it != ptr_to_chunk_map_.end())
+ << "Asked for requested size of pointer we never allocated: " << ptr;
+ GPUBFCAllocator::Chunk* c = it->second;
+ return c->requested_size;
+}
+
+size_t GPUBFCAllocator::AllocatedSize(void* ptr) {
+ mutex_lock l(lock_);
+ auto it = ptr_to_chunk_map_.find(ptr);
+ CHECK(it != ptr_to_chunk_map_.end())
+ << "Asked for allocated size of pointer we never allocated: " << ptr;
+ GPUBFCAllocator::Chunk* c = it->second;
+ return c->size;
+}
+
+void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) {
+ // For each bin: tally up the total number of chunks and bytes.
+ for (auto bit : bins_) {
+ Bin* b = bit.second;
+
+ size_t total_bytes_in_use = 0;
+ size_t total_bytes_in_bin = 0;
+ size_t total_requested_bytes_in_use = 0;
+ size_t total_requested_bytes_in_bin = 0;
+ size_t total_chunks_in_use = 0;
+ size_t total_chunks_in_bin = 0;
+ for (Chunk* c : b->chunks) {
+ total_bytes_in_bin += c->size;
+ total_requested_bytes_in_bin += c->requested_size;
+ ++total_chunks_in_bin;
+ if (c->in_use) {
+ total_bytes_in_use += c->size;
+ total_requested_bytes_in_use += c->requested_size;
+ ++total_chunks_in_use;
+ }
+ }
+
+ LOG(INFO) << "Bin (" << b->bin_size
+ << "): \tTotal Chunks: " << total_chunks_in_bin
+ << ", Chunks in use: " << total_chunks_in_use << " "
+ << strings::HumanReadableNumBytes(total_bytes_in_bin)
+ << " allocated for chunks. "
+ << strings::HumanReadableNumBytes(total_requested_bytes_in_bin)
+ << " client-requested for chunks. "
+ << strings::HumanReadableNumBytes(total_bytes_in_use)
+ << " in use in bin. "
+ << strings::HumanReadableNumBytes(total_requested_bytes_in_use)
+ << " client-requested in use in bin.";
+ }
+
+ // Find the bin that we would have liked to allocate in, so we
+ // can get some further analysis about fragmentation.
+ auto it = bins_.lower_bound(num_bytes);
+ if (it != bins_.end()) {
+ Bin* b = it->second;
+
+ LOG(INFO) << "Bin for " << strings::HumanReadableNumBytes(num_bytes)
+ << " was " << strings::HumanReadableNumBytes(b->bin_size)
+ << ", Chunk State: ";
+
+ for (Chunk* c : b->chunks) {
+ LOG(INFO) << c->DebugString(true);
+ }
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
new file mode 100644
index 0000000000..3d1601e132
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
@@ -0,0 +1,156 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
+#include "tensorflow/core/common_runtime/gpu/visitable_allocator.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace tensorflow {
+
+// A GPU memory allocator that implements a 'best-fit with coalescing'
+// algorithm. This is essentially a very simple version of Doug Lea's
+// malloc (dlmalloc).
+//
+// The goal of this allocator is to support defragmentation via
+// coalescing. One assumption we make is that the process using this
+// allocator owns pretty much all of the GPU memory, and that nearly
+// all requests to allocate GPU memory go through this interface.
+class GPUBFCAllocator : public VisitableAllocator {
+ public:
+ // 'device_id' refers to the StreamExecutor ID of the device within
+ // the process and must reference a valid ID in the process.
+ explicit GPUBFCAllocator(int device_id, size_t total_memory);
+ ~GPUBFCAllocator() override;
+
+ string Name() override { return "gpu_bfc"; }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override;
+ void DeallocateRaw(void* ptr) override;
+
+ void AddAllocVisitor(Visitor visitor) override;
+
+ // Does nothing, because gpu memory is never freed.
+ void AddFreeVisitor(Visitor visitor) override {}
+
+ bool TracksAllocationSizes() override;
+
+ size_t RequestedSize(void* ptr) override;
+
+ size_t AllocatedSize(void* ptr) override;
+
+ private:
+ struct Bin;
+
+ void* AllocateRawInternal(size_t alignment, size_t num_bytes,
+ bool dump_log_on_failure);
+ void DeallocateRawInternal(void* ptr);
+
+ // Chunks point to GPU memory. Their prev/next pointers form a
+ // doubly-linked list of addresses sorted by GPU base address that
+ // must be contiguous. Chunks contain information about whether
+ // they are in use or whether they are free, and contain a pointer
+ // to the bin they are in.
+ struct Chunk {
+ size_t size = 0; // Full size of GPU buffer.
+
+ // We sometimes give chunks that are larger than needed to reduce
+ // fragmentation. requested_size keeps track of what the client
+ // actually wanted so we can understand whether our splitting
+ // strategy is efficient.
+ size_t requested_size = 0;
+
+ bool in_use = false;
+ void* ptr = nullptr; // pointer to granted GPU subbuffer.
+
+ // If not null, the memory referred to by 'prev' is directly
+ // preceding the memory used by this chunk. E.g., It should start
+ // at 'ptr - prev->size'
+ Chunk* prev = nullptr;
+
+ // If not null, the memory referred to by 'next' is directly
+ // following the memory used by this chunk. E.g., It should be at
+ // 'ptr + size'
+ Chunk* next = nullptr;
+
+ // What bin are we in?
+ Bin* bin = nullptr;
+
+ string DebugString(bool recurse) {
+ string dbg;
+ strings::StrAppend(&dbg, " Size: ", strings::HumanReadableNumBytes(size),
+ " | Requested Size: ",
+ strings::HumanReadableNumBytes(requested_size),
+ " | in_use: ", in_use);
+ if (recurse && prev) {
+ strings::StrAppend(&dbg, ", prev: ", prev->DebugString(false));
+ }
+ if (recurse && next) {
+ strings::StrAppend(&dbg, ", next: ", next->DebugString(false));
+ }
+ return dbg;
+ }
+ };
+
+ Chunk* AllocateNewChunk(size_t num_bytes);
+ void SplitChunk(Chunk* c, size_t num_bytes);
+ void Merge(Chunk* c1, Chunk* c2);
+ void MaybeCoalesce(Chunk* c);
+
+ void ReassignChunkToBin(Chunk* c);
+ void RemoveChunkFromBin(Chunk* c);
+
+ void DumpMemoryLog(size_t num_bytes);
+
+ // A Bin is a collection of similar-sized Chunks.
+ struct Bin {
+ // All chunks in this bin have >= bin_size memory.
+ size_t bin_size = 0;
+
+ struct ChunkComparator {
+ bool operator()(Chunk* a, Chunk* b) { return a->size < b->size; }
+ };
+
+ // List of chunks within the bin, sorted by chunk size.
+ std::multiset<Chunk*, ChunkComparator> chunks;
+
+ explicit Bin(size_t bs) : bin_size(bs) {}
+
+ ~Bin() { gtl::STLDeleteElements(&chunks); }
+ };
+
+ GPUAllocatorRetry retry_helper_;
+
+ // Structures immutable after construction
+ const int device_id_;
+ // The base pointer where all the GPU memory begins.
+ void* base_ptr_ = nullptr;
+ size_t gpu_memory_size_ = 0;
+
+ // Map from bin size to Bin
+ // After construction, the bin map is never resized.
+ std::map<size_t, Bin*> bins_;
+
+ perftools::gputools::StreamExecutor* stream_exec_; // Not owned.
+
+ // Structures mutable after construction
+ mutable mutex lock_;
+ // Not owned.
+ std::unordered_map<void*, Chunk*> ptr_to_chunk_map_;
+
+ // Called once on each region, ASAP.
+ std::vector<Visitor> region_visitors_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
new file mode 100644
index 0000000000..7b5e8aec1d
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
@@ -0,0 +1,166 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
+
+#include <algorithm>
+#include <vector>
+
+#include "tensorflow/stream_executor/stream_executor.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+namespace {
+
+TEST(GPUBFCAllocatorTest, NoDups) {
+ GPUBFCAllocator a(0, 1 << 30);
+ // Allocate a lot of raw pointers
+ std::vector<void*> ptrs;
+ for (int s = 1; s < 1024; s++) {
+ void* raw = a.AllocateRaw(1, s);
+ ptrs.push_back(raw);
+ }
+
+ std::sort(ptrs.begin(), ptrs.end());
+
+ // Make sure none of them are equal, and that none of them overlap.
+ for (int i = 0; i < ptrs.size(); i++) {
+ if (i > 0) {
+ ASSERT_NE(ptrs[i], ptrs[i - 1]); // No dups
+ size_t req_size = a.RequestedSize(ptrs[i - 1]);
+ ASSERT_GT(req_size, 0);
+ ASSERT_GE(static_cast<char*>(ptrs[i]) - static_cast<char*>(ptrs[i - 1]),
+ req_size);
+ }
+ }
+
+ for (int i = 0; i < ptrs.size(); i++) {
+ a.DeallocateRaw(ptrs[i]);
+ }
+}
+
+TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) {
+ GPUBFCAllocator a(0, 1 << 30);
+ // Allocate 256 raw pointers of sizes between 100 bytes and about
+ // a meg
+ random::PhiloxRandom philox(123, 17);
+ random::SimplePhilox rand(&philox);
+
+ std::vector<void*> initial_ptrs;
+ for (int s = 1; s < 256; s++) {
+ size_t size = std::min<size_t>(
+ std::max<size_t>(rand.Rand32() % 1048576, 100), 1048576);
+ void* raw = a.AllocateRaw(1, size);
+
+ initial_ptrs.push_back(raw);
+ }
+
+ // Deallocate half of the memory, and keep track of the others.
+ std::vector<void*> existing_ptrs;
+ for (int i = 0; i < initial_ptrs.size(); i++) {
+ if (i % 2 == 1) {
+ a.DeallocateRaw(initial_ptrs[i]);
+ } else {
+ existing_ptrs.push_back(initial_ptrs[i]);
+ }
+ }
+
+ // Allocate a lot of raw pointers
+ for (int s = 1; s < 256; s++) {
+ size_t size = std::min<size_t>(
+ std::max<size_t>(rand.Rand32() % 1048576, 100), 1048576);
+ void* raw = a.AllocateRaw(1, size);
+ existing_ptrs.push_back(raw);
+ }
+
+ std::sort(existing_ptrs.begin(), existing_ptrs.end());
+ // Make sure none of them are equal
+ for (int i = 0; i < existing_ptrs.size(); i++) {
+ if (i > 0) {
+ CHECK_NE(existing_ptrs[i], existing_ptrs[i - 1]); // No dups
+
+ size_t req_size = a.RequestedSize(existing_ptrs[i - 1]);
+ ASSERT_GT(req_size, 0);
+
+ // Check that they don't overlap.
+ ASSERT_GE(static_cast<char*>(existing_ptrs[i]) -
+ static_cast<char*>(existing_ptrs[i - 1]),
+ req_size);
+ }
+ }
+
+ for (int i = 0; i < existing_ptrs.size(); i++) {
+ a.DeallocateRaw(existing_ptrs[i]);
+ }
+}
+
+TEST(GPUBFCAllocatorTest, ExerciseCoalescing) {
+ GPUBFCAllocator a(0, 1 << 30);
+
+ float* first_ptr = a.Allocate<float>(1024);
+ a.Deallocate(first_ptr);
+ for (int i = 0; i < 1024; ++i) {
+ // Allocate several buffers of different sizes, and then clean them
+ // all up. We should be able to repeat this endlessly without
+ // causing fragmentation and growth.
+ float* t1 = a.Allocate<float>(1024);
+
+ int64* t2 = a.Allocate<int64>(1048576);
+ double* t3 = a.Allocate<double>(2048);
+ float* t4 = a.Allocate<float>(10485760);
+
+ a.Deallocate(t1);
+ a.Deallocate(t2);
+ a.Deallocate(t3);
+ a.Deallocate(t4);
+ }
+
+ // At the end, we should have coalesced all memory into one region
+ // starting at the beginning, so validate that allocating a pointer
+ // starts from this region.
+ float* first_ptr_after = a.Allocate<float>(1024);
+ EXPECT_EQ(first_ptr, first_ptr_after);
+ a.Deallocate(first_ptr_after);
+}
+
+TEST(GPUBFCAllocatorTest, AllocateZeroBufSize) {
+ GPUBFCAllocator a(0, 1 << 30);
+ float* ptr = a.Allocate<float>(0);
+ EXPECT_EQ(nullptr, ptr);
+}
+
+TEST(GPUBFCAllocatorTest, TracksSizes) {
+ GPUBFCAllocator a(0, 1 << 30);
+ EXPECT_EQ(true, a.TracksAllocationSizes());
+}
+
+TEST(GPUBFCAllocatorTest, AllocatedVsRequested) {
+ GPUBFCAllocator a(0, 1 << 30);
+ float* t1 = a.Allocate<float>(1);
+ EXPECT_EQ(4, a.RequestedSize(t1));
+ EXPECT_EQ(256, a.AllocatedSize(t1));
+ a.Deallocate(t1);
+}
+
+TEST(GPUBFCAllocatorTest, TestCustomMemoryLimit) {
+ // Configure a 1MiB byte limit
+ GPUBFCAllocator a(0, 1 << 20);
+
+ float* first_ptr = a.Allocate<float>(1 << 6);
+ float* second_ptr = a.Allocate<float>(1 << 20);
+
+ EXPECT_NE(nullptr, first_ptr);
+ EXPECT_EQ(nullptr, second_ptr);
+ a.Deallocate(first_ptr);
+}
+
+} // namespace
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
new file mode 100644
index 0000000000..5ec405cd80
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
@@ -0,0 +1,186 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h"
+
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+#define MASK_WORDS 2
+#define MASK_BYTES (MASK_WORDS * sizeof(int64))
+
+namespace {
+
+static int64* NewMask(int64 word) {
+ int64* m = new int64[MASK_WORDS];
+ for (int i = 0; i < MASK_WORDS; ++i) {
+ m[i] = word;
+ }
+ return m;
+}
+
+static int64* before_mask = NewMask(0xabababababababab);
+static int64* after_mask = NewMask(0xcdcdcdcdcdcdcdcd);
+
+bool CheckMask(perftools::gputools::StreamExecutor* exec, void* ptr,
+ int64* mask) {
+ gpu::DeviceMemory<int64> gpu_ptr{gpu::DeviceMemoryBase{ptr, MASK_BYTES}};
+ int64 tmp[MASK_WORDS];
+
+ if (!exec->SynchronousMemcpy(&tmp, gpu_ptr, MASK_BYTES)) {
+ LOG(FATAL) << "Could not copy debug mask";
+ }
+
+ bool ok = true;
+ for (int i = 0; i < MASK_WORDS; ++i) {
+ ok &= (mask[i] == tmp[i]);
+ if (!ok) {
+ LOG(ERROR) << "i=" << i
+ << " mask=" << reinterpret_cast<const void*>(mask[i])
+ << " field=" << reinterpret_cast<const void*>(tmp[i]);
+ }
+ }
+
+ return ok;
+}
+
+void InitMask(perftools::gputools::StreamExecutor* exec, void* ptr,
+ int64* mask) {
+ gpu::DeviceMemory<int64> gpu_ptr{gpu::DeviceMemoryBase{ptr, MASK_BYTES}};
+ if (!exec->SynchronousMemcpy(&gpu_ptr, mask, MASK_BYTES)) {
+ LOG(FATAL) << "Could not copy debug mask";
+ }
+}
+
+} // namespace
+
+// -----------------------------------------------------------------------------
+// GPUDebugAllocator
+// -----------------------------------------------------------------------------
+GPUDebugAllocator::GPUDebugAllocator(VisitableAllocator* allocator,
+ int device_id)
+ : base_allocator_(allocator) {
+ stream_exec_ = GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+}
+
+GPUDebugAllocator::~GPUDebugAllocator() { delete base_allocator_; }
+
+void* GPUDebugAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
+ num_bytes += (2 * MASK_BYTES);
+
+ void* allocated_ptr = base_allocator_->AllocateRaw(alignment, num_bytes);
+
+ // Return the pointer after the header
+ void* rv = static_cast<char*>(allocated_ptr) + MASK_BYTES;
+
+ // Write the header at allocated_ptr
+ InitMask(stream_exec_, allocated_ptr, before_mask);
+
+ // Write the footer at the end.
+ size_t req_size = base_allocator_->RequestedSize(allocated_ptr);
+ InitMask(stream_exec_,
+ static_cast<char*>(allocated_ptr) + req_size - MASK_BYTES,
+ after_mask);
+ return rv;
+}
+void GPUDebugAllocator::DeallocateRaw(void* ptr) {
+ CHECK(CheckHeader(ptr)) << "before_mask has been overwritten";
+ CHECK(CheckFooter(ptr)) << "after_mask has been overwritten";
+
+ // Backtrack to the beginning of the header.
+ ptr = static_cast<void*>(static_cast<char*>(ptr) - MASK_BYTES);
+ // Deallocate the memory
+ base_allocator_->DeallocateRaw(ptr);
+}
+
+void GPUDebugAllocator::AddAllocVisitor(Visitor visitor) {
+ return base_allocator_->AddAllocVisitor(visitor);
+}
+
+void GPUDebugAllocator::AddFreeVisitor(Visitor visitor) {
+ return base_allocator_->AddFreeVisitor(visitor);
+}
+
+bool GPUDebugAllocator::TracksAllocationSizes() { return true; }
+
+size_t GPUDebugAllocator::RequestedSize(void* ptr) {
+ auto req_size =
+ base_allocator_->RequestedSize(static_cast<char*>(ptr) - MASK_BYTES);
+ return req_size - 2 * MASK_BYTES;
+}
+
+size_t GPUDebugAllocator::AllocatedSize(void* ptr) {
+ return base_allocator_->AllocatedSize(static_cast<char*>(ptr) - MASK_BYTES);
+}
+
+bool GPUDebugAllocator::CheckHeader(void* ptr) {
+ return CheckMask(stream_exec_, static_cast<char*>(ptr) - MASK_BYTES,
+ before_mask);
+}
+
+bool GPUDebugAllocator::CheckFooter(void* ptr) {
+ char* original_ptr = static_cast<char*>(ptr) - MASK_BYTES;
+ size_t req_size = base_allocator_->RequestedSize(original_ptr);
+ return CheckMask(stream_exec_, original_ptr + req_size - MASK_BYTES,
+ after_mask);
+}
+
+// -----------------------------------------------------------------------------
+// GPUNanResetAllocator
+// -----------------------------------------------------------------------------
+GPUNanResetAllocator::GPUNanResetAllocator(VisitableAllocator* allocator,
+ int device_id)
+ : base_allocator_(allocator) {
+ stream_exec_ = GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+}
+
+GPUNanResetAllocator::~GPUNanResetAllocator() { delete base_allocator_; }
+
+void* GPUNanResetAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
+ void* allocated_ptr = base_allocator_->AllocateRaw(alignment, num_bytes);
+
+ // Initialize the buffer to Nans
+ size_t req_size = base_allocator_->RequestedSize(allocated_ptr);
+ std::vector<float> nans(req_size / sizeof(float), std::nanf(""));
+ gpu::DeviceMemory<float> nan_ptr{
+ gpu::DeviceMemoryBase{static_cast<float*>(allocated_ptr), req_size}};
+
+ if (!stream_exec_->SynchronousMemcpy(&nan_ptr, &nans[0], req_size)) {
+ LOG(ERROR) << "Could not initialize to NaNs";
+ }
+
+ return allocated_ptr;
+}
+void GPUNanResetAllocator::DeallocateRaw(void* ptr) {
+ // Reset the buffer to Nans
+ size_t req_size = base_allocator_->RequestedSize(ptr);
+ std::vector<float> nans(req_size / sizeof(float), std::nanf(""));
+ gpu::DeviceMemory<float> nan_ptr{
+ gpu::DeviceMemoryBase{static_cast<float*>(ptr), req_size}};
+ if (!stream_exec_->SynchronousMemcpy(&nan_ptr, &nans[0], req_size)) {
+ LOG(ERROR) << "Could not initialize to NaNs";
+ }
+
+ // Deallocate the memory
+ base_allocator_->DeallocateRaw(ptr);
+}
+
+void GPUNanResetAllocator::AddAllocVisitor(Visitor visitor) {
+ return base_allocator_->AddAllocVisitor(visitor);
+}
+
+void GPUNanResetAllocator::AddFreeVisitor(Visitor visitor) {
+ return base_allocator_->AddFreeVisitor(visitor);
+}
+
+size_t GPUNanResetAllocator::RequestedSize(void* ptr) {
+ return base_allocator_->RequestedSize(ptr);
+}
+
+size_t GPUNanResetAllocator::AllocatedSize(void* ptr) {
+ return base_allocator_->AllocatedSize(ptr);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
new file mode 100644
index 0000000000..c9b564ffc4
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
@@ -0,0 +1,68 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/common_runtime/gpu/visitable_allocator.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+namespace tensorflow {
+
+// An allocator that wraps a GPU allocator and adds debugging
+// functionality that verifies that users do not write outside their
+// allocated memory.
+class GPUDebugAllocator : public VisitableAllocator {
+ public:
+ explicit GPUDebugAllocator(VisitableAllocator* allocator, int device_id);
+ ~GPUDebugAllocator() override;
+ string Name() override { return "gpu_debug"; }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override;
+ void DeallocateRaw(void* ptr) override;
+ void AddAllocVisitor(Visitor visitor) override;
+ void AddFreeVisitor(Visitor visitor) override;
+ bool TracksAllocationSizes() override;
+ size_t RequestedSize(void* ptr) override;
+ size_t AllocatedSize(void* ptr) override;
+
+ // For testing.
+ bool CheckHeader(void* ptr);
+ bool CheckFooter(void* ptr);
+
+ private:
+ VisitableAllocator* base_allocator_ = nullptr; // owned
+
+ perftools::gputools::StreamExecutor* stream_exec_; // Not owned.
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GPUDebugAllocator);
+};
+
+// An allocator that wraps a GPU allocator and resets the memory on
+// allocation and free to 'NaN', helping to identify cases where the
+// user forgets to initialize the memory.
+class GPUNanResetAllocator : public VisitableAllocator {
+ public:
+ explicit GPUNanResetAllocator(VisitableAllocator* allocator, int device_id);
+ ~GPUNanResetAllocator() override;
+ string Name() override { return "gpu_nan_reset"; }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override;
+ void DeallocateRaw(void* ptr) override;
+ void AddAllocVisitor(Visitor visitor) override;
+ void AddFreeVisitor(Visitor visitor) override;
+ size_t RequestedSize(void* ptr) override;
+ size_t AllocatedSize(void* ptr) override;
+
+ private:
+ VisitableAllocator* base_allocator_ = nullptr; // owned
+
+ perftools::gputools::StreamExecutor* stream_exec_; // Not owned.
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GPUNanResetAllocator);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
new file mode 100644
index 0000000000..5f63906576
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
@@ -0,0 +1,207 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h"
+
+#include <algorithm>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include <gtest/gtest.h>
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+TEST(GPUDebugAllocatorTest, OverwriteDetection_None) {
+ const int device_id = 0;
+ GPUDebugAllocator a(new GPUBFCAllocator(device_id, 1 << 30), device_id);
+ auto stream_exec =
+ GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+
+ for (int s : {8}) {
+ std::vector<int64> cpu_array(s);
+ memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
+ int64* gpu_array = a.Allocate<int64>(cpu_array.size());
+ gpu::DeviceMemory<int64> gpu_array_ptr{gpu::DeviceMemoryBase{gpu_array}};
+ ASSERT_TRUE(stream_exec->SynchronousMemcpy(&gpu_array_ptr, &cpu_array[0],
+ s * sizeof(int64)));
+ EXPECT_TRUE(a.CheckHeader(gpu_array));
+ EXPECT_TRUE(a.CheckFooter(gpu_array));
+
+ // Confirm no error on free.
+ a.DeallocateRaw(gpu_array);
+ }
+}
+
+TEST(GPUDebugAllocatorTest, OverwriteDetection_Header) {
+ for (int s : {8, 211}) {
+ EXPECT_DEATH(
+ {
+ const int device_id = 0;
+ GPUDebugAllocator a(new GPUBFCAllocator(device_id, 1 << 30),
+ device_id);
+ auto stream_exec =
+ GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+
+ std::vector<int64> cpu_array(s);
+ memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
+ int64* gpu_array = a.Allocate<int64>(cpu_array.size());
+
+ gpu::DeviceMemory<int64> gpu_array_ptr{
+ gpu::DeviceMemoryBase{gpu_array}};
+ ASSERT_TRUE(stream_exec->SynchronousMemcpy(
+ &gpu_array_ptr, &cpu_array[0], cpu_array.size() * sizeof(int64)));
+
+ gpu::DeviceMemory<int64> gpu_hdr_ptr{
+ gpu::DeviceMemoryBase{gpu_array - 1}};
+ // Clobber first word of the header.
+ float pi = 3.1417;
+ ASSERT_TRUE(
+ stream_exec->SynchronousMemcpy(&gpu_hdr_ptr, &pi, sizeof(float)));
+
+ // Expect error on free.
+ a.Deallocate(gpu_array);
+ },
+ "");
+ }
+}
+
+TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) {
+ for (int s : {8, 22}) {
+ EXPECT_DEATH(
+ {
+ const int device_id = 0;
+ GPUDebugAllocator a(new GPUBFCAllocator(device_id, 1 << 30),
+ device_id);
+ auto stream_exec =
+ GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+
+ std::vector<int64> cpu_array(s);
+ memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
+ int64* gpu_array = a.Allocate<int64>(cpu_array.size());
+
+ gpu::DeviceMemory<int64> gpu_array_ptr{
+ gpu::DeviceMemoryBase{gpu_array}};
+ ASSERT_TRUE(stream_exec->SynchronousMemcpy(
+ &gpu_array_ptr, &cpu_array[0], cpu_array.size() * sizeof(int64)));
+
+ // Clobber word of the footer.
+ gpu::DeviceMemory<int64> gpu_ftr_ptr{
+ gpu::DeviceMemoryBase{gpu_array + s}};
+ float pi = 3.1417;
+ ASSERT_TRUE(
+ stream_exec->SynchronousMemcpy(&gpu_ftr_ptr, &pi, sizeof(float)));
+
+ // Expect error on free.
+ a.Deallocate(gpu_array);
+ },
+ "");
+ }
+}
+
+TEST(GPUDebugAllocatorTest, ResetToNan) {
+ const int device_id = 0;
+ GPUNanResetAllocator a(new GPUBFCAllocator(device_id, 1 << 30), device_id);
+ auto stream_exec =
+ GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+
+ std::vector<float> cpu_array(1024);
+ std::vector<float> cpu_array_result(1024);
+
+ // Allocate 1024 floats
+ float* gpu_array = a.Allocate<float>(cpu_array.size());
+ gpu::DeviceMemory<float> gpu_array_ptr{gpu::DeviceMemoryBase{gpu_array}};
+ ASSERT_TRUE(stream_exec->SynchronousMemcpy(&cpu_array[0], gpu_array_ptr,
+ cpu_array.size() * sizeof(float)));
+ for (float f : cpu_array) {
+ ASSERT_FALSE(std::isfinite(f));
+ }
+
+ // Set one of the fields to 1.0.
+ cpu_array[0] = 1.0;
+ ASSERT_TRUE(stream_exec->SynchronousMemcpy(&gpu_array_ptr, &cpu_array[0],
+ cpu_array.size() * sizeof(float)));
+ // Copy the data back and verify.
+ ASSERT_TRUE(
+ stream_exec->SynchronousMemcpy(&cpu_array_result[0], gpu_array_ptr,
+ cpu_array_result.size() * sizeof(float)));
+ ASSERT_EQ(1.0, cpu_array_result[0]);
+
+ // Free the array
+ a.Deallocate(gpu_array);
+
+ // All values should be reset to nan.
+ ASSERT_TRUE(
+ stream_exec->SynchronousMemcpy(&cpu_array_result[0], gpu_array_ptr,
+ cpu_array_result.size() * sizeof(float)));
+ for (float f : cpu_array_result) {
+ ASSERT_FALSE(std::isfinite(f));
+ }
+}
+
+TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) {
+ const int device_id = 0;
+ // NaN reset must be the outer-most allocator.
+ GPUNanResetAllocator a(
+ new GPUDebugAllocator(new GPUBFCAllocator(device_id, 1 << 30), device_id),
+ device_id);
+ auto stream_exec =
+ GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+
+ std::vector<float> cpu_array(1024);
+ std::vector<float> cpu_array_result(1024);
+
+ // Allocate 1024 floats
+ float* gpu_array = a.Allocate<float>(cpu_array.size());
+ gpu::DeviceMemory<float> gpu_array_ptr{gpu::DeviceMemoryBase{gpu_array}};
+ ASSERT_TRUE(stream_exec->SynchronousMemcpy(&cpu_array[0], gpu_array_ptr,
+ cpu_array.size() * sizeof(float)));
+ for (float f : cpu_array) {
+ ASSERT_FALSE(std::isfinite(f));
+ }
+
+ // Set one of the fields to 1.0.
+ cpu_array[0] = 1.0;
+ ASSERT_TRUE(stream_exec->SynchronousMemcpy(&gpu_array_ptr, &cpu_array[0],
+ cpu_array.size() * sizeof(float)));
+ // Copy the data back and verify.
+ ASSERT_TRUE(
+ stream_exec->SynchronousMemcpy(&cpu_array_result[0], gpu_array_ptr,
+ cpu_array_result.size() * sizeof(float)));
+ ASSERT_EQ(1.0, cpu_array_result[0]);
+
+ // Free the array
+ a.Deallocate(gpu_array);
+
+ // All values should be reset to nan.
+ ASSERT_TRUE(
+ stream_exec->SynchronousMemcpy(&cpu_array_result[0], gpu_array_ptr,
+ cpu_array_result.size() * sizeof(float)));
+ for (float f : cpu_array_result) {
+ ASSERT_FALSE(std::isfinite(f));
+ }
+}
+
+TEST(GPUDebugAllocatorTest, TracksSizes) {
+ GPUDebugAllocator a(new GPUBFCAllocator(0, 1 << 30), 0);
+ EXPECT_EQ(true, a.TracksAllocationSizes());
+}
+
+TEST(GPUDebugAllocatorTest, AllocatedVsRequested) {
+ GPUNanResetAllocator a(
+ new GPUDebugAllocator(new GPUBFCAllocator(0, 1 << 30), 0), 0);
+ float* t1 = a.Allocate<float>(1);
+ EXPECT_EQ(4, a.RequestedSize(t1));
+ EXPECT_EQ(256, a.AllocatedSize(t1));
+ a.Deallocate(t1);
+}
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
new file mode 100644
index 0000000000..26d34645f1
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -0,0 +1,651 @@
+// TODO(opensource): Use a more generic sounding preprocessor name than
+// GOOGLE_CUDA
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
+
+#include <stdlib.h>
+#include <string.h>
+
+//#include "base/commandlineflags.h"
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
+#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/types.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+#if defined(PLATFORM_GOOGLE)
+DEFINE_bool(brain_gpu_sync_every_op, false,
+ "If true, call GPUUtil::Sync() between every dispatched opkernel.");
+
+DEFINE_int32(brain_gpu_max_streams, 1,
+ "Max number of GPU streams to use for computation.");
+#else
+// TODO(opensource): These should be made options in some options struct,
+// rather than flags.
+bool FLAGS_brain_gpu_sync_every_op = false;
+tensorflow::int32 FLAGS_brain_gpu_max_streams = 1;
+#endif
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+// Eigen Ops directly allocate memory only for temporary buffers used
+// during OpKernel::Compute(). The recommended way of allocating such
+// memory is via OpKernelContext::allocate_temp(). However, Eigen Ops
+// don't have access to OpKernelContext, instead they get access to
+// memory directly through the device allocator. As an Open Source
+// project, Eigen assumes allocator semantics similar to those of the
+// CUDA memory allocator, and may not work correctly due to race
+// conditions if used with some other allocator. For safety, we need
+// to delay deallocation calls out of Eigen until all events on the
+// corresponding stream have completed. The following two classes
+// serve this purpose in two different compilation environments.
+
+#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
+class EigenAllocator : public ::Eigen::Allocator {
+ public:
+ explicit EigenAllocator(gpu::Stream* stream, ::tensorflow::Allocator* alloc,
+ EventMgr* em)
+ : stream_(stream), allocator_(alloc), em_(em) {}
+
+ void* allocate(size_t num_bytes) const override {
+ void* ret = allocator_->AllocateRaw(32 /* alignment */, num_bytes);
+ // Eigen doesn't typically check the return pointer from allocate,
+ // so we do it here and die with a more helpful error message.
+ if (ret == nullptr) {
+ LOG(FATAL) << "EigenAllocator for GPU ran out of memory when allocating "
+ << num_bytes << ". See error logs for more detailed info.";
+ }
+ return ret;
+ }
+
+ void deallocate(void* buffer) const override {
+ em_->ThenDeleteBuffer(stream_, {allocator_, buffer});
+ }
+
+ private:
+ gpu::Stream* stream_; // Not owned.
+ ::tensorflow::Allocator* allocator_; // Not owned.
+ ::tensorflow::EventMgr* em_; // Not owned.
+
+ TF_DISALLOW_COPY_AND_ASSIGN(EigenAllocator);
+};
+
+#else
+class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
+ public:
+ EigenCudaStreamDevice(const cudaStream_t* cuda_stream, int gpu_id,
+ ::tensorflow::Allocator* alloc)
+ : stream_(cuda_stream), allocator_(alloc) {
+ Eigen::initializeDeviceProp();
+ device_prop_ = &Eigen::m_deviceProperties[gpu_id];
+ }
+
+ const cudaStream_t& stream() const override { return *stream_; }
+ const cudaDeviceProp& deviceProperties() const override {
+ return *device_prop_;
+ }
+
+ void* allocate(size_t num_bytes) const override {
+ void* ret = allocator_->AllocateRaw(32 /* alignment */, num_bytes);
+ if (ret == nullptr) {
+ LOG(FATAL) << "EigenAllocator for GPU ran out of memory when allocating "
+ << num_bytes << ". See error logs for more detailed info.";
+ }
+
+ return ret;
+ }
+ void deallocate(void* buffer) const override {
+ AsyncFreeData* afData = new AsyncFreeData(allocator_, buffer);
+ cudaError_t err = cudaStreamAddCallback(*stream_, asyncFree, afData, 0);
+ CHECK_EQ(err, cudaSuccess);
+ }
+
+ private:
+ struct AsyncFreeData {
+ AsyncFreeData(::tensorflow::Allocator* a, void* p)
+ : allocator_(a), address_(p) {}
+ ::tensorflow::Allocator* allocator_;
+ void* address_;
+ };
+
+ static void CUDART_CB asyncFree(cudaStream_t stream, cudaError_t status,
+ void* userData) {
+ AsyncFreeData* data = static_cast<AsyncFreeData*>(userData);
+ data->allocator_->DeallocateRaw(data->address_);
+ delete data;
+ }
+
+ const cudaStream_t* stream_; // Not owned.
+ const cudaDeviceProp* device_prop_; // Not owned.
+ ::tensorflow::Allocator* allocator_; // Not owned.
+
+ TF_DISALLOW_COPY_AND_ASSIGN(EigenCudaStreamDevice);
+};
+
+#endif
+
+BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
+ Bytes memory_limit, BusAdjacency bus_adjacency,
+ int gpu_id, const string& physical_device_desc,
+ Allocator* gpu_allocator, Allocator* cpu_allocator)
+ : LocalDevice(options, Device::BuildDeviceAttributes(
+ name, DEVICE_GPU, memory_limit, bus_adjacency,
+ physical_device_desc),
+ gpu_allocator),
+ gpu_allocator_(gpu_allocator),
+ cpu_allocator_(cpu_allocator),
+ gpu_id_(gpu_id) {
+ gpu::StreamExecutor* executor =
+ GPUMachineManager()->ExecutorForDevice(gpu_id_).ValueOrDie();
+ if (!executor) {
+ LOG(ERROR) << "Failed to get StreamExecutor for device " << gpu_id_;
+ return;
+ }
+ em_.reset(new EventMgr(executor));
+
+ if (FLAGS_brain_gpu_max_streams < 1) {
+ LOG(FATAL) << "Invalid value for brain_gpu_max_streams.";
+ }
+
+ // Create the specified number of GPU streams
+ for (int i = 0; i < FLAGS_brain_gpu_max_streams; i++) {
+ auto stream = new gpu::Stream(executor);
+ stream->Init();
+ VLOG(2) << "Created stream[" << i << "] = " << stream;
+ streams_.push_back(stream);
+ device_contexts_.push_back(new GPUDeviceContext(i, stream));
+ }
+ gpu_device_info_ = new GpuDeviceInfo;
+ gpu_device_info_->stream = streams_[0];
+ gpu_device_info_->default_context = device_contexts_[0];
+ gpu_device_info_->event_mgr = em_.get();
+ set_tensorflow_gpu_device_info(gpu_device_info_);
+}
+
+BaseGPUDevice::~BaseGPUDevice() {
+ delete gpu_device_info_;
+ for (auto ctx : device_contexts_) ctx->Unref();
+ gtl::STLDeleteElements(&streams_);
+}
+
+Status BaseGPUDevice::FillContextMap(const Graph* graph,
+ DeviceContextMap* device_context_map) {
+ VLOG(2) << "FillContextMap";
+
+ const auto num_streams = streams_.size();
+ // Special case for single stream.
+ if (num_streams == 1) {
+ return Status::OK();
+ }
+ const int64 before = Env::Default()->NowMicros();
+ gpu_stream_util::AssignStreamsOpts opts;
+ opts.max_streams = num_streams;
+ std::unordered_map<int, int> node_to_stream_id;
+ TF_RETURN_IF_ERROR(
+ gpu_stream_util::AssignStreams(graph, opts, &node_to_stream_id));
+ int64 elapsed = Env::Default()->NowMicros() - before;
+ VLOG(3) << "AssignStreams took " << elapsed << "us";
+
+ // Fill in the context map. It is OK for this map to contain
+ // duplicate DeviceContexts so long as we increment the refcount.
+ for (Node* n : graph->nodes()) {
+ auto mapped_stream = node_to_stream_id[n->id()];
+ CHECK_LE(mapped_stream, num_streams);
+ auto ctx = device_contexts_[mapped_stream];
+ VLOG(3) << "Assigned stream " << node_to_stream_id[n->id()]
+ << " ==> stream[" << ctx->stream_id() << "] for node id " << n->id()
+ << " " << n->type_string() << " " << n->name();
+ ctx->Ref();
+ device_context_map->insert(std::make_pair(n->id(), ctx));
+ }
+
+ return Status::OK();
+}
+
+void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
+ // ScopedActivity is cheap when tracing is not active, but we
+ // can avoid computing the Hash64.
+ // TODO(pbar) This would no longer be needed if Ops have a unique id.
+ const uint64 id = port::Tracing::IsActive() ? Hash64(op_kernel->name()) : 0;
+ port::Tracing::ScopedActivity region(port::Tracing::EventCategory::kCompute,
+ id);
+
+ GPUDeviceContext* gpu_device_context = device_contexts_[0];
+ if (context->op_device_context() != nullptr) {
+ gpu_device_context =
+ static_cast<GPUDeviceContext*>(context->op_device_context());
+ }
+ gpu::Stream* stream = gpu_device_context->stream();
+ const auto stream_id = gpu_device_context->stream_id();
+
+ VLOG(1) << "GpuDevice::Compute " << op_kernel->name() << " op "
+ << op_kernel->def().op() << " on GPU" << gpu_id_ << " stream["
+ << stream_id << "]";
+
+ // NOTE(tucker): We need to discriminate between Eigen GPU
+ // operations and all others. If an operation is Eigen
+ // implemented (or otherwise tries to launch a cuda kernel
+ // directly), we need to establish a stacked-scoped environment
+ // that directs it to execute on the proper device. Otherwise we
+ // expect the Op to use StreamExecutor directly and correctly. The
+ // way we make this discrimination is quite hacky: At the moment
+ // the only non-Eigen GPU Op is the recv-op, which is known to be
+ // asynchronous.
+ if (op_kernel->type_string() == "_Recv") {
+ context->SetStatus(errors::Internal(
+ "Invalid synchronous 'Compute' on GPU for '_Recv' op"));
+ } else {
+ const string label =
+ strings::StrCat(op_kernel->name(), ":", op_kernel->type_string());
+ port::Tracing::ScopedAnnotation annotation(label);
+
+ const auto num_streams = streams_.size();
+ if (num_streams > 1) {
+ // If this op's device context is different from the other contexts,
+ // we must wait on the stream.
+ for (int i = 0; i < context->num_inputs(); ++i) {
+ const GPUDeviceContext* idc =
+ static_cast<GPUDeviceContext*>(context->input_device_context(i));
+ OP_REQUIRES(context, idc != nullptr,
+ errors::Internal("Input device context ", i,
+ " was not set properly."));
+ if (VLOG_IS_ON(2)) {
+ const void* base;
+ size_t len;
+ if (context->has_input(i)) {
+ if (IsRefType(context->input_dtype(i))) {
+ Tensor tensor = context->mutable_input(i, false);
+ base = DMAHelper::base(&tensor);
+ len = tensor.TotalBytes();
+ } else {
+ const Tensor& tensor = context->input(i);
+ base = DMAHelper::base(&tensor);
+ len = tensor.TotalBytes();
+ }
+ VLOG(2) << "Input " << i << " " << base << " " << len;
+ VLOG(2) << " stream[" << stream_id << "].ThenWaitFor(stream["
+ << idc->stream_id() << "])"
+ << ((idc->stream() == stream) ? " not needed" : "");
+ }
+ }
+ if (idc->stream() != stream) stream->ThenWaitFor(idc->stream());
+ }
+ }
+ gpu::cuda::ScopedActivateExecutorContext scoped_activation{
+ stream->parent(), gpu::cuda::MultiOpActivation::kYes};
+ // Keep a copy of the inputs before Compute runs, in case they get
+ // deleted. TODO(misard) this will be fixed when the tracking is
+ // done right.
+ std::vector<Tensor>* tensor_refs = nullptr;
+ if (!FLAGS_brain_gpu_sync_every_op) {
+ tensor_refs = new std::vector<Tensor>;
+ tensor_refs->reserve(context->num_inputs() + context->num_outputs());
+ for (int ii = 0; ii < context->num_inputs(); ++ii) {
+ if (context->has_input(ii)) {
+ if (IsRefType(context->input_dtype(ii))) {
+ Tensor in = context->mutable_input(ii, false);
+ tensor_refs->push_back(in);
+ } else {
+ const Tensor& in = context->input(ii);
+ tensor_refs->push_back(in);
+ }
+ }
+ }
+ }
+ op_kernel->Compute(context);
+ if (context->status().ok()) {
+ if (FLAGS_brain_gpu_sync_every_op) {
+ // Note: GPUUtil::Sync() only syncs the default stream.
+ // We need to either sync the stream used by this op, or
+ // all streams. Given that this flag is typically used for
+ // debugging it makes more sense to sync all GPU activity.
+ context->SetStatus(GPUUtil::SyncAll(this));
+ } else {
+ // The GPU kernel has been queued, but may not complete for some
+ // time. As soon as this function completes, the caller will
+ // discard its refs on the inputs, outputs and any scratch
+ // tensors it created. Create additional refs here that will be
+ // held until the kernel completes.
+ for (int ii = 0; ii < context->num_temps(); ++ii) {
+ Tensor* temp = context->temp(ii);
+ VLOG(2) << "Saving ref to temp Tensor @ " << DMAHelper::base(temp);
+ tensor_refs->push_back(*temp);
+ }
+ for (int ii = 0; ii < context->num_outputs(); ++ii) {
+ Tensor* temp = context->mutable_output(ii);
+ if (nullptr != temp) {
+ tensor_refs->push_back(*temp);
+ }
+ }
+ em_->ThenDeleteTensors(stream, tensor_refs);
+ }
+ } else {
+ if (!FLAGS_brain_gpu_sync_every_op) {
+ delete tensor_refs;
+ }
+ }
+ }
+}
+
+Status BaseGPUDevice::Sync() { return GPUUtil::Sync(this); }
+
+void BaseGPUDevice::ComputeAsync(AsyncOpKernel* op_kernel,
+ OpKernelContext* context,
+ AsyncOpKernel::DoneCallback done) {
+ GPUDeviceContext* gpu_device_context = device_contexts_[0];
+ if (context->op_device_context() != nullptr) {
+ gpu_device_context =
+ static_cast<GPUDeviceContext*>(context->op_device_context());
+ }
+ const auto stream_id = gpu_device_context->stream_id();
+
+ VLOG(1) << "GpuDevice::ComputeAsync " << op_kernel->name() << " op "
+ << op_kernel->def().op() << " on GPU" << gpu_id_ << " stream["
+ << stream_id << "]";
+
+ port::Tracing::TraceMe activity(
+ strings::StrCat(op_kernel->name(), ":", op_kernel->type_string()));
+ op_kernel->ComputeAsync(context, done);
+}
+
+Status BaseGPUDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
+ const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) {
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ attr.set_gpu_compatible(true);
+ Allocator* host_alloc = GetAllocator(attr);
+ Tensor parsed(tensor_proto.dtype());
+ if (!parsed.FromProto(host_alloc, tensor_proto)) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ tensor_proto.DebugString());
+ }
+ Status status;
+ if (alloc_attrs.on_host()) {
+ *tensor = parsed;
+ } else {
+ if (!DMAHelper::CanUseDMA(&parsed)) {
+ return errors::Internal("GPU copy from non-DMA ",
+ DataTypeString(parsed.dtype()), " tensor");
+ }
+ Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
+ port::Tracing::ScopedAnnotation annotation("MakeTensorFromProto");
+ Notification n;
+ device_contexts_[0]->CopyCPUTensorToDevice(&parsed, this, &copy,
+ [&n, &status](const Status& s) {
+ status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ *tensor = copy;
+ }
+ return status;
+}
+
+namespace {
+#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
+class ConcretePerOpGpuDevice : public PerOpGpuDevice {
+ public:
+ explicit ConcretePerOpGpuDevice(gpu::Stream* stream,
+ EigenAllocator* allocator)
+ : device_(stream, allocator), allocator_(allocator) {}
+ ~ConcretePerOpGpuDevice() { delete allocator_; }
+
+ const Eigen::GpuDevice& device() const override { return device_; }
+
+ private:
+ Eigen::GpuDevice device_;
+ EigenAllocator* allocator_;
+};
+#else
+class ConcretePerOpGpuDevice : public PerOpGpuDevice {
+ public:
+ explicit ConcretePerOpGpuDevice(EigenCudaStreamDevice* stream_device)
+ : device_(stream_device), stream_device_(stream_device) {}
+ ~ConcretePerOpGpuDevice() { delete stream_device_; }
+
+ const Eigen::GpuDevice& device() const override { return device_; }
+
+ private:
+ Eigen::GpuDevice device_;
+ EigenCudaStreamDevice* stream_device_;
+};
+#endif
+} // namespace
+
+const PerOpGpuDevice* BaseGPUDevice::NewDevice(int stream_id,
+ Allocator* allocator) {
+#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
+ auto ea = new EigenAllocator(streams_[stream_id], allocator, em_.get());
+ return new ConcretePerOpGpuDevice(streams_[stream_id], ea);
+#else
+ const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>(
+ streams_[stream_id]->implementation()->CudaStreamMemberHack());
+ auto es = new EigenCudaStreamDevice(cuda_stream, gpu_id_, allocator);
+ return new ConcretePerOpGpuDevice(es);
+#endif
+}
+
+const PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice(DeviceContext* dc,
+ Allocator* allocator) {
+ if (dc) {
+ const GPUDeviceContext* gpu_dc = static_cast<GPUDeviceContext*>(dc);
+ const int stream_id = gpu_dc->stream_id();
+ VLOG(1) << " eigen_gpu_device(" << dc << ") => stream[" << stream_id
+ << "]";
+ CHECK_LT(stream_id, streams_.size());
+ return NewDevice(stream_id, allocator);
+ } else {
+ return NewDevice(0, allocator);
+ }
+}
+
+void BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
+ const string& name_prefix,
+ std::vector<Device*>* devices) {
+ int n = INT_MAX;
+ auto iter = options.config.device_count().find("GPU");
+ if (iter != options.config.device_count().end()) {
+ n = iter->second;
+ }
+ std::vector<int> valid_gpu_ids;
+ GetValidDeviceIds(&valid_gpu_ids);
+ if (static_cast<size_t>(n) > valid_gpu_ids.size()) {
+ n = valid_gpu_ids.size();
+ }
+ for (int i = 0; i < n; i++) {
+ devices->push_back(CreateGPUDevice(
+ options, strings::StrCat(name_prefix, "/gpu:", i), valid_gpu_ids[i]));
+ }
+}
+
+namespace {
+int64 MinSystemMemory(int64 available_memory) {
+ // We use the following heuristic for now:
+ //
+ // If the available_memory is < 2GiB, we allocate 200MiB to system memory.
+ // Otherwise, allocate 300MiB to system memory.
+ //
+ // In the future we could be more sophisticated by using a table of
+ // devices.
+ if (available_memory < (1LL << 31)) {
+ // 200MiB
+ return 209715200LL;
+ } else {
+ // max(300 MiB, 0.95 * available_memory)
+ return std::max(314572800LL, static_cast<int64>(available_memory * 0.05));
+ }
+}
+} // namespace
+
+static string GetShortDeviceDescription(int device_id,
+ const gpu::DeviceDescription& desc) {
+ return strings::StrCat("device: ", device_id, ", name: ", desc.name(),
+ ", pci bus id: ", desc.pci_bus_id());
+}
+
+LocalDevice* BaseGPUDeviceFactory::CreateGPUDevice(
+ const SessionOptions& options, const string& name, int gpu_id) {
+ CHECK_GE(gpu_id, 0);
+
+ // Look up the device, to see its attributes.
+ gpu::Platform* gpu_platform = GPUMachineManager();
+ CHECK_LT(gpu_id, gpu_platform->VisibleDeviceCount());
+ gpu::StreamExecutor* se =
+ gpu_platform->ExecutorForDevice(gpu_id).ValueOrDie();
+ const gpu::DeviceDescription& desc = se->GetDeviceDescription();
+
+ int64 total_memory, available_memory;
+ CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory));
+
+ int64 allocated_memory = available_memory;
+ double config_memory_fraction =
+ options.config.gpu_options().per_process_gpu_memory_fraction();
+ if (config_memory_fraction == 0) {
+ const int64 min_system_memory = MinSystemMemory(available_memory);
+ if (min_system_memory < allocated_memory) {
+ allocated_memory -= min_system_memory;
+ }
+ } else {
+ allocated_memory *= config_memory_fraction;
+ }
+
+ Bytes allocated_bytes = static_cast<Bytes>(allocated_memory);
+
+ // Get GPU BusAdjacency from its reported NUMA affinity.
+ // Because GPUs are virtualized in some environments, we can't just
+ // use the GPU id.
+ BusAdjacency bus_adjacency = BUS_ANY;
+ switch (desc.numa_node()) {
+ case 0:
+ bus_adjacency = BUS_0;
+ break;
+ case 1:
+ bus_adjacency = BUS_1;
+ break;
+ default:
+ bus_adjacency = BUS_ANY;
+ }
+ VLOG(1) << "GPUDevice id " << gpu_id << " on bus " << bus_adjacency
+ << " numa: " << desc.numa_node() << " pci: " << desc.pci_bus_id();
+
+ ProcessState* process_state = ProcessState::singleton();
+ return CreateGPUDevice(
+ options, name, allocated_bytes, bus_adjacency, gpu_id,
+ GetShortDeviceDescription(gpu_id, desc),
+ process_state->GetGPUAllocator(gpu_id, allocated_memory),
+ process_state->GetCPUAllocator(desc.numa_node()));
+}
+
+static int GetMinGPUMultiprocessorCount() {
+ static const int kDefaultMinGPUMultiprocessorCount = 8;
+
+ const char* tf_min_gpu_core_count = getenv("TF_MIN_GPU_MULTIPROCESSOR_COUNT");
+
+ if (tf_min_gpu_core_count == nullptr ||
+ strcmp(tf_min_gpu_core_count, "") == 0) {
+ return kDefaultMinGPUMultiprocessorCount;
+ }
+
+ int min_gpu_core_count = -1;
+ if (strings::safe_strto32(tf_min_gpu_core_count, &min_gpu_core_count)) {
+ if (min_gpu_core_count >= 0) {
+ return min_gpu_core_count;
+ }
+ }
+
+ LOG(ERROR) << "Invalid minimum GPU multiprocessor count: ["
+ << tf_min_gpu_core_count << "]. "
+ << "Using the default value: "
+ << kDefaultMinGPUMultiprocessorCount;
+ return kDefaultMinGPUMultiprocessorCount;
+}
+
+void BaseGPUDeviceFactory::GetValidDeviceIds(std::vector<int>* ids) {
+ auto gpu_manager = GPUMachineManager();
+ int min_gpu_core_count = GetMinGPUMultiprocessorCount();
+ if (gpu_manager) {
+ auto visible_device_count = gpu_manager->VisibleDeviceCount();
+ for (int i = 0; i < gpu_manager->VisibleDeviceCount(); ++i) {
+ auto exec_status = gpu_manager->ExecutorForDevice(i);
+ if (!exec_status.ok()) {
+ continue;
+ }
+ gpu::StreamExecutor* se = exec_status.ValueOrDie();
+ const gpu::DeviceDescription& desc = se->GetDeviceDescription();
+ int major, minor;
+ if (!desc.cuda_compute_capability(&major, &minor)) {
+ continue;
+ }
+ // Only consider GPUs with compute capability >= 3.5 (Kepler or
+ // higher)
+ if (major < 3 || (major == 3 && minor < 5)) {
+ LOG(INFO) << "Ignoring gpu device "
+ << "(" << GetShortDeviceDescription(i, desc) << ") "
+ << "with Cuda compute capability " << major << "." << minor
+ << ". The minimum required Cuda capability is 3.5.";
+ continue;
+ }
+
+ // TensorFlow currently places computation on devices assuming
+ // they have similar capability.
+ //
+ // If there are multiple GPUs available on the machine, only
+ // consider GPUs with 8 or more multiprocessors.
+ //
+ // TODO(vrv): In the medium term: we should only filter out GPUs
+ // that are slow relative to the fastest GPU. In the long term,
+ // TensorFlow should support automatic placement based on
+ // capability.
+ if (visible_device_count > 1) {
+ if (desc.core_count() < min_gpu_core_count) {
+ LOG(INFO) << "Ignoring gpu device "
+ << "(" << GetShortDeviceDescription(i, desc) << ") "
+ << "with Cuda multiprocessor count: " << desc.core_count()
+ << ". The minimum required count is " << min_gpu_core_count
+ << ". You can adjust this requirement with the env var "
+ "TF_MIN_GPU_MULTIPROCESSOR_COUNT.";
+ continue;
+ }
+ }
+
+ int new_id = ids->size();
+ ids->push_back(i);
+
+ LOG(INFO) << "Creating TensorFlow device (/gpu:" << new_id << ") -> "
+ << "(" << GetShortDeviceDescription(i, desc) << ")";
+ }
+ }
+}
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
new file mode 100644
index 0000000000..a415224d95
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -0,0 +1,94 @@
+#if !GOOGLE_CUDA
+#error This file must only be included when building with Cuda support
+#endif
+
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEVICE_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEVICE_H_
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
+#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+class EigenAllocator;
+
+class BaseGPUDevice : public LocalDevice {
+ public:
+ BaseGPUDevice(const SessionOptions& options, const string& name,
+ Bytes memory_limit, BusAdjacency bus_adjacency, int gpu_id,
+ const string& physical_device_desc, Allocator* gpu_allocator,
+ Allocator* cpu_allocator);
+
+ ~BaseGPUDevice() override;
+
+ // GPU devices require the Op Compute method to save a reference to
+ // any temporary tensors that are allocated until the Op execution
+ // completes.
+ bool SaveTemporaryTensors() const override { return true; }
+
+ Status FillContextMap(const Graph* graph,
+ DeviceContextMap* device_context_map);
+
+ void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
+
+ Status Sync() override;
+
+ void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
+ AsyncOpKernel::DoneCallback done) override;
+
+ Status MakeTensorFromProto(const TensorProto& tensor_proto,
+ const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) override;
+
+ // The caller owns the returned device.
+ const PerOpGpuDevice* MakeGpuDevice(DeviceContext* dc,
+ Allocator* allocator) override;
+
+ protected:
+ Allocator* gpu_allocator_; // not owned
+ Allocator* cpu_allocator_; // not owned
+
+ private:
+ std::vector<gpu::Stream*> streams_;
+ std::vector<GPUDeviceContext*> device_contexts_;
+ GpuDeviceInfo* gpu_device_info_ = nullptr;
+ mutex trace_mu_;
+ int gpu_id_ = -1;
+ std::unique_ptr<EventMgr> em_;
+
+ const PerOpGpuDevice* NewDevice(int stream_id, Allocator* allocator);
+};
+
+class BaseGPUDeviceFactory : public DeviceFactory {
+ public:
+ void CreateDevices(const SessionOptions& options, const string& name_prefix,
+ std::vector<Device*>* devices) override;
+
+ private:
+ LocalDevice* CreateGPUDevice(const SessionOptions& options,
+ const string& name, int gpu_id);
+
+ virtual LocalDevice* CreateGPUDevice(const SessionOptions& options,
+ const string& name, Bytes memory_limit,
+ BusAdjacency bus_adjacency, int gpu_id,
+ const string& physical_device_desc,
+ Allocator* gpu_allocator,
+ Allocator* cpu_allocator) = 0;
+
+ void GetValidDeviceIds(std::vector<int>* ids);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
new file mode 100644
index 0000000000..240ac47499
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
@@ -0,0 +1,52 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
+
+namespace tensorflow {
+
+void RequireGPUDevice() {}
+
+class GPUDevice : public BaseGPUDevice {
+ public:
+ GPUDevice(const SessionOptions& options, const string& name,
+ Bytes memory_limit, BusAdjacency bus_adjacency, int gpu_id,
+ const string& physical_device_desc, Allocator* gpu_allocator,
+ Allocator* cpu_allocator)
+ : BaseGPUDevice(options, name, memory_limit, bus_adjacency, gpu_id,
+ physical_device_desc, gpu_allocator, cpu_allocator) {}
+
+ Allocator* GetAllocator(AllocatorAttributes attr) override {
+ if (attr.on_host()) {
+ ProcessState* ps = ProcessState::singleton();
+ if (attr.gpu_compatible()) {
+ return ps->GetCUDAHostAllocator(0);
+ } else {
+ return cpu_allocator_;
+ }
+ } else {
+ return gpu_allocator_;
+ }
+ }
+};
+
+class GPUDeviceFactory : public BaseGPUDeviceFactory {
+ private:
+ LocalDevice* CreateGPUDevice(const SessionOptions& options,
+ const string& name, Bytes memory_limit,
+ BusAdjacency bus_adjacency, int gpu_id,
+ const string& physical_device_desc,
+ Allocator* gpu_allocator,
+ Allocator* cpu_allocator) override {
+ return new GPUDevice(options, name, memory_limit, bus_adjacency, gpu_id,
+ physical_device_desc, gpu_allocator, cpu_allocator);
+ }
+};
+
+REGISTER_LOCAL_DEVICE_FACTORY("GPU", GPUDeviceFactory);
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
new file mode 100644
index 0000000000..29d6281733
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
@@ -0,0 +1,132 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
+
+#include "tensorflow/stream_executor/event.h"
+#include "tensorflow/stream_executor/stream.h"
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+EventMgr::EventMgr(gpu::StreamExecutor* se)
+ : exec_(se),
+ // threadpool_ has 1 thread for the polling loop, and one to execute
+ // event callback functions. Maybe we should have more?
+ threadpool_(Env::Default(), "GPU_Event_Manager", 2) {
+ threadpool_.Schedule([this]() { PollLoop(); });
+}
+
+EventMgr::~EventMgr() {
+ stop_polling_.Notify();
+ // Shut down the backup polling loop.
+ polling_stopped_.WaitForNotification();
+
+ // Events are owned by this object.
+ for (auto& e : free_events_) {
+ delete e;
+ }
+ while (!used_events_.empty()) {
+ delete used_events_[0].event;
+ delete used_events_[0].mem;
+ if (used_events_[0].bufrec.buf) {
+ used_events_[0].bufrec.alloc->DeallocateRaw(used_events_[0].bufrec.buf);
+ }
+ if (used_events_[0].func != nullptr)
+ threadpool_.Schedule(used_events_[0].func);
+ used_events_.pop_front();
+ }
+}
+
+// This polling loop runs at a relatively low frequency. Most calls to
+// PollEvents() should come directly from Compute() via
+// ThenDeleteTensors(). This function's purpose is to ensure that
+// even if no more GPU operations are being requested, we still
+// eventually clear the queue. It seems to prevent some tensorflow
+// programs from stalling for reasons not yet understood.
+void EventMgr::PollLoop() {
+ while (!stop_polling_.HasBeenNotified()) {
+ Env::Default()->SleepForMicroseconds(1 * 1000);
+ {
+ mutex_lock l(mu_);
+ PollEvents(true);
+ }
+ }
+ polling_stopped_.Notify();
+}
+
+void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu) {
+ VLOG(2) << "QueueInUse free_events_ " << free_events_.size()
+ << " used_events_ " << used_events_.size();
+ // Events are created on demand, and repeatedly reused. There is no
+ // limit placed here on the number of allocated Events.
+ if (free_events_.empty()) {
+ free_events_.push_back(new gpu::Event(exec_));
+ free_events_.back()->Init();
+ }
+ gpu::Event* e = free_events_.back();
+ free_events_.pop_back();
+ stream->ThenRecordEvent(e);
+ iu.event = e;
+ used_events_.push_back(iu);
+}
+
+// This function must be called periodically to check whether pending
+// events have recorded, and then retire them. Initial observations
+// suggest that typical behavior in a TensorFlow program is to have
+// 0-3 events pending most of the time, but there are occasionally
+// spikes of up to several hundred outstanding.
+//
+// NOTE: If all events are on the same stream, no later event will
+// complete before an earlier event, except possibly if the earlier
+// event transitions to an error state, so there's no advantage in
+// looking past the first kPending event. However, if we're using
+// multiple streams there may be some gain in looking deeper.
+// As a compromise, PollEvent() calls that are triggered by the queueing
+// of a single event never look past the first kPending event. Calls
+// coming from the dedicated polling thread always sweep the full queue.
+//
+// Note that allowing the queue to grow very long could cause overall
+// GPU memory use to spike needlessly. An alternative strategy would
+// be to throttle new Op execution until the pending event queue
+// clears.
+void EventMgr::PollEvents(bool is_dedicated_poller) {
+ VLOG(2) << "PollEvents free_events_ " << free_events_.size()
+ << " used_events_ " << used_events_.size();
+ // Sweep the remaining events in order. If this is the dedicated
+ // polling thread, check the entire set. Otherwise, just sweep up to
+ // the first non-complete record that is still pending.
+ for (auto& iu : used_events_) {
+ if (iu.event == nullptr) continue;
+ gpu::Event::Status s = iu.event->PollForStatus();
+ switch (s) {
+ case gpu::Event::Status::kUnknown:
+ case gpu::Event::Status::kError:
+ // We don't expect to see these. Someday maybe propagate
+ // a Status error, but for now fail hard.
+ LOG(FATAL) << "Unexpected Event status: " << static_cast<int>(s);
+ break;
+ case gpu::Event::Status::kPending:
+ if (!is_dedicated_poller) return; // quit processing queue
+ break;
+ case gpu::Event::Status::kComplete:
+ delete iu.mem;
+ if (iu.bufrec.buf) iu.bufrec.alloc->DeallocateRaw(iu.bufrec.buf);
+ // The function must be called in another thread, outside of
+ // the mutex held here.
+ if (iu.func != nullptr) threadpool_.Schedule(iu.func);
+ free_events_.push_back(iu.event);
+ // Mark this InUse record as completed.
+ iu.event = nullptr;
+ }
+ }
+ // Then clear any completed InUse records from the front of the queue.
+ while (!used_events_.empty()) {
+ InUse& iu = used_events_.front();
+ if (iu.event == nullptr) {
+ used_events_.pop_front();
+ } else {
+ break;
+ }
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
new file mode 100644
index 0000000000..f9436566d4
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
@@ -0,0 +1,118 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
+
+#include <deque>
+#include <vector>
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace perftools {
+namespace gputools {
+class Event;
+class Stream;
+class StreamExecutor;
+} // namespace gputools
+} // namespace perftools
+
+namespace tensorflow {
+
+// An object to keep track of pending Events in the StreamExecutor streams
+// and associated Tensors that cannot safely be deleted until the associated
+// Events are recorded.
+class EventMgr {
+ public:
+ explicit EventMgr(perftools::gputools::StreamExecutor* se);
+
+ ~EventMgr();
+
+ // Takes ownership of *tensors and deletes it as soon as all events
+ // currently enqueued on *stream have completed.
+ inline void ThenDeleteTensors(perftools::gputools::Stream* stream,
+ std::vector<Tensor>* tensors) {
+ mutex_lock l(mu_);
+ QueueTensors(stream, tensors);
+ PollEvents(false);
+ }
+
+ struct BufRec {
+ Allocator* alloc;
+ void* buf;
+ };
+
+ // Takes ownership of *bufrec.buf and calls bufrec.alloc->DeallocateRaw()
+ // on it as soon as all events currently enqueued on *stream have completed.
+ inline void ThenDeleteBuffer(perftools::gputools::Stream* stream,
+ BufRec bufrec) {
+ mutex_lock l(mu_);
+ QueueBuffer(stream, bufrec);
+ PollEvents(false);
+ }
+
+ inline void ThenExecute(perftools::gputools::Stream* stream,
+ std::function<void()> func) {
+ mutex_lock l(mu_);
+ QueueFunc(stream, func);
+ PollEvents(false);
+ }
+
+ private:
+ friend class TEST_EventMgrHelper;
+ mutex mu_;
+ perftools::gputools::StreamExecutor* exec_;
+
+ struct InUse {
+ perftools::gputools::Event* event;
+ std::vector<Tensor>* mem;
+ BufRec bufrec;
+ std::function<void()> func;
+ };
+
+ // Stream-enqueue an unused Event and save with it a collection of
+ // Tensors and/or a BufRec to be deleted only after the Event
+ // records.
+ void QueueInUse(perftools::gputools::Stream* stream, InUse in_use)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ void QueueTensors(perftools::gputools::Stream* stream,
+ std::vector<Tensor>* tensors)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ QueueInUse(stream, {nullptr, tensors, BufRec(), nullptr});
+ }
+
+ void QueueBuffer(perftools::gputools::Stream* stream, BufRec bufrec)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ QueueInUse(stream, {nullptr, nullptr, bufrec, nullptr});
+ }
+
+ void QueueFunc(perftools::gputools::Stream* stream,
+ std::function<void()> func) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ QueueInUse(stream, {nullptr, nullptr, BufRec(), func});
+ }
+
+ // This function should be called at roughly the same tempo as
+ // QueueTensors() to check whether pending events have recorded,
+ // and then retire them.
+ void PollEvents(bool is_dedicated_poller) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // An internal polling loop that runs at a low frequency to clear
+ // straggler Events.
+ void PollLoop();
+
+ // A stack of unused events
+ std::vector<perftools::gputools::Event*> free_events_ GUARDED_BY(mu_);
+
+ // A FIFO queue of InUse events and associated tensors.
+ std::deque<InUse> used_events_ GUARDED_BY(mu_);
+
+ Notification stop_polling_;
+ Notification polling_stopped_;
+
+ // The main PollLoop for the event manager runs in this threadpool.
+ thread::ThreadPool threadpool_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
new file mode 100644
index 0000000000..30ca1ff187
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
@@ -0,0 +1,152 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
+
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include <gtest/gtest.h>
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+class TEST_EventMgrHelper {
+ public:
+ explicit TEST_EventMgrHelper(EventMgr* em) : em_(em) {}
+
+ int queue_size() {
+ mutex_lock l(em_->mu_);
+ return em_->used_events_.size();
+ }
+
+ int free_size() {
+ mutex_lock l(em_->mu_);
+ return em_->free_events_.size();
+ }
+
+ void QueueTensors(perftools::gputools::Stream* stream,
+ std::vector<Tensor>* tensors) {
+ mutex_lock l(em_->mu_);
+ em_->QueueTensors(stream, tensors);
+ }
+
+ void PollEvents(bool is_dedicated_poller) {
+ mutex_lock l(em_->mu_);
+ em_->PollEvents(is_dedicated_poller);
+ }
+
+ private:
+ EventMgr* em_;
+};
+
+namespace {
+
+TEST(EventMgr, Empty) {
+ auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
+ EventMgr em(stream_exec);
+ TEST_EventMgrHelper th(&em);
+ EXPECT_EQ(0, th.queue_size());
+ EXPECT_EQ(0, th.free_size());
+}
+
+// Delaying polling until after several enqueings should grow the
+// total number of allocated events. Once we have enough events for
+// the max simultaneously pending, we should not allocate any more.
+TEST(EventMgr, DelayedPolling) {
+ auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
+ EventMgr em(stream_exec);
+ TEST_EventMgrHelper th(&em);
+ EXPECT_EQ(0, th.queue_size());
+ std::vector<Tensor>* v = nullptr;
+ std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
+ CHECK(stream.get());
+ stream->Init();
+ for (int i = 0; i < 5; ++i) {
+ v = new std::vector<Tensor>;
+ th.QueueTensors(stream.get(), v);
+ EXPECT_EQ(i + 1, th.queue_size());
+ EXPECT_EQ(0, th.free_size());
+ }
+ th.PollEvents(false);
+ EXPECT_EQ(0, th.queue_size());
+ EXPECT_EQ(5, th.free_size());
+ for (int j = 0; j < 2; ++j) {
+ for (int i = 0; i < 5; ++i) {
+ v = new std::vector<Tensor>;
+ th.QueueTensors(stream.get(), v);
+ EXPECT_EQ(i + 1, th.queue_size());
+ EXPECT_EQ(4 - i, th.free_size());
+ }
+ th.PollEvents(false);
+ EXPECT_EQ(0, th.queue_size());
+ EXPECT_EQ(5, th.free_size());
+ }
+}
+
+// Immediate polling should require only one event to be allocated.
+TEST(EventMgr, ImmediatePolling) {
+ auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
+ EventMgr em(stream_exec);
+ TEST_EventMgrHelper th(&em);
+ EXPECT_EQ(0, th.queue_size());
+ EXPECT_EQ(0, th.free_size());
+ std::vector<Tensor>* v = nullptr;
+ std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
+ CHECK(stream.get());
+ stream->Init();
+ for (int i = 0; i < 5; ++i) {
+ v = new std::vector<Tensor>;
+ em.ThenDeleteTensors(stream.get(), v);
+ EXPECT_EQ(0, th.queue_size());
+ EXPECT_EQ(1, th.free_size());
+ }
+}
+
+// If we delay polling by more than 1 second, the backup polling loop
+// should clear the queue.
+TEST(EventMgr, LongDelayedPolling) {
+ auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
+ EventMgr em(stream_exec);
+ TEST_EventMgrHelper th(&em);
+ EXPECT_EQ(0, th.queue_size());
+ EXPECT_EQ(0, th.free_size());
+ std::vector<Tensor>* v = nullptr;
+ std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
+ CHECK(stream.get());
+ stream->Init();
+ for (int i = 0; i < 5; ++i) {
+ v = new std::vector<Tensor>;
+ th.QueueTensors(stream.get(), v);
+ EXPECT_EQ(1 + i, th.queue_size());
+ EXPECT_EQ(0, th.free_size());
+ }
+ sleep(1);
+ EXPECT_EQ(0, th.queue_size());
+ EXPECT_EQ(5, th.free_size());
+}
+
+// Deleting the EventMgr when events are still pending should shut
+// down gracefully.
+TEST(EventMgr, NonEmptyShutdown) {
+ auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
+ EventMgr em(stream_exec);
+ TEST_EventMgrHelper th(&em);
+ EXPECT_EQ(0, th.queue_size());
+ EXPECT_EQ(0, th.free_size());
+ std::vector<Tensor>* v = nullptr;
+ std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
+ CHECK(stream.get());
+ stream->Init();
+ for (int i = 0; i < 5; ++i) {
+ v = new std::vector<Tensor>;
+ th.QueueTensors(stream.get(), v);
+ EXPECT_EQ(1 + i, th.queue_size());
+ EXPECT_EQ(0, th.free_size());
+ }
+}
+
+} // namespace
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_init.cc b/tensorflow/core/common_runtime/gpu/gpu_init.cc
new file mode 100644
index 0000000000..631a47eb91
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_init.cc
@@ -0,0 +1,147 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+
+#include <string>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+namespace {
+
+std::unique_ptr<std::map<std::pair<int, int>, bool>> GetPeerAccessMap(
+ gpu::Platform* platform, int device_count) {
+ auto* map = new std::map<std::pair<int, int>, bool>;
+ for (int i = 0; i < device_count; ++i) {
+ for (int j = 0; j < device_count; ++j) {
+ gpu::StreamExecutor* from = platform->ExecutorForDevice(i).ValueOrDie();
+ gpu::StreamExecutor* to = platform->ExecutorForDevice(j).ValueOrDie();
+ (*map)[{i, j}] = from->CanEnablePeerAccessTo(to);
+ }
+ }
+
+ return std::unique_ptr<std::map<std::pair<int, int>, bool>>{map};
+}
+
+Status EnablePeerAccess(gpu::Platform* platform, int device_count) {
+ for (int i = 0; i < device_count; ++i) {
+ for (int j = 0; j < device_count; ++j) {
+ gpu::StreamExecutor* from = platform->ExecutorForDevice(i).ValueOrDie();
+ gpu::StreamExecutor* to = platform->ExecutorForDevice(j).ValueOrDie();
+
+ if (from->CanEnablePeerAccessTo(to)) {
+ auto status = from->EnablePeerAccessTo(to);
+ if (!status.ok()) {
+ return errors::Internal(status.ToString());
+ }
+ } else {
+ LOG(INFO) << "cannot enable peer access from device ordinal " << i
+ << " to device ordinal " << j;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+static void InitGPU() {
+ auto result = gpu::MultiPlatformManager::PlatformWithName("CUDA");
+ if (!result.ok()) {
+ LOG(WARNING)
+ << "Not initializing the GPU, could not create GPU MachineManager. "
+ << "Error: " << result.status();
+ return;
+ }
+
+ gpu::Platform* platform = result.ValueOrDie();
+
+ int dev_count = platform->VisibleDeviceCount();
+
+ if (dev_count == 0) {
+ LOG(INFO) << "No GPU devices available on machine.";
+ return;
+ }
+
+ for (int i = 0; i < dev_count; ++i) {
+ auto stream_exec = platform->ExecutorForDevice(i).ValueOrDie();
+ int64 free_bytes;
+ int64 total_bytes;
+ if (!stream_exec->DeviceMemoryUsage(&free_bytes, &total_bytes)) {
+ // Logs internally on failure.
+ free_bytes = 0;
+ total_bytes = 0;
+ }
+ const auto& description = stream_exec->GetDeviceDescription();
+ int cc_major;
+ int cc_minor;
+ if (!description.cuda_compute_capability(&cc_major, &cc_minor)) {
+ // Logs internally on failure.
+ cc_major = 0;
+ cc_minor = 0;
+ }
+ LOG(INFO) << "Found device " << i << " with properties: "
+ << "\nname: " << description.name() << "\nmajor: " << cc_major
+ << " minor: " << cc_minor << " memoryClockRate (GHz) "
+ << description.clock_rate_ghz() << "\npciBusID "
+ << description.pci_bus_id() << "\nTotal memory: "
+ << strings::HumanReadableNumBytes(total_bytes)
+ << "\nFree memory: "
+ << strings::HumanReadableNumBytes(free_bytes);
+ }
+
+ // Enable peer access
+
+ auto status = EnablePeerAccess(platform, dev_count);
+ if (!status.ok()) {
+ LOG(FATAL) << "could not enable peer access for GPU devices: " << status;
+ }
+
+ // Print out a matrix showing which devices can DMA to one
+ // another.
+ auto access_map = GetPeerAccessMap(platform, dev_count);
+ string line_buf = "DMA: ";
+ for (int i = 0; i < dev_count; ++i) {
+ strings::StrAppend(&line_buf, i, " ");
+ }
+ LOG(INFO) << line_buf;
+ for (int i = 0; i < dev_count; ++i) {
+ line_buf = strings::StrCat(i, ": ");
+ for (int j = 0; j < dev_count; ++j) {
+ if ((*access_map)[{i, j}]) {
+ line_buf.append("Y ");
+ } else {
+ line_buf.append("N ");
+ }
+ }
+ LOG(INFO) << line_buf;
+ }
+}
+
+static bool InitModule() {
+ InitGPU();
+ return true;
+}
+
+} // namespace
+
+gpu::Platform* GPUMachineManager() {
+ // Create the machine manager singleton and initialize the GPUs only
+ // once.
+ static bool init = InitModule();
+ CHECK(init); // Avoids compiler warning that init is unused.
+
+ auto result = gpu::MultiPlatformManager::PlatformWithName("CUDA");
+ if (!result.ok()) {
+ return nullptr;
+ }
+
+ return result.ValueOrDie();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_init.h b/tensorflow/core/common_runtime/gpu/gpu_init.h
new file mode 100644
index 0000000000..d126a8b1ca
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_init.h
@@ -0,0 +1,19 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_INIT_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_INIT_H_
+
+namespace perftools {
+namespace gputools {
+class Platform;
+} // namespace gputools
+} // namespace perftools
+
+namespace tensorflow {
+
+// Returns the GPU machine manager singleton, creating it and
+// initializing the GPUs on the machine if needed the first time it is
+// called.
+perftools::gputools::Platform* GPUMachineManager();
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_INIT_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_region_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_region_allocator.cc
new file mode 100644
index 0000000000..08ff55e221
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_region_allocator.cc
@@ -0,0 +1,371 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_region_allocator.h"
+
+//#include "base/commandlineflags.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/lib/core/bits.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+#if defined(PLATFORM_GOOGLE)
+DEFINE_bool(brain_gpu_region_allocator_heap_check_on_destruction, true,
+ "If true, the CUDA gpu manager checks that all allocated "
+ "memory through the GPU memory pool implementation has been "
+ "freed.");
+
+DEFINE_int64(brain_gpu_region_allocator_region_size, 0,
+ "If > 0, sets the default chunk-size allocatable from GPU memory. "
+ "Else defaults to entire GPU memory.");
+
+#else
+bool FLAGS_brain_gpu_region_allocator_heap_check_on_destruction = true;
+tensorflow::int64 FLAGS_brain_gpu_region_allocator_region_size = 0;
+#endif
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+GPURegionAllocator::GPURegionAllocator(int device_id, size_t total_bytes)
+ : device_id_(device_id), total_bytes_(total_bytes) {
+ // Get a pointer to the stream_executor for this device
+ stream_exec_ = GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie();
+
+ // Set the region size based on explicit user request, or based on
+ // total GPU capacity.
+ if (FLAGS_brain_gpu_region_allocator_region_size > 0) {
+ region_size_ = FLAGS_brain_gpu_region_allocator_region_size;
+ } else {
+ region_size_ = static_cast<size_t>(total_bytes_);
+ }
+
+ LOG(INFO) << "Setting region size to " << region_size_;
+}
+
+GPURegionAllocator::~GPURegionAllocator() {
+ if (FLAGS_brain_gpu_region_allocator_heap_check_on_destruction) {
+ CheckForMemoryLeaks();
+ }
+
+ gtl::STLDeleteValues(&chunk_map_);
+
+ for (auto r : regions_) {
+ gpu::DeviceMemoryBase gpu_ptr{r->ptr};
+ stream_exec_->Deallocate(&gpu_ptr);
+ delete r;
+ }
+}
+
+void* GPURegionAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
+ static const int64 kMaxMillisToWait = 10000; // 10 seconds
+ return retry_helper_.AllocateRaw(
+ [this](size_t a, size_t nb, bool v) {
+ return AllocateRawInternal(a, nb, v);
+ },
+ kMaxMillisToWait, alignment, num_bytes);
+}
+
+void* GPURegionAllocator::AllocateRawInternal(size_t alignment,
+ size_t num_bytes,
+ bool dump_log_on_failure) {
+ if (num_bytes == 0) {
+ LOG(ERROR) << "tried to allocate 0 bytes";
+ return nullptr;
+ }
+ size_t chunk_size = ChunkSize(num_bytes);
+
+ VLOG(2) << "chunk_size " << chunk_size << " from num_bytes "
+ << strings::HumanReadableNumBytes(num_bytes);
+ mutex_lock l(lock_);
+ Pool* pool = &pools_[chunk_size];
+ if (pool->num_free == 0) {
+ if (!ExpandPool(pool, chunk_size, num_bytes, dump_log_on_failure)) {
+ if (dump_log_on_failure) {
+ LOG(WARNING) << "Out of GPU memory, see memory state dump above";
+ }
+ return nullptr;
+ }
+ }
+ CHECK_LT(0, pool->num_free);
+ CHECK(pool->first);
+ CHECK(pool->last);
+ Chunk* c = pool->first;
+ CHECK(c);
+ CHECK(!c->in_use);
+
+ c->in_use = true;
+ // Move c to the back of the queue.
+ if (c->next != nullptr) {
+ pool->first = c->next;
+ pool->first->prev = nullptr;
+ c->next = nullptr;
+ }
+
+ if (pool->last != c) {
+ pool->last->next = c;
+ c->prev = pool->last;
+ pool->last = c;
+ }
+ pool->num_free--;
+ pool->cumulative_malloced++;
+
+ void* rv = c->ptr;
+ c->bytes_allocated = num_bytes;
+
+ VLOG(2) << "new ptr " << rv;
+ return rv;
+}
+
+void GPURegionAllocator::DeallocateRaw(void* ptr) {
+ retry_helper_.DeallocateRaw([this](void* p) { DeallocateRawInternal(p); },
+ ptr);
+}
+
+void GPURegionAllocator::DeallocateRawInternal(void* ptr) {
+ VLOG(2) << "DeallocateRaw: " << ptr;
+ if (ptr == nullptr) {
+ LOG(ERROR) << "tried to deallocate nullptr";
+ return;
+ }
+
+ mutex_lock l(lock_);
+ ChunkMap::const_iterator iter = chunk_map_.find(ptr);
+ CHECK(iter != chunk_map_.end());
+
+ Chunk* c = iter->second;
+ VLOG(2) << "chunk of size " << c->size << " at " << c;
+
+ Pool* pool = &(pools_[c->size]);
+ // Move chunk to head of queue, and mark free.
+ DCHECK(c->in_use);
+ c->in_use = false;
+ if (c->prev) c->prev->next = c->next;
+ if (c->next) c->next->prev = c->prev;
+ if (pool->first == c) pool->first = c->next;
+ if (pool->last == c) pool->last = c->prev;
+ c->next = pool->first;
+ c->prev = nullptr;
+ if (c->next) c->next->prev = c;
+ pool->first = c;
+ if (pool->last == nullptr) pool->last = c;
+ pool->num_free++;
+ pool->cumulative_freed++;
+}
+
+bool GPURegionAllocator::ExpandPool(Pool* pool, size_t chunk_size,
+ size_t requested_size,
+ bool dump_log_on_failure) {
+ VLOG(1) << "ExpandPool of " << chunk_size << " from " << pool->num_chunks
+ << " current members";
+ DCHECK_NE(0, chunk_size);
+ // If chunk_size is < 4096, double the pool size. Otherwise
+ // just increase by one.
+ int num_chunks = pool->num_chunks;
+ if (num_chunks == 0) {
+ if (chunk_size > 4096) {
+ num_chunks = 1;
+ } else {
+ num_chunks = 4096 / chunk_size;
+ }
+ }
+ // For larger chunks, limit the amount of expansion.
+ size_t aggregate_size = num_chunks * chunk_size;
+ if (aggregate_size > (1 << 20)) {
+ num_chunks = static_cast<int>(
+ std::max(static_cast<size_t>(1), (1 << 20) / chunk_size));
+ }
+ while (num_chunks > 0) {
+ Region* r = (regions_.empty() ? nullptr : regions_.back());
+ if (r == nullptr ||
+ (((r->ptr + r->size) - r->next) < static_cast<int64>(chunk_size))) {
+ // Current region is not large enough to accommodate another chunk.
+ while (r == nullptr || (((r->ptr + r->size) - r->next) <
+ static_cast<int64>(chunk_size))) {
+ // Get another region.
+ size_t this_region_size = std::max(region_size_, chunk_size);
+
+ // Check if we would exceed our limit.
+ if (allocated_memory_ + this_region_size > total_bytes_) {
+ if (dump_log_on_failure) DumpMemoryLog();
+ return false;
+ }
+
+ // Perform the allocation, still checking that the allocator
+ // has not run out of memory.
+ gpu::DeviceMemory<char> gpu_mem =
+ stream_exec_->AllocateArray<char>(this_region_size);
+ if (gpu_mem == nullptr) {
+ if (dump_log_on_failure) DumpMemoryLog();
+ return false;
+ }
+
+ // We never release memory once expanded.
+ allocated_memory_ += this_region_size;
+
+ Region* nr = new Region;
+ nr->ptr = static_cast<char*>(gpu_mem.opaque());
+
+ if (VLOG_IS_ON(2)) {
+ int64 free_bytes;
+ int64 total_bytes;
+ if (stream_exec_->DeviceMemoryUsage(&free_bytes, &total_bytes)) {
+ VLOG(2) << "free " << free_bytes << " total " << total_bytes;
+ } else {
+ // Note: stream_exec call also logs internally on failure.
+ VLOG(2) << "could not retrieve memory usage";
+ }
+ }
+ VLOG(1) << "new Region of size " << this_region_size << " at "
+ << static_cast<void*>(nr->ptr) << " on device " << device_id_;
+ r = nr;
+ r->size = this_region_size;
+ r->next = r->ptr;
+ regions_.push_back(r);
+
+ for (auto visitor : region_visitors_) {
+ visitor(r->ptr, r->size);
+ }
+ }
+ } else {
+ // Allocate a new chunk and push on front of Pool.
+ Chunk* c = new Chunk;
+ c->ptr = r->next;
+ chunk_map_[c->ptr] = c;
+ c->size = chunk_size;
+ r->next += chunk_size;
+ c->next = pool->first;
+ if (c->next != nullptr) c->next->prev = c;
+ pool->first = c;
+ if (pool->last == nullptr) pool->last = c;
+ pool->num_chunks++;
+ pool->num_free++;
+ --num_chunks;
+ }
+ }
+
+ return true;
+}
+
+void GPURegionAllocator::CheckForMemoryLeaks() {
+ std::vector<string> errors;
+ mutex_lock l(lock_); // could use reader lock
+ for (auto pool_map : pools_) {
+ const Pool& p = pool_map.second;
+ Chunk* curr_chunk = p.first;
+ while (curr_chunk != nullptr) {
+ if (curr_chunk->in_use) {
+ errors.push_back(
+ strings::StrCat("Unfreed chunk of size ", curr_chunk->size));
+ }
+ curr_chunk = curr_chunk->next;
+ }
+ }
+ if (!errors.empty()) {
+ LOG(FATAL) << "GPU Memory leaks:\n" << str_util::Join(errors, "\n");
+ }
+}
+
+// Since there's no merging of chunks once allocated, we want to
+// maximize their reusablity (which argues for fewer, larger sizes),
+// while minimizing waste (which argues for tight-fitting sizes).
+//
+// The smallest unit of allocation is 256 bytes.
+// NOTE(tucker): akrizhevsky says that nvidia's memory manager always
+// aligns to 256 bytes, and doing so results in significant speedup.
+//
+// Up to 2^16 bytes we only allocate in powers of 2.
+//
+// Above that, we pick a max-waste which is the largest power
+// of 2 <= 1/16 of the requested size, then round up to the nearest
+// multiple of max_waste.
+//
+// static
+size_t GPURegionAllocator::ChunkSize(size_t bytes) {
+ if (bytes <= 256) {
+ return 256;
+ } else if (bytes <= (1 << 16)) {
+ return 1uLL << Log2Ceiling64(bytes);
+ } else {
+ // 1/16th of requested size
+ size_t max_waste = 1uLL << (Log2Ceiling64(bytes) - 4);
+ return (bytes + max_waste) & (~(max_waste - 1));
+ }
+}
+
+void GPURegionAllocator::AddAllocVisitor(Visitor visitor) {
+ VLOG(1) << "AddVisitor";
+ mutex_lock l(lock_);
+ region_visitors_.push_back(visitor);
+ for (auto region : regions_) {
+ visitor(region->ptr, region->size);
+ }
+}
+
+void GPURegionAllocator::DumpMemoryLog() {
+ size_t region_bytes = 0;
+ for (auto r : regions_) {
+ region_bytes += r->size;
+ }
+ size_t chunk_bytes = 0;
+ std::vector<size_t> chunk_sizes;
+ for (auto i : pools_) {
+ chunk_sizes.push_back(i.first);
+ }
+ std::sort(chunk_sizes.begin(), chunk_sizes.end());
+ for (auto i : chunk_sizes) {
+ int32 chunks_in_use = 0;
+ const Pool& p = pools_[i];
+ chunk_bytes += i * p.num_chunks;
+
+ if (p.num_chunks > 0) {
+ // Iterate backwards (allocated chunks are last).
+ Chunk* curr_chunk = p.last;
+ while (curr_chunk != nullptr) {
+ if (curr_chunk->in_use) {
+ ++chunks_in_use;
+ }
+ curr_chunk = curr_chunk->prev;
+ if (curr_chunk == p.first) {
+ break;
+ }
+ }
+ }
+
+ LOG(INFO) << "Chunk size: " << i << " ("
+ << strings::HumanReadableNumBytes(i) << ") Pool: " << p.ToString()
+ << "\nNumber of chunks: " << p.num_chunks
+ << ", in_use chunks: " << chunks_in_use;
+ }
+
+ LOG(INFO) << "Aggregate Region Memory: " << region_bytes << " ("
+ << strings::HumanReadableNumBytes(region_bytes) << ")";
+ LOG(INFO) << "Aggregate Chunk Memory: " << chunk_bytes << " ("
+ << strings::HumanReadableNumBytes(chunk_bytes) << ")";
+}
+
+bool GPURegionAllocator::TracksAllocationSizes() { return true; }
+
+size_t GPURegionAllocator::RequestedSize(void* ptr) {
+ mutex_lock l(lock_);
+ auto it = chunk_map_.find(ptr);
+ CHECK(it != chunk_map_.end())
+ << "Asked for requested size of pointer we never allocated: " << ptr;
+ auto c = it->second;
+ return c->bytes_allocated;
+}
+
+size_t GPURegionAllocator::AllocatedSize(void* ptr) {
+ mutex_lock l(lock_);
+ auto it = chunk_map_.find(ptr);
+ CHECK(it != chunk_map_.end())
+ << "Asked for allocated size of pointer we never allocated: " << ptr;
+ auto c = it->second;
+ return c->size;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_region_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_region_allocator.h
new file mode 100644
index 0000000000..1a250b6ede
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_region_allocator.h
@@ -0,0 +1,146 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_REGION_ALLOCATOR_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_REGION_ALLOCATOR_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_allocator_retry.h"
+#include "tensorflow/core/common_runtime/gpu/visitable_allocator.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace tensorflow {
+
+class GPURegionAllocator : public VisitableAllocator {
+ public:
+ // 'device_id' must be a valid device on the machine.
+ //
+ // total_bytes is how many bytes this allocator should allocate up
+ // to. This may be less than the total available.
+ explicit GPURegionAllocator(int device_id, size_t total_bytes);
+ ~GPURegionAllocator() override;
+
+ string Name() override { return "gpu_region"; }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override;
+ void DeallocateRaw(void* ptr) override;
+ void AddAllocVisitor(Visitor visitor) override;
+ // Does nothing, because regions are never freed.
+ void AddFreeVisitor(Visitor visitor) override {}
+
+ bool TracksAllocationSizes() override;
+ size_t RequestedSize(void* ptr) override;
+ size_t AllocatedSize(void* ptr) override;
+
+ private:
+ // A Chunk is the header on a single piece of memory given back
+ // in response to an AllocateRaw() call.
+ struct Chunk {
+ char* ptr; // pointer to granted GPU buffer.
+ size_t size; // Full size of GPU buffer.
+ size_t bytes_allocated; // Bytes asked for by client.
+ bool in_use;
+ Chunk* prev; // Used for chaining in pool.
+ Chunk* next;
+ Chunk()
+ : ptr(nullptr),
+ size(0),
+ bytes_allocated(0),
+ in_use(false),
+ prev(nullptr),
+ next(nullptr) {}
+ };
+
+ // A Pool is a collection of same-sized Chunks.
+ struct Pool {
+ int num_chunks; // total chunks in this pool
+ int num_free; // total free chunks in this pool
+ int64 cumulative_malloced; // number of chunks malloced so far
+ int64 cumulative_freed; // number of chunks freed so far
+
+ // double-linked ring of chunks; all free chunks precede all
+ // granted chunks
+ Chunk* first;
+ Chunk* last;
+ Pool()
+ : num_chunks(0),
+ num_free(0),
+ cumulative_malloced(0),
+ cumulative_freed(0),
+ first(nullptr),
+ last(nullptr) {}
+
+ string ToString() const {
+ return strings::StrCat("chunks: ", num_chunks, " free: ", num_free,
+ " cumulative malloc: ", cumulative_malloced,
+ " cumulative freed: ", cumulative_freed);
+ }
+ };
+
+ // A Region is a single area of GPU memory that has been
+ // reserved by this class and carved up into Chunks.
+ struct Region {
+ char* ptr; // base GPU ptr
+ char* next; // frontier of unused part of region
+ size_t size;
+ Region() : ptr(nullptr), size(0) {}
+ };
+
+ // Calculate size of chunk for an allocation of this size.
+ // Min chunk size is 16, for alignment.
+ // For larger sizes, we round up somewhat so there are fewer
+ // size-specific pools.
+ static size_t ChunkSize(size_t bytes);
+
+ void* AllocateRawInternal(size_t alignment, size_t num_bytes,
+ bool dump_log_on_failure);
+ void DeallocateRawInternal(void* ptr);
+
+ bool ExpandPool(Pool* p, size_t chunk_size, size_t requested_size,
+ bool dump_log_on_failure) EXCLUSIVE_LOCKS_REQUIRED(lock_);
+
+ // Inspects region maps and crashes with debug information if there
+ // are any memory leaks as detected by the region allocator.
+ void CheckForMemoryLeaks() LOCKS_EXCLUDED(lock_);
+
+ void DumpMemoryLog() EXCLUSIVE_LOCKS_REQUIRED(lock_);
+
+ perftools::gputools::StreamExecutor* stream_exec_; // Not owned.
+
+ typedef std::unordered_map<size_t, Pool> PoolMap;
+ typedef std::unordered_map<void*, Chunk*> ChunkMap;
+
+ GPUAllocatorRetry retry_helper_;
+ mutable mutex lock_;
+ PoolMap pools_ GUARDED_BY(lock_);
+
+ // Owns regions.
+ std::vector<Region*> regions_ GUARDED_BY(lock_);
+
+ // Maps from GPU ptr to Chunk owning it.
+ //
+ // Owns chunks.
+ ChunkMap chunk_map_ GUARDED_BY(lock_);
+
+ // Called once on each region, ASAP.
+ std::vector<Visitor> region_visitors_ GUARDED_BY(lock_);
+
+ const int device_id_;
+
+ // Total amount of memory (in bytes) available to this Allocator
+ const size_t total_bytes_;
+
+ // Total amount of memory allocated to regions.
+ size_t allocated_memory_ = 0;
+
+ size_t region_size_ = 0;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GPURegionAllocator);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_REGION_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_region_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_region_allocator_test.cc
new file mode 100644
index 0000000000..07b0dd57f6
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_region_allocator_test.cc
@@ -0,0 +1,71 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/common_runtime/gpu/gpu_region_allocator.h"
+
+#include <algorithm>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include <gtest/gtest.h>
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+namespace {
+
+TEST(GPURegionAllocatorTest, Simple) {
+ GPURegionAllocator a(0, 1 << 26);
+ std::vector<void*> ptrs;
+ for (int s = 1; s < 1024; s++) {
+ void* raw = a.AllocateRaw(1, s);
+ ptrs.push_back(raw);
+ }
+ std::sort(ptrs.begin(), ptrs.end());
+ for (int i = 0; i < ptrs.size(); i++) {
+ if (i > 0) {
+ CHECK_NE(ptrs[i], ptrs[i - 1]); // No dups
+ }
+ a.DeallocateRaw(ptrs[i]);
+ }
+ float* t1 = a.Allocate<float>(1024);
+ double* t2 = a.Allocate<double>(1048576);
+ a.Deallocate(t1);
+ a.Deallocate(t2);
+}
+
+TEST(GPURegionAllocatorTest, CheckMemLeak) {
+ EXPECT_DEATH(
+ {
+ GPURegionAllocator a(0, 1 << 26);
+ float* t1 = a.Allocate<float>(1024);
+ if (t1) {
+ LOG(INFO) << "Not deallocating";
+ }
+ },
+ "");
+}
+
+TEST(GPURegionAllocatorTest, TracksSizes) {
+ GPURegionAllocator a(0, 1 << 26);
+ EXPECT_EQ(true, a.TracksAllocationSizes());
+}
+
+TEST(GPURegionAllocatorTest, AllocatedVsRequested) {
+ GPURegionAllocator a(0, 1 << 26);
+ float* t1 = a.Allocate<float>(1);
+ EXPECT_EQ(sizeof(float), a.RequestedSize(t1));
+
+ // Minimum allocation size if 256
+ EXPECT_EQ(256, a.AllocatedSize(t1));
+
+ a.Deallocate(t1);
+}
+
+} // namespace
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util.cc b/tensorflow/core/common_runtime/gpu/gpu_stream_util.cc
new file mode 100644
index 0000000000..ca86c7fa06
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util.cc
@@ -0,0 +1,97 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h"
+
+#include <set>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace gpu_stream_util {
+
+Status AssignStreams(const Graph* graph, const AssignStreamsOpts& opts,
+ std::unordered_map<int, int>* node_to_stream_id) {
+ VLOG(1) << "AssignStreams";
+ Status status;
+
+ // Sanity check arguments.
+ if (graph == nullptr)
+ status.Update(errors::InvalidArgument("Bad graph argument supplied."));
+ if (node_to_stream_id == nullptr) {
+ status.Update(
+ errors::InvalidArgument("Bad node_to_stream_id argument supplied."));
+ }
+ if ((opts.max_streams < 1) || (opts.send_stream >= opts.max_streams) ||
+ (opts.recv_stream >= opts.max_streams) ||
+ (opts.const_stream >= opts.max_streams) ||
+ (opts.compute_stream >= opts.max_streams)) {
+ status.Update(errors::InvalidArgument("Bad graph argument supplied."));
+ }
+ TF_RETURN_IF_ERROR(status);
+
+ // Topologically sort the nodes.
+ std::vector<Node*> order;
+ GetReversePostOrder(*graph, &order);
+ if (VLOG_IS_ON(2)) {
+ for (Node* n : order) {
+ const int node_id = n->id();
+ VLOG(2) << "Node " << node_id << " " << n->type_string() << " "
+ << n->name() << " " << n->in_edges().size() << " inputs";
+ for (const Edge* e : n->in_edges()) {
+ VLOG(2) << " Edge from " << e->src()->id() << " " << e->src()->name()
+ << " fanout " << e->src()->out_edges().size();
+ }
+ }
+ }
+ // We perform stream assigmnent assuming a large number of
+ // stream IDs and then map these down to the required number of streams
+ // using simple round-robin.
+ // Stream Assignment strategy:
+ // 1. Nodes with zero inputs are always be executed on a
+ // fresh stream.
+ // 2. Try to execute a node on the same stream as one of its
+ // inputs to avoid inter-stream dependencies.
+ // 3. If any input comes from a node with a large fanout then
+ // perhaps an indication that it is shared between parallel
+ // streams of work. We choose a new stream here so that all consumers
+ // of the tensor are likely to run in parallel.
+ int highest_stream_id = -1;
+ for (Node* n : order) {
+ VLOG(3) << "Inspecting node " << n->DebugString();
+ const int node_id = n->id();
+ const string& op = n->type_string();
+
+ // Determine a suitable stream to use.
+ int stream_id = highest_stream_id + 1;
+ for (const Edge* e : n->in_edges()) {
+ const int fanout = e->src()->out_edges().size();
+ if (fanout == 1) {
+ stream_id = (*node_to_stream_id)[e->src()->id()];
+ break;
+ }
+ }
+ // Override stream for specific op types.
+ if (op == "_Send") {
+ if (opts.send_stream >= 0) stream_id = opts.send_stream;
+ } else if (op == "_Recv") {
+ if (opts.recv_stream >= 0) stream_id = opts.recv_stream;
+ } else if (op == "Const") {
+ if (opts.const_stream >= 0) stream_id = opts.const_stream;
+ } else {
+ if (opts.compute_stream >= 0) stream_id = opts.compute_stream;
+ }
+
+ (*node_to_stream_id)[node_id] = stream_id % opts.max_streams;
+ highest_stream_id = std::max(stream_id, highest_stream_id);
+ }
+ VLOG(1) << "Identified " << highest_stream_id << " candidate streams for "
+ << order.size() << " nodes.";
+
+ return Status::OK();
+}
+
+} // namespace gpu_stream_util
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util.h b/tensorflow/core/common_runtime/gpu/gpu_stream_util.h
new file mode 100644
index 0000000000..e1c623382c
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util.h
@@ -0,0 +1,30 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
+
+#include <unordered_map>
+
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+namespace gpu_stream_util {
+
+struct AssignStreamsOpts {
+ int32 max_streams = 1;
+ // The following options specify a stream to use for specific op
+ // types. The value -1 allows ops to be assigned to any stream.
+ int32 send_stream = -1;
+ int32 recv_stream = -1;
+ int32 const_stream = -1;
+ int32 compute_stream = -1;
+};
+
+// Given the input graph, assigns every node in the graph with a
+// stream_id that should be used.
+Status AssignStreams(const Graph* graph, const AssignStreamsOpts& opts,
+ std::unordered_map<int, int>* node_to_stream_id);
+
+} // namespace gpu_stream_util
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc b/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc
new file mode 100644
index 0000000000..5c426caaef
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc
@@ -0,0 +1,137 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace {
+
+class GpuStreamUtilTest : public OpsTestBase {
+ protected:
+ void SetUp() override { RequireDefaultOps(); }
+};
+
+TEST_F(GpuStreamUtilTest, BogusOpts) {
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Graph g(OpRegistry::Global());
+ ASSERT_OK(b.ToGraph(&g));
+ std::unordered_map<int, int> node_to_stream_id;
+ gpu_stream_util::AssignStreamsOpts opts;
+ Status status;
+ status = gpu_stream_util::AssignStreams(nullptr, opts, &node_to_stream_id);
+ EXPECT_FALSE(status.ok());
+ status = gpu_stream_util::AssignStreams(&g, opts, nullptr);
+ EXPECT_FALSE(status.ok());
+ opts.max_streams = 0;
+ status = gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id);
+ EXPECT_FALSE(status.ok());
+ opts.max_streams = 1;
+ opts.compute_stream = 5;
+ status = gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id);
+ EXPECT_FALSE(status.ok());
+}
+
+TEST_F(GpuStreamUtilTest, EmptyGraph) {
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Graph g(OpRegistry::Global());
+ ASSERT_OK(b.ToGraph(&g));
+ std::unordered_map<int, int> node_to_stream_id;
+ gpu_stream_util::AssignStreamsOpts opts;
+ ASSERT_OK(gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id));
+ EXPECT_EQ(2, node_to_stream_id.size()); // _SOURCE and _SINK
+}
+
+TEST_F(GpuStreamUtilTest, SimpleGraphOneStream) {
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::MatMul(ops::Const(Tensor(DT_FLOAT), b.opts()),
+ ops::Const(Tensor(DT_FLOAT), b.opts()), b.opts());
+ Graph g(OpRegistry::Global());
+ ASSERT_OK(b.ToGraph(&g));
+
+ std::unordered_map<int, int> node_to_stream_id;
+ gpu_stream_util::AssignStreamsOpts opts;
+ ASSERT_OK(gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id));
+
+ // There should be 5 nodes assigned.
+ EXPECT_EQ(5, node_to_stream_id.size());
+
+ // All of them should have stream 0.
+ for (const auto& it : node_to_stream_id) {
+ EXPECT_EQ(0, it.second);
+ }
+}
+
+TEST_F(GpuStreamUtilTest, SimpleGraphManyStreams) {
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::MatMul(ops::Const(Tensor(DT_FLOAT), b.opts()),
+ ops::Const(Tensor(DT_FLOAT), b.opts()), b.opts());
+ Graph g(OpRegistry::Global());
+ ASSERT_OK(b.ToGraph(&g));
+
+ std::unordered_map<int, int> node_to_stream_id;
+ gpu_stream_util::AssignStreamsOpts opts;
+ opts.max_streams = 3;
+ ASSERT_OK(gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id));
+
+ // There should be 5 nodes assigned.
+ EXPECT_EQ(5, node_to_stream_id.size());
+
+ // All of them should have a stream in the range [0..max_streams).
+ for (const auto& it : node_to_stream_id) {
+ EXPECT_GE(it.second, 0);
+ EXPECT_LT(it.second, opts.max_streams);
+ }
+}
+
+TEST_F(GpuStreamUtilTest, StreamOverrides) {
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::_Recv(DT_FLOAT, "input", "/cpu:0", 0, "/gpu:0",
+ b.opts().WithName("input"));
+ auto n = ops::MatMul(ops::Const(Tensor(DT_FLOAT), b.opts()),
+ ops::Const(Tensor(DT_FLOAT), b.opts()), b.opts());
+ ops::_Send(n, "output", "/gpu:0", 0, "/cpu:0", b.opts().WithName("output"));
+ Graph g(OpRegistry::Global());
+ ASSERT_OK(b.ToGraph(&g));
+
+ // Perform stream assignment using a large number of streams, but with
+ // op types constrained to specific streams.
+ std::unordered_map<int, int> node_to_stream_id;
+ gpu_stream_util::AssignStreamsOpts opts;
+ opts.max_streams = 100;
+ opts.const_stream = 90;
+ opts.send_stream = 91;
+ opts.recv_stream = 92;
+ opts.compute_stream = 93;
+ ASSERT_OK(gpu_stream_util::AssignStreams(&g, opts, &node_to_stream_id));
+
+ // There should be 7 nodes assigned.
+ EXPECT_EQ(7, node_to_stream_id.size()); // including _SOURCE and _SINK
+
+ // Nodes should be assigned to streams by op type.
+ for (const auto& it : node_to_stream_id) {
+ Node* n = g.FindNodeId(it.first);
+ const string op = n->type_string();
+ const int stream = it.second;
+ if (op == "Const") {
+ EXPECT_EQ(stream, 90);
+ } else if (op == "_Send") {
+ EXPECT_EQ(stream, 91);
+ } else if (op == "_Recv") {
+ EXPECT_EQ(stream, 92);
+ } else { // Compute.
+ EXPECT_EQ(stream, 93);
+ }
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc
new file mode 100644
index 0000000000..a6a3ce01fc
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc
@@ -0,0 +1,345 @@
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+
+//#include "base/commandlineflags.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/tensor_coding.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/common_runtime/gpu/dma_helper.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#include "tensorflow/core/util/util.h"
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+#include "tensorflow/core/platform/stream_executor_util.h"
+
+#if defined(PLATFORM_GOOGLE)
+DEFINE_int64(brain_gpu_util_debug_string_maxlen, 128,
+ "When dumping gpu memory, prints up to this many bytes.");
+
+DECLARE_bool(record_mem_types);
+#else
+tensorflow::int64 FLAGS_brain_gpu_util_debug_string_maxlen = 128;
+bool FLAGS_EXPERIMENTAL_brain_gpu_multi_stream = false;
+extern bool FLAGS_record_mem_types;
+#endif
+
+using perftools::gputools::DeviceMemoryBase;
+using perftools::gputools::DeviceMemory;
+using perftools::gputools::Stream;
+
+namespace tensorflow {
+
+namespace gpu = ::perftools::gputools;
+
+/*static*/
+void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev,
+ const DeviceContext* device_context,
+ TensorProto* proto, bool is_dead,
+ StatusCallback done) {
+ VLOG(1) << "SetProtoFromGPU device_context " << device_context;
+ // Tensor values need to be copied from GPU to CPU ram so that
+ // we can build the protobuf response for a RecvTensor RPC.
+ // "device context" identifies the stream where the _Send op executed.
+ CHECK(device_context);
+ gpu::Stream* stream =
+ static_cast<const GPUDeviceContext*>(device_context)->stream();
+
+ if (!DMAHelper::CanUseDMA(&tensor)) {
+ done(errors::Internal(strings::StrCat(
+ "GPU copy from non-DMA ", DataTypeString(tensor.dtype()), "tensor")));
+ return;
+ }
+ proto->set_dtype(tensor.dtype());
+ tensor.shape().AsProto(proto->mutable_tensor_shape());
+ // Prepare a Cord with the right data buf size, and DMA the
+ // data over from the GPU buffer. Note that 0-size tensors
+ // do not have a backing buffer.
+ const size_t num_bytes = is_dead ? 0 : tensor.TotalBytes();
+ if (num_bytes > 0) {
+ port::Tracing::ScopedAnnotation annotation("SetProtoFromGPU");
+ Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
+ char* mb = alloc->Allocate<char>(num_bytes);
+ const char* src_ptr =
+ reinterpret_cast<const char*>(DMAHelper::base(&tensor));
+ DeviceMemoryBase gpu_src_ptr(const_cast<char*>(src_ptr), num_bytes);
+ stream->ThenMemcpy(mb, gpu_src_ptr, num_bytes);
+ // Use of tensor may outlive stack scope, so keep a ref.
+ Tensor* tensor_ref = new Tensor(tensor);
+ dev->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
+ stream, [stream, done, proto, mb, num_bytes, alloc, tensor_ref]() {
+ if (!stream->ok()) {
+ done(errors::Internal("SetProtoFromGPU: GPU Memcpy failed"));
+ // TODO(pbar) We currently have no way to recover the
+ // worker from a GPU stream in the error state. Until
+ // there is a way to reset the CUDA driver, it is
+ // preferable to crash the process and restart. Tracked
+ // under b/23717097
+ LOG(FATAL) << "SetProtoFromGPU: GPU Memcpy failed";
+ return;
+ }
+ delete tensor_ref;
+ port::CopyFromArray(proto->mutable_tensor_content(), mb, num_bytes);
+ alloc->Deallocate<char>(mb);
+ done(Status::OK());
+ });
+ } else {
+ done(Status::OK());
+ }
+}
+
+typedef ProcessState::MemDesc PMD;
+
+/*static*/
+void GPUUtil::CopyViaDMA(const string& edge_name,
+ DeviceContext* send_dev_context,
+ DeviceContext* recv_dev_context, Device* src,
+ Device* dst, AllocatorAttributes src_alloc_attr,
+ AllocatorAttributes dst_alloc_attr,
+ const Tensor* input, Tensor* output,
+ StatusCallback done) {
+ port::Tracing::ScopedAnnotation annotation(edge_name);
+ VLOG(1) << "CopyViaDMA " << edge_name;
+ size_t total_bytes = input->TotalBytes();
+ // Note that 0-size tensors have no backing buffer.
+ if (total_bytes > 0) {
+ const void* src_ptr = DMAHelper::base(input);
+ void* dst_ptr = DMAHelper::base(output);
+ VLOG(2) << "src_ptr " << src_ptr << " dst_ptr " << dst_ptr;
+ if (FLAGS_record_mem_types) {
+ ProcessState::MemDesc smd = ProcessState::singleton()->PtrType(src_ptr);
+ ProcessState::MemDesc dmd = ProcessState::singleton()->PtrType(dst_ptr);
+ VLOG(0) << "Src " << smd.DebugString() << " Dst " << dmd.DebugString();
+ if (smd.loc == PMD::CPU && dmd.loc == PMD::GPU && (!smd.gpu_registered)) {
+ LOG(WARNING) << "CPU -> GPU no reg for " << edge_name;
+ }
+ if (dmd.loc == PMD::CPU && smd.loc == PMD::GPU && (!dmd.gpu_registered)) {
+ LOG(WARNING) << "GPU -> CPU no reg for " << edge_name;
+ }
+ }
+
+ auto src_device_type = src->attributes().device_type();
+ auto dst_device_type = dst->attributes().device_type();
+
+ bool non_cpu_src = (!src_alloc_attr.on_host() &&
+ src_device_type != DeviceType(DEVICE_CPU).type());
+ bool non_cpu_dst = (!dst_alloc_attr.on_host() &&
+ dst_device_type != DeviceType(DEVICE_CPU).type());
+ if (non_cpu_src) {
+ gpu::Stream* stream = send_dev_context->stream();
+ if (stream == nullptr) {
+ done(errors::Internal("Failed to find device stream"));
+ return;
+ }
+ auto* src_dev_info = src->tensorflow_gpu_device_info();
+ CHECK(src_dev_info);
+
+ if (non_cpu_dst) {
+ // Device to device copy
+ DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes);
+ stream->ThenMemcpy(
+ &gpu_dst_ptr,
+ DeviceMemoryBase{const_cast<void*>(src_ptr), total_bytes},
+ total_bytes);
+ if (dst_device_type == DeviceType(DEVICE_GPU).type()) {
+ // Use of input may outlive stack scope, so keep a ref.
+ Tensor* input_ref = new Tensor(*input);
+ src_dev_info->event_mgr->ThenExecute(
+ stream, [done, stream, input_ref]() {
+ delete input_ref;
+ if (!stream->ok()) {
+ done(errors::Internal("GPU->GPU Memcpy failed"));
+ } else {
+ done(Status::OK());
+ }
+ });
+ }
+ send_dev_context->MaintainLifetimeOnStream(input, stream);
+ } else {
+ // Device to host copy.
+ return send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src,
+ output, done);
+ }
+ } else if (non_cpu_dst) {
+ // Host to Device copy.
+ // Note that this is already an async copy.
+ recv_dev_context->CopyCPUTensorToDevice(input, dst, output, done);
+ } else {
+ memcpy(dst_ptr, src_ptr, total_bytes);
+ done(Status::OK());
+ }
+ } else {
+ // buffer is empty
+ done(Status::OK());
+ }
+}
+
+void GPUUtil::CopyGPUTensorToCPU(Device* gpu_device,
+ const DeviceContext* device_context,
+ const Tensor* gpu_tensor, Tensor* cpu_tensor,
+ StatusCallback done) {
+ VLOG(1) << "CopyGPUTensorToCPU";
+ size_t total_bytes = gpu_tensor->TotalBytes();
+ // Note that 0-size tensors have no backing buffer.
+ if (total_bytes > 0) {
+ const void* src_ptr = DMAHelper::base(gpu_tensor);
+ void* dst_ptr = DMAHelper::base(cpu_tensor);
+ CHECK(dst_ptr);
+ auto* stream = gpu_device->tensorflow_gpu_device_info()->stream;
+ if (device_context) {
+ stream = static_cast<const GPUDeviceContext*>(device_context)->stream();
+ }
+ stream->ThenMemcpy(
+ dst_ptr, DeviceMemoryBase{const_cast<void*>(src_ptr), total_bytes},
+ total_bytes);
+ stream->BlockHostUntilDone();
+ if (!stream->ok()) {
+ done(errors::Internal("CopyGPUTensorToCPU: GPU->CPU Memcpy failed"));
+ return;
+ }
+ }
+
+ done(Status::OK());
+}
+
+/* static */
+void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor,
+ const DeviceContext* device_context,
+ Device* gpu_device, Tensor* gpu_tensor,
+ StatusCallback done) {
+ VLOG(1) << "CopyCPUTensorToGPU";
+ CHECK(DeviceType(gpu_device->attributes().device_type()) ==
+ DeviceType(DEVICE_GPU));
+
+ auto* dev_info = gpu_device->tensorflow_gpu_device_info();
+ if (!dev_info) {
+ done(errors::Internal("Failed to find dest device GPUDeviceInfo"));
+ return;
+ }
+ if (cpu_tensor->TotalBytes() != gpu_tensor->TotalBytes()) {
+ done(errors::Internal(
+ strings::StrCat("Can't copy ", cpu_tensor->TotalBytes(),
+ " bytes of a tensor into another with ",
+ gpu_tensor->TotalBytes(), " bytes buffer.")));
+ return;
+ }
+ const int64 total_bytes = cpu_tensor->TotalBytes();
+ // Note that 0-size tensors have no backing buffer.
+ if (total_bytes > 0) {
+ const void* src_ptr = DMAHelper::base(cpu_tensor);
+ void* dst_ptr = DMAHelper::base(gpu_tensor);
+ DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes);
+
+ CHECK(device_context);
+ auto* stream =
+ static_cast<const GPUDeviceContext*>(device_context)->stream();
+ stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes);
+ auto* dev_info = gpu_device->tensorflow_gpu_device_info();
+ // Use of cpu_tensor may outlive stack scope, so keep a ref.
+ Tensor* input_ref = new Tensor(*cpu_tensor);
+ dev_info->event_mgr->ThenExecute(stream, [stream, done, input_ref]() {
+ delete input_ref;
+ if (!stream->ok()) {
+ done(errors::Internal("CopyCPUTensorToGPU: GPU Memcpy failed"));
+ } else {
+ done(Status::OK());
+ }
+ });
+ } else {
+ // empty tensor case
+ done(Status::OK());
+ }
+}
+
+Status GPUUtil::Sync(Device* gpu_device) {
+ VLOG(1) << "GPUUtil::Sync";
+ auto* dev_info = gpu_device->tensorflow_gpu_device_info();
+ if (!dev_info) {
+ return errors::Internal("Failed to find dest device GPUDeviceInfo");
+ }
+ dev_info->stream->BlockHostUntilDone();
+ if (!dev_info->stream->ok()) {
+ LOG(FATAL) << "GPU sync failed";
+ }
+ return Status::OK();
+}
+
+Status GPUUtil::SyncAll(Device* gpu_device) {
+ VLOG(1) << "GPUUtil::SyncAll";
+ auto* dev_info = gpu_device->tensorflow_gpu_device_info();
+ if (!dev_info) {
+ return errors::Internal("Failed to find dest device GPUDeviceInfo");
+ }
+ if (!dev_info->stream->parent()->SynchronizeAllActivity() ||
+ !dev_info->stream->ok()) {
+ LOG(FATAL) << "GPU sync failed";
+ }
+ return Status::OK();
+}
+
+string GPUUtil::MemoryDebugString(const Device* device, Tensor* tensor) {
+ string ret;
+ CHECK(tensor);
+ const int64 num_bytes = std::min<int64>(
+ FLAGS_brain_gpu_util_debug_string_maxlen, tensor->TotalBytes());
+ void* ptr = (num_bytes > 0) ? DMAHelper::base(tensor) : nullptr;
+ strings::Appendf(&ret, "%p:", ptr);
+ if (num_bytes > 0) {
+ auto* dev_info = device->tensorflow_gpu_device_info();
+ if (!dev_info) {
+ strings::StrAppend(
+ &ret, PrintMemory(reinterpret_cast<const char*>(ptr), num_bytes));
+ } else {
+ string buf;
+ buf.resize(num_bytes);
+ DeviceMemoryBase gpu_ptr(ptr, num_bytes);
+ Status s = dev_info->stream->parent()->SynchronousMemcpyD2H(
+ gpu_ptr, num_bytes, gtl::string_as_array(&buf));
+ strings::StrAppend(&ret,
+ PrintMemory(gtl::string_as_array(&buf), num_bytes));
+ }
+ }
+ return ret;
+}
+
+// TODO(pbar) Checksum is called from places without a valid device context.
+uint64 GPUUtil::Checksum(Device* gpu_device,
+ const DeviceContext* device_context,
+ const Tensor& tensor) {
+ Tensor copy(tensor.dtype(), tensor.shape());
+ Status s;
+ Notification n;
+ CopyGPUTensorToCPU(gpu_device, device_context, &tensor, &copy,
+ [&s, &n](Status status) {
+ s.Update(status);
+ n.Notify();
+ });
+ n.WaitForNotification();
+ CHECK(s.ok()) << s;
+ return Checksum(copy);
+}
+
+uint64 GPUUtil::Checksum(const Tensor& tensor) {
+ const float* fptr = reinterpret_cast<const float*>(DMAHelper::base(&tensor));
+ size_t num_bytes = tensor.TotalBytes();
+ size_t num_floats = num_bytes / sizeof(float);
+ for (size_t i = 0; i < num_floats; ++i) {
+ CHECK(!std::isnan(fptr[i])) << " i " << i;
+ }
+ // TODO(tucker): consider using crc32c instead.
+ return Hash64(reinterpret_cast<const char*>(DMAHelper::base(&tensor)),
+ tensor.TotalBytes(), 0);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.h b/tensorflow/core/common_runtime/gpu/gpu_util.h
new file mode 100644
index 0000000000..1d8c3a054d
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_util.h
@@ -0,0 +1,89 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_UTIL_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_UTIL_H_
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/common_runtime/gpu/dma_helper.h"
+#include "tensorflow/stream_executor/device_memory.h"
+
+#include "tensorflow/stream_executor/stream.h"
+
+namespace tensorflow {
+
+class RecvTensorResponse;
+class TensorProto;
+
+namespace gpu = ::perftools::gputools;
+
+class GPUUtil {
+ public:
+ // "tensor" is GPU-local. "dev" is the hosting GPU.
+ // "device_context" should be the context of the GPU "_Send" op
+ // which provides the Tensor.
+ // Sets all necessasry fields of "proto" by transferring value
+ // bytes from GPU to CPU RAM. "is_dead" indicates that the
+ // tensor is dead with an uninit value.
+ static void SetProtoFromGPU(const Tensor& tensor, Device* dev,
+ const DeviceContext* device_context,
+ TensorProto* proto, bool is_dead,
+ StatusCallback done);
+
+ // Copies "input" to "output" between devices accessible to the
+ // local process via some DMA-like method. "edge_name" is the name
+ // of the tensor being copied, for debugging purposes. Depending on
+ // the type of devices and memory in use, the copy may be performed
+ // synchronously or asynchronously. 'done' will be invoked only
+ // after the copy is actually complete.
+ static void CopyViaDMA(const string& edge_name,
+ DeviceContext* send_dev_context,
+ DeviceContext* recv_dev_context, Device* src,
+ Device* dst, const AllocatorAttributes src_alloc_attr,
+ const AllocatorAttributes dst_alloc_attr,
+ const Tensor* input, Tensor* output,
+ StatusCallback done);
+
+ // Copies the data in 'gpu_tensor' into 'cpu_tensor'.
+ // 'gpu_tensor''s backing memory must be on 'gpu_device' and
+ // 'cpu_tensor' must be allocated to be of the same size as
+ // 'gpu_tensor'. Synchronous: may block.
+ static void CopyGPUTensorToCPU(Device* gpu_device,
+ const DeviceContext* device_context,
+ const Tensor* gpu_tensor, Tensor* cpu_tensor,
+ StatusCallback done);
+
+ // Blocks until all operations queued on the stream associated with
+ // "gpu_device" at the time of the call have completed. Returns any
+ // error pending on the stream at completion.
+ static Status Sync(Device* gpu_device);
+
+ // Blocks until all operations queued on all streams associated with the
+ // corresponding GPU device at the time of call have completed.
+ // Returns any error pending on the stream at completion.
+ static Status SyncAll(Device* gpu_device);
+
+ // For debugging purpose, given a "device" and a "tensor" allocated
+ // on the device, return a string printing each byte in the tensor
+ // (up to a limit). "device" can be either a CPU or a GPU device.
+ static string MemoryDebugString(const Device* device, Tensor* tensor);
+
+ static perftools::gputools::DeviceMemory<float> AsGPUFloat(const Tensor& t);
+
+ // Computes a checksum over the contents of "tensor", which is allocated
+ // on "gpu_device".
+ static uint64 Checksum(Device* gpu_device,
+ const DeviceContext* device_context,
+ const Tensor& tensor);
+
+ // Computes a checksum over the contents of "tensor", which is allocated
+ // in local CPU RAM.
+ static uint64 Checksum(const Tensor& tensor);
+
+ static void CopyCPUTensorToGPU(const Tensor* cpu_tensor,
+ const DeviceContext* device_context,
+ Device* gpu_device, Tensor* gpu_tensor,
+ StatusCallback done);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_UTIL_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
new file mode 100644
index 0000000000..f1b1174a28
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
@@ -0,0 +1,24 @@
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#include "tensorflow/stream_executor/stream.h"
+
+namespace tensorflow {
+
+void GPUDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
+ Device* device,
+ Tensor* device_tensor,
+ StatusCallback done) const {
+ GPUUtil::CopyCPUTensorToGPU(cpu_tensor, this, device, device_tensor, done);
+}
+
+void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
+ const string& tensor_name,
+ Device* device, Tensor* cpu_tensor,
+ StatusCallback done) {
+ GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator.cc b/tensorflow/core/common_runtime/gpu/pool_allocator.cc
new file mode 100644
index 0000000000..52deb7fce2
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator.cc
@@ -0,0 +1,269 @@
+#include "tensorflow/core/common_runtime/gpu/pool_allocator.h"
+
+#include <errno.h>
+#include <strings.h>
+#include <sys/mman.h> // for munmap
+
+#include <map>
+
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+//#include "prodkernel/api/base/numa.h"
+
+namespace tensorflow {
+
+PoolAllocator::PoolAllocator(size_t pool_size_limit, bool auto_resize,
+ SubAllocator* allocator,
+ RoundUpInterface* size_rounder, string name)
+ : name_(name),
+ has_size_limit_(pool_size_limit > 0),
+ auto_resize_(auto_resize),
+ pool_size_limit_(pool_size_limit),
+ allocator_(allocator),
+ size_rounder_(size_rounder),
+ allocation_begun_(false) {
+ if (auto_resize) {
+ CHECK_LT(0, pool_size_limit)
+ << "size limit must be > 0 if auto_resize is true.";
+ }
+}
+
+PoolAllocator::~PoolAllocator() { Clear(); }
+
+namespace {
+// Pools contain Chunks allocatated from the underlying Allocator.
+// Chunk alignment is always on kPoolAlignment boundaries. Each Chunk
+// begins with a descriptor (ChunkPrefix) that gives its size and a
+// pointer to itself. The pointer returned to the user is just past
+// the ChunkPrefix. If the user asks for a larger alignment, we will
+// increase the size of the chunk, then adjust the returned user
+// pointer and also re-write the ChunkPrefix.chunk_ptr value
+// immediately before it. This way the Chunk address and size can be
+// recovered from the returned user pointer, regardless of alignment.
+// Note that this deferencing of the pointers means that we cannot
+// handle GPU memory, only CPU memory.
+struct ChunkPrefix {
+ size_t num_bytes;
+ void* chunk_ptr;
+};
+// kPoolAlignment cannot be less than the size of ChunkPrefix.
+static const int kPoolAlignment = sizeof(ChunkPrefix);
+
+void* PrepareChunk(void* chunk, size_t alignment, size_t num_bytes) {
+ ChunkPrefix* cp = reinterpret_cast<ChunkPrefix*>(chunk);
+ cp->num_bytes = num_bytes;
+ cp->chunk_ptr = chunk;
+ void* user_ptr = reinterpret_cast<void*>(cp + 1);
+ if (alignment > kPoolAlignment) {
+ // Move user_ptr forward to the first satisfying offset, and write
+ // chunk_ptr just before it.
+ size_t aligned_ptr = reinterpret_cast<size_t>(user_ptr) + alignment;
+ user_ptr = reinterpret_cast<void*>(aligned_ptr & ~(alignment - 1));
+ (reinterpret_cast<ChunkPrefix*>(user_ptr) - 1)->chunk_ptr = chunk;
+ }
+ // Safety check that user_ptr is always past the ChunkPrefix.
+ CHECK_GE(user_ptr, reinterpret_cast<ChunkPrefix*>(chunk) + 1);
+ return user_ptr;
+}
+
+ChunkPrefix* FindPrefix(void* user_ptr) {
+ ChunkPrefix* cp = reinterpret_cast<ChunkPrefix*>(user_ptr) - 1;
+ return reinterpret_cast<ChunkPrefix*>(cp->chunk_ptr);
+}
+} // namespace
+
+void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
+ if (!allocation_begun_) allocation_begun_ = true;
+ if (num_bytes == 0) return nullptr;
+
+ // If alignment is larger than kPoolAlignment, increase num_bytes so that we
+ // are guaranteed to be able to return an aligned ptr by advancing user_ptr
+ // without overrunning the end of the chunk.
+ if (alignment > kPoolAlignment) {
+ num_bytes += alignment;
+ }
+ num_bytes += sizeof(ChunkPrefix);
+ num_bytes = size_rounder_->RoundUp(num_bytes);
+ PtrRecord* pr = nullptr;
+ if (has_size_limit_) {
+ {
+ mutex_lock lock(mutex_);
+ auto iter = pool_.find(num_bytes);
+ if (iter == pool_.end()) {
+ allocated_count_++;
+ // Deliberately fall out of lock scope before
+ // calling the allocator. No further modification
+ // to the pool will be performed.
+ } else {
+ get_from_pool_count_++;
+ pr = iter->second;
+ RemoveFromList(pr);
+ pool_.erase(iter);
+ // Fall out of lock scope and do the result without the lock held.
+ }
+ }
+ }
+ if (pr != nullptr) {
+ void* r = pr->ptr;
+ delete pr;
+ return PrepareChunk(r, alignment, num_bytes);
+ } else {
+ void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes);
+ for (auto v : alloc_visitors_) {
+ v(ptr, num_bytes);
+ }
+ return PrepareChunk(ptr, alignment, num_bytes);
+ }
+}
+
+void PoolAllocator::DeallocateRaw(void* ptr) {
+ if (ptr == nullptr) return;
+ ChunkPrefix* cp = FindPrefix(ptr);
+ CHECK_LE((void*)cp, (void*)ptr);
+ if (!has_size_limit_ && !auto_resize_) {
+ for (auto v : free_visitors_) {
+ v(cp, cp->num_bytes);
+ }
+ allocator_->Free(cp, cp->num_bytes);
+ } else {
+ mutex_lock lock(mutex_);
+ ++put_count_;
+ while (pool_.size() >= pool_size_limit_) {
+ EvictOne();
+ }
+ PtrRecord* pr = new PtrRecord;
+ pr->num_bytes = cp->num_bytes;
+ pr->ptr = cp;
+ AddToList(pr);
+ pool_.insert(std::make_pair(cp->num_bytes, pr));
+ }
+}
+
+void PoolAllocator::Clear() {
+ if (has_size_limit_) {
+ mutex_lock lock(mutex_);
+ for (auto iter : pool_) {
+ PtrRecord* pr = iter.second;
+ for (auto v : free_visitors_) {
+ v(pr->ptr, pr->num_bytes);
+ }
+ allocator_->Free(pr->ptr, pr->num_bytes);
+ delete pr;
+ }
+ pool_.clear();
+ get_from_pool_count_ = 0;
+ put_count_ = 0;
+ allocated_count_ = 0;
+ evicted_count_ = 0;
+ lru_head_ = nullptr;
+ lru_tail_ = nullptr;
+ }
+}
+
+void PoolAllocator::RemoveFromList(PtrRecord* pr) {
+ if (pr->prev == nullptr) {
+ DCHECK_EQ(lru_head_, pr);
+ lru_head_ = nullptr;
+ } else {
+ pr->prev->next = pr->next;
+ }
+ if (pr->next == nullptr) {
+ DCHECK_EQ(lru_tail_, pr);
+ lru_tail_ = pr->prev;
+ } else {
+ pr->next->prev = pr->prev;
+ if (lru_head_ == nullptr) {
+ lru_head_ = pr->next;
+ }
+ }
+}
+
+void PoolAllocator::AddToList(PtrRecord* pr) {
+ pr->prev = nullptr;
+ if (lru_head_ == nullptr) {
+ CHECK(lru_tail_ == nullptr);
+ lru_tail_ = pr;
+ pr->next = nullptr;
+ } else {
+ pr->next = lru_head_;
+ pr->next->prev = pr;
+ }
+ lru_head_ = pr;
+}
+
+void PoolAllocator::EvictOne() {
+ DCHECK(lru_tail_ != nullptr);
+ PtrRecord* prec = lru_tail_;
+ RemoveFromList(prec);
+ auto iter = pool_.find(prec->num_bytes);
+ while (iter->second != prec) {
+ ++iter;
+ DCHECK(iter != pool_.end());
+ }
+ pool_.erase(iter);
+ for (auto v : free_visitors_) {
+ v(prec->ptr, prec->num_bytes);
+ }
+ allocator_->Free(prec->ptr, prec->num_bytes);
+ delete prec;
+ ++evicted_count_;
+ // Auto-resizing, and warning messages.
+ static const double kTolerable = 2e-3;
+ static const int kCheckInterval = 1000;
+ static const double kIncreaseFactor = 1.1;
+ static const int kMinPoolSize = 100;
+ if (0 == evicted_count_ % kCheckInterval) {
+ const double eviction_rate =
+ evicted_count_ / static_cast<double>(put_count_);
+ const int64 alloc_request_count = allocated_count_ + get_from_pool_count_;
+ const double alloc_rate =
+ allocated_count_ / static_cast<double>(alloc_request_count);
+ static int log_counter = 0;
+ // (counter increment not thread safe but it's just for logging, so we
+ // don't care).
+ bool should_log = ((log_counter++ % 10) == 0);
+ if (should_log) {
+ LOG(WARNING) << "PoolAllocator: After " << alloc_request_count
+ << " get requests, put_count=" << put_count_
+ << " evicted_count=" << evicted_count_
+ << " eviction_rate=" << eviction_rate
+ << " and unsatisfied allocation rate=" << alloc_rate;
+ }
+ if (auto_resize_ && (eviction_rate > kTolerable) &&
+ (alloc_rate > kTolerable)) {
+ size_t new_size_limit = (pool_size_limit_ < kMinPoolSize)
+ ? kMinPoolSize
+ : (kIncreaseFactor * pool_size_limit_);
+ if (should_log) {
+ LOG(INFO) << "Raising pool_size_limit_ from " << pool_size_limit_
+ << " to " << new_size_limit;
+ }
+ pool_size_limit_ = new_size_limit;
+ // Reset all the counters so that ratios are relative to new sizes
+ // at next test interval.
+ put_count_ = 0;
+ allocated_count_ = 0;
+ evicted_count_ = 0;
+ get_from_pool_count_ = 0;
+ }
+ }
+}
+
+void PoolAllocator::AddAllocVisitor(Visitor visitor) {
+ mutex_lock lock(mutex_);
+ CHECK(!allocation_begun_)
+ << "AddAllocVisitor may not be called after pool allocation "
+ << "has begun.";
+ alloc_visitors_.push_back(visitor);
+}
+
+void PoolAllocator::AddFreeVisitor(Visitor visitor) {
+ mutex_lock lock(mutex_);
+ CHECK(!allocation_begun_)
+ << "AddFreeVisitor may not be called after pool allocation "
+ << "has begun.";
+ free_visitors_.push_back(visitor);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator.h b/tensorflow/core/common_runtime/gpu/pool_allocator.h
new file mode 100644
index 0000000000..d10aabe88a
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator.h
@@ -0,0 +1,202 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_POOL_ALLOCATOR_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_POOL_ALLOCATOR_H_
+
+// Simple LRU pool allocators for various flavors of CPU RAM that
+// implement the VisitableAllocator interface. GPU memory is managed
+// by GPURegionAllocator.
+
+#include <atomic>
+#include <map>
+#include <memory>
+#include "tensorflow/core/lib/core/bits.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/common_runtime/gpu/visitable_allocator.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+namespace tensorflow {
+
+// Interface of an object that does the underlying alloc/free of memory.
+class SubAllocator {
+ public:
+ virtual ~SubAllocator() {}
+ virtual void* Alloc(size_t alignment, size_t num_bytes) = 0;
+ virtual void Free(void* ptr, size_t num_bytes) = 0;
+};
+
+// Interface of an object that rounds up integers.
+class RoundUpInterface {
+ public:
+ virtual ~RoundUpInterface() {}
+ virtual size_t RoundUp(size_t num_bytes) = 0;
+};
+
+// Size-limited pool of memory buffers obtained from a SubAllocator
+// instance. Pool eviction policy is LRU.
+class PoolAllocator : public VisitableAllocator {
+ public:
+ // "pool_size_limit" is the maximum number of returned, re-usable
+ // memory buffers to keep in the pool. If pool_size_limit == 0, the
+ // pool is effectively a thin wrapper around the allocator.
+ // If "auto_resize" is true, then the pool_size_limit will gradually
+ // be raised so that deallocations happen very rarely, if at all.
+ // Transitory start-up objects may deallocate, but the long-term
+ // working-set should not. Auto-resizing can raise pool_size_limit
+ // but will never lower it.
+ // "allocator" is the object that performs the underlying memory
+ // malloc/free operations. This object takes ownership of allocator.
+ PoolAllocator(size_t pool_size_limit, bool auto_resize,
+ SubAllocator* allocator, RoundUpInterface* size_rounder,
+ string name);
+ ~PoolAllocator() override;
+
+ string Name() override { return name_; }
+
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override;
+
+ void DeallocateRaw(void* ptr) override;
+
+ // REQUIRES: The following functions may only be called prior
+ // to the first Allocate*() call. Once allocation has begun, it is
+ // illegal to register another visitor.
+
+ void AddAllocVisitor(Visitor visitor) override;
+
+ void AddFreeVisitor(Visitor visitor) override;
+
+ // Allocate an unused memory region of size "num_bytes". Fetch from
+ // the pool if available, otherwise call allocator_.
+ void* Get(size_t num_bytes);
+
+ // Return a no-longer needed memory region to the pool. It is an error
+ // to deference "ptr" after this call. If the pool is full, the least
+ // recently used region will be deallocated.
+ void Put(void* ptr, size_t num_bytes);
+
+ // Reset the pool to empty.
+ void Clear();
+
+ // The following accessors permit monitoring the effectiveness of
+ // the pool at avoiding repeated malloc/frees on the underlying
+ // allocator. Read locks are not taken on the theory that value
+ // consistency with other threads is not important.
+
+ // Number of Get() requests satisfied from pool.
+ int64 get_from_pool_count() const NO_THREAD_SAFETY_ANALYSIS {
+ return get_from_pool_count_;
+ }
+ // Number of Put() requests.
+ int64 put_count() const NO_THREAD_SAFETY_ANALYSIS { return put_count_; }
+ // Number of Get() requests requiring a fresh allocation.
+ int64 allocated_count() const NO_THREAD_SAFETY_ANALYSIS {
+ return allocated_count_;
+ }
+ // Number of pool evictions.
+ int64 evicted_count() const NO_THREAD_SAFETY_ANALYSIS {
+ return evicted_count_;
+ }
+ // Current size limit.
+ size_t size_limit() const NO_THREAD_SAFETY_ANALYSIS {
+ return pool_size_limit_;
+ }
+
+ private:
+ struct PtrRecord {
+ void* ptr;
+ size_t num_bytes;
+ PtrRecord* prev;
+ PtrRecord* next;
+ };
+
+ // Remove "pr" from the double-linked LRU list.
+ void RemoveFromList(PtrRecord* pr) EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Add "pr" to the head of the double-linked LRU list.
+ void AddToList(PtrRecord* pr) EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Delete the least recently used record.
+ void EvictOne() EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ const string name_;
+ const bool has_size_limit_;
+ const bool auto_resize_;
+ size_t pool_size_limit_;
+ std::unique_ptr<SubAllocator> allocator_;
+ std::unique_ptr<RoundUpInterface> size_rounder_;
+ mutex mutex_;
+ std::multimap<const size_t, PtrRecord*> pool_ GUARDED_BY(mutex_);
+ PtrRecord* lru_head_ GUARDED_BY(mutex_) = nullptr;
+ PtrRecord* lru_tail_ GUARDED_BY(mutex_) = nullptr;
+ int64 get_from_pool_count_ GUARDED_BY(mutex_) = 0;
+ int64 put_count_ GUARDED_BY(mutex_) = 0;
+ int64 allocated_count_ GUARDED_BY(mutex_) = 0;
+ int64 evicted_count_ GUARDED_BY(mutex_) = 0;
+ // Write access to these is guarded by mutex_, but not read
+ // access. They may only be modified prior to the first
+ // allocation. Later attempts to modify will fail.
+ std::vector<Visitor> alloc_visitors_;
+ std::vector<Visitor> free_visitors_;
+ std::atomic<bool> allocation_begun_;
+};
+
+// Do-nothing rounder. Passes through sizes unchanged.
+class NoopRounder : public RoundUpInterface {
+ public:
+ size_t RoundUp(size_t num_bytes) override { return num_bytes; }
+};
+
+// Power of 2 rounder: rounds up to nearest power of 2 size.
+class Pow2Rounder : public RoundUpInterface {
+ public:
+ size_t RoundUp(size_t num_bytes) override {
+ return 1uLL << Log2Ceiling64(num_bytes);
+ }
+};
+
+class BasicCPUAllocator : public SubAllocator {
+ public:
+ ~BasicCPUAllocator() override {}
+
+ void* Alloc(size_t alignment, size_t num_bytes) override {
+ return port::aligned_malloc(num_bytes, alignment);
+ }
+ void Free(void* ptr, size_t num_bytes) override { free(ptr); }
+};
+
+// Allocator for pinned CPU RAM that is made known to CUDA for the
+// purpose of efficient DMA with a GPU.
+class CUDAHostAllocator : public SubAllocator {
+ public:
+ // Note: stream_exec cannot be null.
+ explicit CUDAHostAllocator(perftools::gputools::StreamExecutor* stream_exec)
+ : stream_exec_(stream_exec) {
+ CHECK(stream_exec_ != nullptr);
+ }
+ ~CUDAHostAllocator() override {}
+
+ void* Alloc(size_t alignment, size_t num_bytes) override {
+ void* ptr = nullptr;
+ if (num_bytes > 0) {
+ ptr = stream_exec_->HostMemoryAllocate(num_bytes);
+ if (ptr == nullptr) {
+ LOG(FATAL) << "could not allocate pinned host memory of size: "
+ << num_bytes;
+ }
+ }
+ return ptr;
+ }
+
+ void Free(void* ptr, size_t num_bytes) override {
+ if (ptr != nullptr) {
+ stream_exec_->HostMemoryDeallocate(ptr);
+ }
+ }
+
+ private:
+ perftools::gputools::StreamExecutor* stream_exec_; // not owned, non-null
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CUDAHostAllocator);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_POOL_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
new file mode 100644
index 0000000000..ca409b2b4c
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
@@ -0,0 +1,203 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/common_runtime/gpu/pool_allocator.h"
+
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/platform.h"
+#include <gtest/gtest.h>
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+namespace {
+
+TEST(PoolAllocatorTest, ZeroSizeBuffers) {
+ gpu::Platform* platform =
+ gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
+ PoolAllocator pool(
+ 2 /*pool_size_limit*/, false /*auto_resize*/,
+ new CUDAHostAllocator(
+ platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0))
+ .ValueOrDie()),
+ new NoopRounder, "pool");
+
+ EXPECT_EQ(nullptr, pool.AllocateRaw(4 /*alignment*/, 0 /*num_bytes*/));
+ pool.DeallocateRaw(nullptr); // Should not crash.
+ EXPECT_EQ(0, pool.get_from_pool_count());
+ EXPECT_EQ(0, pool.put_count());
+ EXPECT_EQ(0, pool.allocated_count());
+ EXPECT_EQ(0, pool.evicted_count());
+}
+
+TEST(PoolAllocatorTest, ZeroSizePool) {
+ gpu::Platform* platform =
+ gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
+ PoolAllocator pool(
+ 0 /*pool_size_limit*/, false /*auto_resize*/,
+ new CUDAHostAllocator(
+ platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0))
+ .ValueOrDie()),
+ new NoopRounder, "pool");
+
+ EXPECT_EQ(0, pool.get_from_pool_count());
+ EXPECT_EQ(0, pool.put_count());
+ EXPECT_EQ(0, pool.allocated_count());
+ EXPECT_EQ(0, pool.evicted_count());
+
+ // All allocations should bypass the pool and return valid pointers.
+ for (int i = 0; i < 3; ++i) {
+ void* p0 = pool.AllocateRaw(4, 0);
+ void* p4 = pool.AllocateRaw(4, 4);
+ void* p12 = pool.AllocateRaw(4, 12);
+ EXPECT_EQ(nullptr, p0);
+ EXPECT_NE(nullptr, p4);
+ EXPECT_NE(nullptr, p12);
+ pool.DeallocateRaw(p0);
+ pool.DeallocateRaw(p4);
+ pool.DeallocateRaw(p12);
+ }
+ EXPECT_EQ(0, pool.get_from_pool_count());
+ EXPECT_EQ(0, pool.put_count());
+ EXPECT_EQ(0, pool.allocated_count());
+ EXPECT_EQ(0, pool.evicted_count());
+}
+
+TEST(PoolAllocatorTest, Alignment) {
+ gpu::Platform* platform =
+ gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
+ PoolAllocator pool(
+ 0 /*pool_size_limit*/, false /*auto_resize*/,
+ new CUDAHostAllocator(
+ platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0))
+ .ValueOrDie()),
+ new NoopRounder, "pool");
+ for (int i = 0; i < 16; ++i) {
+ size_t alignment = 1 << i;
+ void* p = pool.AllocateRaw(alignment, 111);
+ EXPECT_TRUE(p != nullptr);
+ EXPECT_EQ(0, reinterpret_cast<int64>(p) & (alignment - 1))
+ << "ptr: " << p << " alignment " << alignment;
+ // Intentionally don't deallocate, to test that destruction of
+ // the PoolAllocator frees all pending memory.
+ }
+}
+
+TEST(PoolAllocatorTest, AutoResize) {
+ PoolAllocator pool(2 /*pool_size_limit*/, true /*auto_resize*/,
+ new BasicCPUAllocator, new NoopRounder, "pool");
+
+ // Alloc/dealloc 10 sizes just a few times, confirming pool size
+ // stays at 2.
+ for (int i = 0; i < 10; ++i) {
+ void* p = pool.AllocateRaw(4, 64 << i);
+ pool.DeallocateRaw(p);
+ }
+ EXPECT_EQ(0, pool.get_from_pool_count());
+ EXPECT_EQ(10, pool.allocated_count());
+ EXPECT_EQ(10, pool.put_count());
+ EXPECT_EQ(8, pool.evicted_count());
+ EXPECT_EQ(2, pool.size_limit());
+
+ // Then repeat 1200 times. Pool size limit should jump to 100.
+ for (int j = 0; j < 120; ++j) {
+ for (int i = 0; i < 10; ++i) {
+ void* p = pool.AllocateRaw(4, 64 << i);
+ pool.DeallocateRaw(p);
+ }
+ }
+ EXPECT_EQ(100, pool.size_limit());
+}
+
+TEST(PoolAllocatorTest, CudaHostAllocator) {
+ gpu::Platform* platform =
+ gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
+ PoolAllocator pool(
+ 2 /*pool_size_limit*/, false /*auto_resize*/,
+ new CUDAHostAllocator(
+ platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0))
+ .ValueOrDie()),
+ new NoopRounder, "pool");
+
+ // Repeatedly Get a 16-byte value, confirming that there's only
+ // one real allocation.
+ void* p1_16 = pool.AllocateRaw(4, 16);
+ EXPECT_EQ(0, pool.get_from_pool_count());
+ EXPECT_EQ(1, pool.allocated_count());
+ EXPECT_NE(nullptr, p1_16);
+ pool.DeallocateRaw(p1_16);
+ // Pool contents {16}
+ EXPECT_EQ(1, pool.put_count());
+ void* p2_16 = pool.AllocateRaw(4, 16); // Get it again.
+ EXPECT_EQ(1, pool.get_from_pool_count());
+ EXPECT_EQ(1, pool.allocated_count());
+ EXPECT_EQ(p1_16, p2_16); // Same pointer value
+ pool.DeallocateRaw(p2_16); // Put it back.
+ // Pool contents {16}
+ EXPECT_EQ(2, pool.put_count());
+
+ // Get two more values of different sizes.
+ void* p3_4 = pool.AllocateRaw(4, 4);
+ EXPECT_EQ(2, pool.allocated_count());
+ EXPECT_NE(p1_16, p3_4); // Different pointer value
+ EXPECT_NE(nullptr, p3_4);
+ pool.DeallocateRaw(p3_4); // Put it back. Pool is now full.
+ // Pool contents {4, 16}
+ EXPECT_EQ(3, pool.put_count());
+ void* p4_2 = pool.AllocateRaw(4, 2); // Get a third size buffer.
+ EXPECT_NE(nullptr, p4_2);
+ EXPECT_EQ(0, pool.evicted_count());
+
+ // The pool is full: when we put back p4_2, the 16-byte buffer
+ // should be evicted since it was least recently inserted.
+ pool.DeallocateRaw(p4_2);
+ // Pool contents {2, 4}
+ EXPECT_EQ(4, pool.put_count());
+ EXPECT_EQ(1, pool.evicted_count());
+
+ // Re-getting and putting size 2 or 4 should not alter pool size or
+ // num-evicted.
+ void* p5_4 = pool.AllocateRaw(4, 4);
+ EXPECT_NE(nullptr, p5_4);
+ pool.DeallocateRaw(p5_4);
+ void* p6_2 = pool.AllocateRaw(4, 2);
+ EXPECT_NE(nullptr, p6_2);
+ pool.DeallocateRaw(p6_2);
+ EXPECT_EQ(3, pool.get_from_pool_count());
+ EXPECT_EQ(6, pool.put_count());
+ EXPECT_EQ(3, pool.allocated_count());
+ EXPECT_EQ(1, pool.evicted_count());
+
+ pool.Clear();
+ EXPECT_EQ(0, pool.get_from_pool_count());
+ EXPECT_EQ(0, pool.put_count());
+ EXPECT_EQ(0, pool.allocated_count());
+ EXPECT_EQ(0, pool.evicted_count());
+}
+
+TEST(PoolAllocatorTest, Pow2Rounder) {
+ Pow2Rounder rounder;
+ EXPECT_EQ(1, rounder.RoundUp(1));
+ EXPECT_EQ(2, rounder.RoundUp(2));
+ EXPECT_EQ(16, rounder.RoundUp(9));
+ EXPECT_EQ(16, rounder.RoundUp(16));
+ EXPECT_EQ(65536, rounder.RoundUp(41234));
+ EXPECT_EQ(65536, rounder.RoundUp(65535));
+ EXPECT_EQ(65536, rounder.RoundUp(65536));
+}
+
+TEST(PoolAllocatorTest, Name) {
+ gpu::Platform* platform =
+ gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
+ PoolAllocator pool(
+ 2 /*pool_size_limit*/, false /*auto_resize*/,
+ new CUDAHostAllocator(
+ platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0))
+ .ValueOrDie()),
+ new NoopRounder, "pool");
+ EXPECT_EQ("pool", pool.Name());
+}
+
+} // namespace
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc
new file mode 100644
index 0000000000..70ac6130c2
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/process_state.cc
@@ -0,0 +1,220 @@
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_region_allocator.h"
+#include "tensorflow/core/common_runtime/gpu/pool_allocator.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+
+#if defined(PLATFORM_GOOGLE)
+DEFINE_bool(record_mem_types, false,
+ "If true, record attributes of memory allocations and "
+ "dyanmically check for appropriate use of registered memory."
+ "Should only be true for debugging or diagnosis of "
+ "performance issues.");
+DEFINE_bool(brain_mem_reg_cuda_dma, true,
+ "If true, register CPU RAM used to copy to/from GPU RAM "
+ "with the CUDA driver.");
+DEFINE_bool(brain_gpu_use_bfc_allocator, false,
+ "If true, uses the Best-Fit GPU allocator.");
+DEFINE_bool(brain_gpu_region_allocator_debug, false,
+ "If true, checks for memory overwrites by writing "
+ "distinctive patterns on both ends of allocated memory.");
+DEFINE_bool(brain_gpu_region_allocator_reset_to_nan, false,
+ "If true, initializes all new Malloc buffers to NaN, "
+ "and resets the buffer to NaN upon Free.");
+
+#else
+bool FLAGS_record_mem_types = false;
+bool FLAGS_brain_mem_reg_cuda_dma = true;
+bool FLAGS_brain_gpu_region_allocator_debug = false;
+bool FLAGS_brain_gpu_region_allocator_reset_to_nan = false;
+bool FLAGS_brain_gpu_use_bfc_allocator = false;
+#endif
+
+namespace gpu = ::perftools::gputools;
+
+namespace tensorflow {
+
+ProcessState* ProcessState::instance_ = nullptr;
+
+/*static*/ ProcessState* ProcessState::singleton() {
+ if (instance_ == nullptr) {
+ instance_ = new ProcessState;
+ }
+
+ return instance_;
+}
+
+ProcessState::ProcessState() : gpu_count_(0) {
+ CHECK(instance_ == nullptr);
+ instance_ = this;
+}
+
+ProcessState::~ProcessState() {
+ for (auto p : gpu_allocators_) {
+ delete p;
+ }
+ instance_ = nullptr;
+}
+
+string ProcessState::MemDesc::DebugString() {
+ return strings::StrCat((loc == CPU ? "CPU " : "GPU "), dev_index, ", dma: ",
+ gpu_registered, ", nic: ", nic_registered);
+}
+
+ProcessState::MemDesc ProcessState::PtrType(const void* ptr) {
+ if (FLAGS_record_mem_types) {
+ auto iter = mem_desc_map_.find(ptr);
+ if (iter != mem_desc_map_.end()) {
+ return iter->second;
+ }
+ }
+ return MemDesc();
+}
+
+void ProcessState::SetGPUCount(int c) {
+ CHECK(gpu_count_ == 0 || gpu_count_ == c)
+ << "Cannot call SetGPUCount with a non-zero value "
+ << "not equal to prior set value.";
+ gpu_count_ = c;
+}
+
+int ProcessState::GPUCount() const { return gpu_count_; }
+
+Allocator* ProcessState::GetGPUAllocator(int gpu_id, size_t total_bytes) {
+#if GOOGLE_CUDA
+ mutex_lock lock(mu_);
+ gpu::Platform* gpu_platform = GPUMachineManager();
+
+ // Verify that gpu_id is legitimate.
+ CHECK_LT(gpu_id, gpu_platform->VisibleDeviceCount())
+ << "gpu_id is outside discovered device range";
+
+ if (gpu_id >= static_cast<int64>(gpu_allocators_.size())) {
+ gpu_allocators_.resize(gpu_id + 1);
+ if (FLAGS_record_mem_types) gpu_al_.resize(gpu_id + 1);
+ }
+
+ if (gpu_allocators_[gpu_id] == nullptr) {
+ VisitableAllocator* gpu_allocator;
+
+ if (FLAGS_brain_gpu_use_bfc_allocator) {
+ gpu_allocator = new GPUBFCAllocator(gpu_id, total_bytes);
+ } else {
+ gpu_allocator = new GPURegionAllocator(gpu_id, total_bytes);
+ }
+
+ if (FLAGS_brain_gpu_region_allocator_debug) {
+ gpu_allocator = new GPUDebugAllocator(gpu_allocator, gpu_id);
+ }
+ if (FLAGS_brain_gpu_region_allocator_reset_to_nan) {
+ gpu_allocator = new GPUNanResetAllocator(gpu_allocator, gpu_id);
+ }
+
+ gpu_allocators_[gpu_id] = gpu_allocator;
+
+ // If there are any pending AllocVisitors for this bus, add
+ // them now.
+ gpu::StreamExecutor* se =
+ gpu_platform->ExecutorForDevice(gpu_id).ValueOrDie();
+ int bus_id = se->GetDeviceDescription().numa_node();
+ if (bus_id < static_cast<int64>(gpu_visitors_.size())) {
+ for (auto v : gpu_visitors_[bus_id]) {
+ gpu_allocators_[gpu_id]->AddAllocVisitor(v);
+ }
+ }
+ if (FLAGS_record_mem_types) {
+ MemDesc md;
+ md.loc = MemDesc::GPU;
+ md.dev_index = gpu_id;
+ md.gpu_registered = false;
+ md.nic_registered = true;
+ if (static_cast<int64>(gpu_al_.size()) <= gpu_id)
+ gpu_al_.resize(gpu_id + 1);
+ gpu_al_[gpu_id] = new internal::RecordingAllocator(
+ &mem_desc_map_, gpu_allocators_[gpu_id], md, &mu_);
+ }
+ }
+ if (FLAGS_record_mem_types) return gpu_al_[gpu_id];
+ return gpu_allocators_[gpu_id];
+#else
+ LOG(FATAL) << "GPUAllocator unavailable. Not compiled with --config=cuda.";
+ return nullptr;
+#endif // GOOGLE_CUDA
+}
+
+Allocator* ProcessState::GetCPUAllocator(int numa_node) {
+ // Although we're temporarily ignoring numa_node, check for legality.
+ CHECK_GE(numa_node, 0);
+ // TODO(tucker): actually maintain separate CPUAllocators for
+ // different numa_nodes. For now, just one.
+ numa_node = 0;
+ mutex_lock lock(mu_);
+ while (cpu_allocators_.size() <= static_cast<size_t>(numa_node)) {
+ cpu_allocators_.push_back(new PoolAllocator(
+ 100 /*pool_size_limit*/, true /*auto_resize*/, new BasicCPUAllocator(),
+ new NoopRounder, "cpu_pool"));
+ }
+ return cpu_allocators_[0];
+}
+
+Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) {
+ if (gpu_count_ == 0 || !FLAGS_brain_mem_reg_cuda_dma) {
+ return GetCPUAllocator(numa_node);
+ }
+ // Although we're temporarily ignoring numa_node, check for legality.
+ CHECK_GE(numa_node, 0);
+ // TODO(tucker): actually maintain separate CPUAllocators for
+ // different numa_nodes. For now, just one.
+ numa_node = 0;
+ mutex_lock lock(mu_);
+ while (static_cast<int>(cuda_host_allocators_.size()) <= numa_node) {
+ // CUDAHost alloc the same across all gpus, so just get the
+ // executor for the first device.
+ gpu::Platform* gpu_platform = GPUMachineManager();
+ gpu::StreamExecutor* se = gpu_platform->ExecutorForDevice(0).ValueOrDie();
+ CHECK(se);
+ cuda_host_allocators_.push_back(new PoolAllocator(
+ 100 /*pool_size_limit*/, true /*auto_resize*/,
+ new CUDAHostAllocator(se), new Pow2Rounder, "cuda_host"));
+ if (FLAGS_record_mem_types) {
+ MemDesc md;
+ md.loc = MemDesc::CPU;
+ md.dev_index = 0;
+ md.gpu_registered = true;
+ md.nic_registered = false;
+ cuda_al_.push_back(new internal::RecordingAllocator(
+ &mem_desc_map_, cuda_host_allocators_.back(), md, &mu_));
+ }
+ }
+ if (FLAGS_record_mem_types) return cuda_al_[0];
+ return cuda_host_allocators_[0];
+}
+
+void ProcessState::AddGPUAllocVisitor(int bus_id, AllocVisitor visitor) {
+#if GOOGLE_CUDA
+ mutex_lock lock(mu_);
+ gpu::Platform* gpu_platform = GPUMachineManager();
+ for (int gpu_id = 0; gpu_id < static_cast<int64>(gpu_allocators_.size());
+ ++gpu_id) {
+ gpu::StreamExecutor* se =
+ gpu_platform->ExecutorForDevice(gpu_id).ValueOrDie();
+ if (gpu_allocators_[gpu_id] &&
+ se->GetDeviceDescription().numa_node() == bus_id) {
+ gpu_allocators_[gpu_id]->AddAllocVisitor(visitor);
+ }
+ }
+ while (bus_id >= static_cast<int64>(gpu_visitors_.size())) {
+ gpu_visitors_.push_back(std::vector<AllocVisitor>());
+ }
+ gpu_visitors_[bus_id].push_back(visitor);
+#endif // GOOGLE_CUDA
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/process_state.h b/tensorflow/core/common_runtime/gpu/process_state.h
new file mode 100644
index 0000000000..527d12c10d
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/process_state.h
@@ -0,0 +1,140 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_PROCESS_STATE_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_PROCESS_STATE_H_
+
+#include <functional>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace tensorflow {
+
+class Allocator;
+class VisitableAllocator;
+class PoolAllocator;
+
+// Singleton that manages per-process state, e.g. allocation
+// of shared resources.
+class ProcessState {
+ public:
+ static ProcessState* singleton();
+
+ // Descriptor for memory allocation attributes, used by optional
+ // runtime correctness analysis logic.
+ struct MemDesc {
+ enum MemLoc { CPU, GPU };
+ MemLoc loc;
+ int dev_index;
+ bool gpu_registered;
+ bool nic_registered;
+ MemDesc()
+ : loc(CPU),
+ dev_index(0),
+ gpu_registered(false),
+ nic_registered(false) {}
+ string DebugString();
+ };
+
+ // Records the number of GPUs available in the local process.
+ // It is a fatal error to call this with a value != to the value
+ // in a prior call.
+ void SetGPUCount(int c);
+
+ // Returns number of GPUs available in local process, as set by
+ // SetGPUCount(); Returns 0 if SetGPUCount has not been called.
+ int GPUCount() const;
+
+ // Returns what we know about the memory at ptr.
+ // If we know nothing, it's called CPU 0 with no other attributes.
+ MemDesc PtrType(const void* ptr);
+
+ // Returns the one CPUAllocator used for the given numa_node.
+ // TEMPORY: ignores numa_node.
+ Allocator* GetCPUAllocator(int numa_node);
+
+ // Returns the one GPU allocator used for the indexed GPU.
+ // Note that this is a system GPU index, not (necessarily) a brain
+ // device index.
+ //
+ // 'total_bytes' is the total number of bytes that should be made
+ // available to the allocator. The first call to this function for
+ // a given gpu_id creates the allocator, so only the total_bytes
+ // used on that first call is used.
+ //
+ // REQUIRES: gpu_id must be a valid ordinal for a GPU available in the
+ // current system environment. Otherwise returns nullptr.
+ Allocator* GetGPUAllocator(int gpu_id, size_t total_bytes);
+
+ Allocator* GetCUDAHostAllocator(int numa_node);
+
+ // Registers a function to be called once on every new Region
+ // allocated by every GPURegionAllocator proximate to the specified
+ // bus. The AllocVisitor is provided with a memory pointer and the
+ // size of the area it identifies. The pointer is not guaranteed to
+ // be valid after the call terminates. The intention is for this
+ // interface to be used for network device memory registration.
+ // "bus_id" is platform-specific. On many platforms it
+ // should be 0. On machines with multiple PCIe buses, it should be
+ // the index of one of the PCIe buses. If the the bus_id is invalid,
+ // results are undefined.
+ typedef std::function<void(void*, size_t)> AllocVisitor;
+ void AddGPUAllocVisitor(int bus_id, AllocVisitor visitor);
+
+ typedef std::unordered_map<const void*, MemDesc> MDMap;
+
+ protected:
+ ProcessState();
+
+ static ProcessState* instance_;
+
+ mutex mu_;
+ int gpu_count_;
+
+ std::vector<PoolAllocator*> cpu_allocators_ GUARDED_BY(mu_);
+ std::vector<VisitableAllocator*> gpu_allocators_ GUARDED_BY(mu_);
+ std::vector<std::vector<AllocVisitor>> gpu_visitors_ GUARDED_BY(mu_);
+ std::vector<PoolAllocator*> cuda_host_allocators_ GUARDED_BY(mu_);
+
+ virtual ~ProcessState();
+
+ // Optional RecordingAllocators that wrap the corresponding
+ // Allocators for runtime attribute use analysis.
+ MDMap mem_desc_map_;
+ std::vector<Allocator*> cpu_al_ GUARDED_BY(mu_);
+ std::vector<Allocator*> gpu_al_ GUARDED_BY(mu_);
+ std::vector<Allocator*> cuda_al_ GUARDED_BY(mu_);
+};
+
+namespace internal {
+class RecordingAllocator : public Allocator {
+ public:
+ RecordingAllocator(ProcessState::MDMap* mm, Allocator* a,
+ ProcessState::MemDesc md, mutex* mu)
+ : mm_(mm), a_(a), md_(md), mu_(mu) {}
+
+ string Name() override { return a_->Name(); }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override {
+ void* p = a_->AllocateRaw(alignment, num_bytes);
+ mutex_lock l(*mu_);
+ (*mm_)[p] = md_;
+ return p;
+ }
+ void DeallocateRaw(void* p) override {
+ mutex_lock l(*mu_);
+ auto iter = mm_->find(p);
+ mm_->erase(iter);
+ a_->DeallocateRaw(p);
+ }
+ bool TracksAllocationSizes() override { return a_->TracksAllocationSizes(); }
+ size_t RequestedSize(void* p) override { return a_->RequestedSize(p); }
+ size_t AllocatedSize(void* p) override { return a_->AllocatedSize(p); }
+ ProcessState::MDMap* mm_; // not owned
+ Allocator* a_; // not owned
+ ProcessState::MemDesc md_;
+ mutex* mu_;
+};
+} // namespace internal
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_PROCESS_STATE_H_
diff --git a/tensorflow/core/common_runtime/gpu/visitable_allocator.h b/tensorflow/core/common_runtime/gpu/visitable_allocator.h
new file mode 100644
index 0000000000..23feed9aab
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/visitable_allocator.h
@@ -0,0 +1,30 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_VISITABLE_ALLOCATOR_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_VISITABLE_ALLOCATOR_H_
+
+#include <functional>
+#include "tensorflow/core/framework/allocator.h"
+
+namespace tensorflow {
+
+// Subclass VisitableAllocator instead of Allocator when a memory
+// allocator needs to enable some kind of registration/deregistration
+// of memory areas.
+class VisitableAllocator : public Allocator {
+ public:
+ // Visitor gets called with a pointer to a memory area and its
+ // size in bytes.
+ typedef std::function<void(void*, size_t)> Visitor;
+
+ // Register a visitor guaranteed to be called exactly once on each
+ // chunk of memory newly allocated from the underlying device.
+ // Typically, chunks will be reused and possibly sub-divided by a
+ // pool manager, so the calls will happen only once per process
+ // execution, not once per tensor (re)allocation.
+ virtual void AddAllocVisitor(Visitor visitor) = 0;
+
+ // Register a visitor guaranteed to be called on each chunk of
+ // memory returned to the underlying device.
+ virtual void AddFreeVisitor(Visitor visitor) = 0;
+};
+} // namespace tensorflow
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_VISITABLE_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h
new file mode 100644
index 0000000000..03fd9a97c3
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu_device_context.h
@@ -0,0 +1,45 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
+#define TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/device_base.h"
+
+namespace perftools {
+namespace gputools {
+class Stream;
+} // namespace gputools
+} // namespace perftools
+
+namespace tensorflow {
+
+namespace gpu = ::perftools::gputools;
+
+class GPUDeviceContext : public DeviceContext {
+ public:
+ GPUDeviceContext(int stream_id, gpu::Stream* stream)
+ : stream_id_(stream_id), stream_(stream) {}
+
+ ~GPUDeviceContext() override {}
+
+ gpu::Stream* stream() const override { return stream_; }
+ int stream_id() const { return stream_id_; }
+
+ void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
+ Tensor* device_tensor,
+ StatusCallback done) const override;
+
+ void CopyDeviceTensorToCPU(const Tensor* device_tensor,
+ const string& edge_name, Device* device,
+ Tensor* cpu_tensor, StatusCallback done) override;
+
+ void MaintainLifetimeOnStream(
+ const Tensor* t, perftools::gputools::Stream* stream) const override {}
+
+ private:
+ int stream_id_;
+ gpu::Stream* stream_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
new file mode 100644
index 0000000000..28afc95c1b
--- /dev/null
+++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
@@ -0,0 +1,160 @@
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/framework/op_segment.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/util/device_name_utils.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session_options.h"
+
+#if defined(PLATFORM_GOOGLE)
+DECLARE_bool(brain_gpu_use_bfc_allocator);
+#else
+extern bool FLAGS_brain_gpu_use_bfc_allocator;
+#endif
+
+namespace tensorflow {
+namespace test {
+
+Benchmark::Benchmark(const string& device, Graph* g,
+ const SessionOptions* options, Graph* init) {
+ RequireDefaultOps();
+
+ FLAGS_brain_gpu_use_bfc_allocator = true;
+
+ SessionOptions default_options;
+ if (!options) {
+ options = &default_options;
+ }
+
+ testing::StopTiming();
+ string t = str_util::Uppercase(device);
+ device_ =
+ DeviceFactory::NewDevice(t, *options, "/job:localhost/replica:0/task:0");
+ CHECK(device_) << "Could not create a " << device << " device";
+
+ pool_ = new thread::ThreadPool(options->env, "blocking",
+ port::NumSchedulableCPUs());
+
+ auto runner = [this](std::function<void()> closure) {
+ pool_->Schedule(closure);
+ };
+
+ rendez_ = NewLocalRendezvous();
+
+ if (init) {
+ Executor* init_exec;
+ TF_CHECK_OK(NewLocalExecutor(
+ {
+ device_, nullptr, false,
+ [this](const NodeDef& ndef, OpKernel** kernel) {
+ return CreateNonCachedKernel(device_, nullptr, ndef, kernel);
+ },
+ [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); },
+ },
+ init, &init_exec));
+ Executor::Args args;
+ args.rendezvous = rendez_;
+ args.runner = runner;
+ TF_CHECK_OK(init_exec->Run(args));
+ delete init_exec;
+ }
+
+ TF_CHECK_OK(NewLocalExecutor(
+ {
+ device_,
+ nullptr,
+ false,
+ [this](const NodeDef& ndef, OpKernel** kernel) {
+ return CreateNonCachedKernel(device_, nullptr, ndef, kernel);
+ },
+ [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); },
+ },
+ g, &exec_));
+}
+
+Benchmark::~Benchmark() {
+ if (device_) {
+ rendez_->Unref();
+ delete exec_;
+ delete device_;
+ delete pool_;
+ }
+}
+
+void Benchmark::Run(int iters) { RunWithArgs({}, {}, iters); }
+
+string GetRendezvousKey(const Node* node) {
+ string send_device;
+ TF_CHECK_OK(GetNodeAttr(node->def(), "send_device", &send_device));
+ string recv_device;
+ TF_CHECK_OK(GetNodeAttr(node->def(), "recv_device", &recv_device));
+ string tensor_name;
+ TF_CHECK_OK(GetNodeAttr(node->def(), "tensor_name", &tensor_name));
+ uint64 send_device_incarnation;
+ TF_CHECK_OK(GetNodeAttr(node->def(), "send_device_incarnation",
+ reinterpret_cast<int64*>(&send_device_incarnation)));
+ return Rendezvous::CreateKey(send_device, send_device_incarnation,
+ recv_device, tensor_name, FrameAndIter(0, 0));
+}
+
+void Benchmark::RunWithArgs(
+ const std::vector<std::pair<const Node*, Tensor>>& inputs,
+ const std::vector<const Node*>& outputs, int iters) {
+ if (device_) {
+ // Gets inputs' and outputs' rendezvous keys.
+ std::vector<std::pair<string, Tensor>> in;
+ for (const auto& p : inputs) {
+ in.push_back({GetRendezvousKey(p.first), p.second});
+ }
+ std::vector<string> out;
+ for (const auto& n : outputs) {
+ out.push_back(GetRendezvousKey(n));
+ }
+ Tensor unused; // In benchmark, we don't care the return value.
+ bool is_dead;
+
+ // Warm up
+ Executor::Args args;
+ args.rendezvous = rendez_;
+ args.runner = [this](std::function<void()> closure) {
+ pool_->Schedule(closure);
+ };
+ for (int i = 0; i < 3; ++i) {
+ for (const auto& p : in) {
+ rendez_->Send(p.first, Rendezvous::Args(), p.second, false);
+ }
+ TF_CHECK_OK(exec_->Run(args));
+ for (const string& key : out) {
+ rendez_->Recv(key, Rendezvous::Args(), &unused, &is_dead);
+ }
+ }
+ TF_CHECK_OK(device_->Sync());
+
+ testing::StartTiming();
+ while (iters-- > 0) {
+ for (const auto& p : in) {
+ rendez_->Send(p.first, Rendezvous::Args(), p.second, false);
+ }
+ TF_CHECK_OK(exec_->Run(args));
+ for (const string& key : out) {
+ rendez_->Recv(key, Rendezvous::Args(), &unused, &is_dead);
+ }
+ }
+
+ TF_CHECK_OK(device_->Sync());
+ testing::StopTiming();
+ }
+}
+
+} // end namespace test
+} // end namespace tensorflow
diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
new file mode 100644
index 0000000000..5ebe13e1d4
--- /dev/null
+++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
@@ -0,0 +1,52 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
+#define TENSORFLOW_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+class Device;
+class SessionOptions;
+
+namespace test {
+
+class Benchmark {
+ public:
+ // "device" must be either "cpu" or "gpu". Takes ownership of "g"
+ // and "init".
+ Benchmark(const string& device, Graph* g,
+ const SessionOptions* options = nullptr, Graph* init = nullptr);
+ ~Benchmark();
+
+ // Executes the graph for "iters" times.
+ void Run(int iters);
+
+ // If "g" contains send/recv nodes, before each execution, we send
+ // inputs to the corresponding recv nodes in the graph, after each
+ // execution, we recv outputs from the corresponding send nodes in
+ // the graph. In the benchmark, we throw away values returned by the
+ // graph.
+ void RunWithArgs(const std::vector<std::pair<const Node*, Tensor>>& inputs,
+ const std::vector<const Node*>& outputs, int iters);
+
+ private:
+ thread::ThreadPool* pool_ = nullptr;
+ thread::ThreadPool* non_blocking_pool_ = nullptr;
+ Device* device_ = nullptr;
+ Rendezvous* rendez_ = nullptr;
+ Executor* exec_ = nullptr;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Benchmark);
+};
+
+} // end namespace test
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc
new file mode 100644
index 0000000000..6a75346805
--- /dev/null
+++ b/tensorflow/core/common_runtime/local_device.cc
@@ -0,0 +1,51 @@
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
+#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/session_options.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+namespace {
+
+DeviceBase::CpuWorkerThreads eigen_worker_threads;
+Eigen::ThreadPoolInterface* eigen_thread_pool = nullptr;
+Eigen::ThreadPoolDevice* eigen_device = nullptr;
+
+static bool InitModule(const SessionOptions& options) {
+ int32 intra_op_parallelism_threads =
+ options.config.intra_op_parallelism_threads();
+ if (intra_op_parallelism_threads == 0) {
+ intra_op_parallelism_threads = port::NumSchedulableCPUs();
+ }
+ LOG(INFO) << "Local device intra op parallelism threads: "
+ << intra_op_parallelism_threads;
+ eigen_worker_threads.num_threads = intra_op_parallelism_threads;
+ eigen_worker_threads.workers = new thread::ThreadPool(
+ options.env, "Eigen", intra_op_parallelism_threads);
+ eigen_thread_pool = new EigenThreadPoolWrapper(eigen_worker_threads.workers);
+ eigen_device = new Eigen::ThreadPoolDevice(eigen_thread_pool,
+ eigen_worker_threads.num_threads);
+ return true;
+}
+} // end namespace
+
+// LocalDevice ----------------------------------------------------------------
+
+LocalDevice::LocalDevice(const SessionOptions& options,
+ const DeviceAttributes& attributes,
+ Allocator* device_allocator)
+ : Device(options.env, attributes, device_allocator) {
+ // All ThreadPoolDevices in the process will use this single fixed
+ // sized threadpool for numerical computations.
+ static bool init = InitModule(options);
+ CHECK(init); // Avoids compiler warning that init is unused.
+ set_tensorflow_cpu_worker_threads(&eigen_worker_threads);
+ set_eigen_cpu_device(eigen_device);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/local_device.h b/tensorflow/core/common_runtime/local_device.h
new file mode 100644
index 0000000000..fc4cfc2dfc
--- /dev/null
+++ b/tensorflow/core/common_runtime/local_device.h
@@ -0,0 +1,27 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_LOCAL_DEVICE_H_
+#define TENSORFLOW_COMMON_RUNTIME_LOCAL_DEVICE_H_
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+
+namespace tensorflow {
+
+class SessionOptions;
+
+// This class is shared by ThreadPoolDevice and GPUDevice and
+// initializes a shared Eigen compute device used by both. This
+// should eventually be removed once we refactor ThreadPoolDevice and
+// GPUDevice into more 'process-wide' abstractions.
+class LocalDevice : public Device {
+ public:
+ LocalDevice(const SessionOptions& options, const DeviceAttributes& attributes,
+ Allocator* device_allocator);
+ ~LocalDevice() override {}
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(LocalDevice);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_LOCAL_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/local_session.cc b/tensorflow/core/common_runtime/local_session.cc
new file mode 100644
index 0000000000..ab6993b8a2
--- /dev/null
+++ b/tensorflow/core/common_runtime/local_session.cc
@@ -0,0 +1,500 @@
+#include "tensorflow/core/common_runtime/local_session.h"
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/common_runtime/session_factory.h"
+#include "tensorflow/core/common_runtime/simple_placer.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_partition.h"
+#include "tensorflow/core/graph/subgraph.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+namespace {
+
+thread::ThreadPool* kernel_thread_pool_ = nullptr;
+static bool InitModule(const SessionOptions& options) {
+ int32 inter_op_parallelism_threads =
+ options.config.inter_op_parallelism_threads();
+ if (inter_op_parallelism_threads == 0) {
+ // Default to using the number of cores available in the process.
+ inter_op_parallelism_threads = port::NumSchedulableCPUs();
+ }
+ LOG(INFO) << "Local session inter op parallelism threads: "
+ << inter_op_parallelism_threads;
+ kernel_thread_pool_ = new thread::ThreadPool(options.env, "Compute",
+ inter_op_parallelism_threads);
+ return true;
+}
+
+// TODO(vrv): Figure out how to unify the many different functions
+// that generate RendezvousKey, since many of them have to be
+// consistent with each other.
+string GetRendezvousKey(const string& tensor_name,
+ const DeviceAttributes& device_info,
+ const FrameAndIter& frame_iter) {
+ return strings::StrCat(device_info.name(), ";",
+ strings::FpToString(device_info.incarnation()), ";",
+ device_info.name(), ";", tensor_name, ";",
+ frame_iter.frame_id, ":", frame_iter.iter_id);
+}
+
+// NOTE: On Android with a single device, there is never
+// a risk of an OpKernel blocking indefinitely:
+//
+// 1) No operations do I/O that depends on other simultaneous kernels,
+//
+// 2) Recv nodes always complete immediately: The inputs are sent into
+// the local rendezvous before we start the executor, so the
+// corresonding recvs will not block.
+//
+// Based on these assumptions, we can use the same thread pool for
+// both "non-blocking" and "blocking" OpKernels on Android.
+//
+// This may change down the road when we add support for multiple
+// devices that run concurrently, in which case we will need to
+// revisit this decision.
+void SchedClosure(std::function<void()> c) {
+// TODO(sanjay): Get rid of __ANDROID__ path
+#ifdef __ANDROID__
+ // On Android, there is no implementation of ThreadPool that takes
+ // std::function, only Closure, which we cannot easily convert.
+ //
+ // Instead, we just run the function in-line, which is currently
+ // safe given the reasoning above.
+ c();
+#else
+ kernel_thread_pool_->Schedule(c);
+#endif // __ANDROID__
+}
+
+} // namespace
+
+LocalSession::LocalSession(const SessionOptions& options,
+ const DeviceMgr* device_mgr)
+ : options_(options),
+ device_mgr_(device_mgr),
+ cancellation_manager_(new CancellationManager()) {
+ static bool init = InitModule(options);
+ CHECK(init); // Avoids compiler warning that init is unused.
+ session_handle_ = strings::FpToString(random::New64());
+ int devices_added = 0;
+ if (options.config.log_device_placement()) {
+ const string mapping_str = device_mgr_->DeviceMappingString();
+ printf("Device mapping:\n%s", mapping_str.c_str());
+ LOG(INFO) << "Device mapping:\n" << mapping_str;
+ }
+ for (auto d : device_mgr_->ListDevices()) {
+ devices_.push_back(d);
+ device_set_.AddDevice(d);
+ d->op_segment()->AddHold(session_handle_);
+
+ // The first device added is special: it is the 'client device' (a
+ // CPU device) from which we feed and fetch Tensors.
+ if (devices_added == 0) {
+ device_set_.set_client_device(d);
+ }
+ ++devices_added;
+ }
+}
+
+LocalSession::~LocalSession() {
+ for (auto d : device_mgr_->ListDevices()) {
+ d->op_segment()->RemoveHold(session_handle_);
+ }
+ for (auto it : executors_) {
+ delete it.second;
+ }
+ delete cancellation_manager_;
+}
+
+Status LocalSession::Create(const GraphDef& graph) {
+ mutex_lock l(graph_def_lock_);
+ if (graph_created_) {
+ return errors::AlreadyExists(
+ "A Graph has already been created for this session.");
+ }
+ return ExtendLocked(graph);
+}
+
+Status LocalSession::Extend(const GraphDef& graph) {
+ mutex_lock l(graph_def_lock_);
+ return ExtendLocked(graph);
+}
+
+Status LocalSession::ExtendLocked(const GraphDef& graph) {
+ graph_created_ = true; // In case this is first call
+ graph_def_.MergeFrom(graph);
+ return Status::OK();
+}
+
+Status LocalSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ std::vector<Tensor>* outputs) {
+ {
+ mutex_lock l(graph_def_lock_);
+ if (!graph_created_) {
+ return errors::InvalidArgument(
+ "Session was not created with a graph before Run()!");
+ }
+ }
+
+ // Extract the inputs names for this run of the session.
+ std::vector<string> input_tensor_names;
+ input_tensor_names.reserve(inputs.size());
+ for (const auto& it : inputs) {
+ input_tensor_names.push_back(it.first);
+ }
+
+ // Check if we already have an executor for these arguments.
+ ExecutorsAndKeys* executors_and_keys;
+ Status s = GetOrCreateExecutors(input_tensor_names, output_names,
+ target_nodes, &executors_and_keys);
+ if (!s.ok()) {
+ return s;
+ }
+
+ IntraProcessRendezvous* rendez =
+ new IntraProcessRendezvous(device_mgr_.get());
+ core::ScopedUnref rendez_unref(rendez);
+
+ // Insert the input tensors into the local rendezvous by their
+ // rendezvous key.
+ for (const auto& input : inputs) {
+ const string& input_key = executors_and_keys->input_keys[input.first];
+ s = rendez->Send(input_key, Rendezvous::Args(), input.second, false);
+ if (!s.ok()) {
+ rendez->StartAbort(s);
+ return s;
+ }
+ }
+
+ // Start parallel Executors.
+ Notification executors_done;
+ const int num_executors = executors_and_keys->device_executors.size();
+ ExecutorBarrier* barrier = new ExecutorBarrier(
+ num_executors, rendez, [&executors_done, &s](const Status& ret) {
+ s = ret;
+ executors_done.Notify();
+ });
+
+ Executor::Args args;
+ args.rendezvous = rendez;
+ args.cancellation_manager = cancellation_manager_;
+ args.runner = SchedClosure;
+
+ for (auto device_executor : executors_and_keys->device_executors) {
+ Executor* exec = device_executor.second;
+ exec->RunAsync(args, barrier->Get());
+ }
+
+ executors_done.WaitForNotification();
+
+ TF_RETURN_IF_ERROR(s);
+
+ if (!output_names.empty()) {
+ outputs->resize(output_names.size());
+ }
+
+ // Get the outputs from the rendezvous
+ for (size_t output_offset = 0; output_offset < output_names.size();
+ ++output_offset) {
+ const string& output_key =
+ executors_and_keys->output_keys[output_names[output_offset]];
+ Tensor output_tensor;
+ bool is_dead;
+
+ // Fetch data from the Rendezvous.
+ s = rendez->Recv(output_key, Rendezvous::Args(), &output_tensor, &is_dead);
+ if (is_dead) {
+ s = errors::InvalidArgument("The tensor returned for ",
+ output_names[output_offset],
+ " was not valid.");
+ }
+ if (!s.ok()) {
+ rendez->StartAbort(s);
+ outputs->clear();
+ return s;
+ }
+
+ (*outputs)[output_offset] = output_tensor;
+ }
+
+ return s;
+}
+
+Status LocalSession::GetOrCreateExecutors(
+ gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
+ gtl::ArraySlice<string> target_nodes,
+ ExecutorsAndKeys** executors_and_keys) {
+ // Sort the inputs and outputs, so we don't create separate
+ // executors when a user passes in the same inputs/outputs in
+ // different orders.
+ //
+ // We could consider some other signature instead of sorting that
+ // preserves the same property to avoid the sort in the future.
+ std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
+ std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
+ std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
+ std::sort(inputs_sorted.begin(), inputs_sorted.end());
+ std::sort(outputs_sorted.begin(), outputs_sorted.end());
+ std::sort(tn_sorted.begin(), tn_sorted.end());
+
+ const string key = strings::StrCat(str_util::Join(inputs_sorted, ","), "->",
+ str_util::Join(outputs_sorted, ","), "/",
+ str_util::Join(tn_sorted, ","));
+
+ // See if we already have the executors for this run.
+ {
+ mutex_lock l(executor_lock_); // could use reader lock
+ auto it = executors_.find(key);
+ if (it != executors_.end()) {
+ *executors_and_keys = it->second;
+ return Status::OK();
+ }
+ }
+
+ // The executor_lock_ is intentionally released while executor is
+ // being created.
+ std::unordered_map<string, Graph*> graphs;
+ Status s = CreateGraphs(inputs, outputs, target_nodes, &graphs);
+ if (!s.ok()) {
+ return s;
+ }
+
+ bool has_control_flow = false;
+ for (const auto& graph : graphs) {
+ for (const Node* n : graph.second->nodes()) {
+ if (IsControlFlow(n)) {
+ has_control_flow = true;
+ break;
+ }
+ }
+ if (has_control_flow) break;
+ }
+
+ std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
+
+ for (const auto& graph : graphs) {
+ const string& partition_name = graph.first;
+ Graph* partition_graph = graph.second;
+
+ Device* d;
+ s = device_mgr_->LookupDevice(partition_name, &d);
+ if (!s.ok()) {
+ return s;
+ }
+
+ LocalExecutorParams params;
+ params.has_control_flow = has_control_flow;
+ params.device = d;
+ params.create_kernel = [this, d](const NodeDef& ndef, OpKernel** kernel) {
+ return CreateCachedKernel(d, session_handle_, nullptr, ndef, kernel);
+ };
+ params.delete_kernel = [this, d](OpKernel* kernel) {
+ DeleteCachedKernel(d, session_handle_, kernel);
+ };
+
+ Executor* tmp_exec;
+ s = NewLocalExecutor(params, partition_graph, &tmp_exec);
+ if (!s.ok()) {
+ return s;
+ }
+ ek->device_executors.insert(std::make_pair(graph.first, tmp_exec));
+ }
+
+ // Compute the rendezvous keys to avoid recomputing them every time.
+ //
+ // We always use the first device as the device name portion of the
+ // key, even if we're feeding another graph.
+ for (const string& input : inputs) {
+ ek->input_keys[input] = GetRendezvousKey(
+ input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
+ }
+ for (const string& output : outputs) {
+ ek->output_keys[output] = GetRendezvousKey(
+ output, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
+ }
+
+ // Reacquire the lock, try to insert into the map.
+ mutex_lock l(executor_lock_);
+ const bool inserted = executors_.insert(std::make_pair(key, ek.get())).second;
+ if (!inserted) {
+ // Another thread created the entry before us, so delete the
+ // one we created and return the already created one.
+ auto it = executors_.find(key);
+ *executors_and_keys = it->second;
+ } else {
+ *executors_and_keys = ek.release();
+ }
+
+ return Status::OK();
+}
+
+void LocalSession::SaveStatefulNodes(Graph* graph) {
+ for (Node* n : graph->nodes()) {
+ if (n->op_def().is_stateful()) {
+ VLOG(2) << "Saving " << n->DebugString();
+ stateful_placements_[n->name()] = n->assigned_device_name();
+ }
+ }
+}
+
+void LocalSession::RestoreStatefulNodes(Graph* graph) {
+ for (Node* n : graph->nodes()) {
+ if (n->op_def().is_stateful()) {
+ auto iter = stateful_placements_.find(n->name());
+ if (iter != stateful_placements_.end()) {
+ n->set_assigned_device_name(iter->second);
+ VLOG(2) << "Restored " << n->DebugString();
+ }
+ }
+ }
+}
+
+Status LocalSession::CreateGraphs(gtl::ArraySlice<string> feeds,
+ gtl::ArraySlice<string> fetches,
+ gtl::ArraySlice<string> target_nodes,
+ std::unordered_map<string, Graph*>* outputs) {
+ Graph graph(OpRegistry::Global());
+ GraphConstructorOptions opts;
+
+ {
+ mutex_lock l(graph_def_lock_);
+ TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def_, &graph));
+ }
+
+ TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
+ &graph, feeds, fetches, target_nodes,
+ device_set_.client_device()->attributes()));
+
+ // Run the simple placer after rewriting the graph.
+ std::unordered_map<string, int32> node_name_to_cost_map;
+ for (Node* n : graph.nodes()) {
+ node_name_to_cost_map[n->name()] = n->cost_id();
+ }
+ SimplePlacer placer(&graph, &device_set_, &node_name_to_cost_map, &options_);
+
+ {
+ mutex_lock l(mu_);
+ // Restore stateful nodes.
+ RestoreStatefulNodes(&graph);
+ TF_RETURN_IF_ERROR(placer.Run());
+ // Save stateful nodes.
+ SaveStatefulNodes(&graph);
+ }
+
+ // Partition the graph across devices.
+ std::unordered_map<string, GraphDef> partitions;
+ PartitionOptions popts;
+ popts.node_to_loc = [](const Node* node) {
+ return node->assigned_device_name();
+ };
+ popts.new_name = [this](const string& prefix) {
+ mutex_lock l(mu_);
+ return strings::StrCat(prefix, "/_", name_counter_++);
+ };
+ popts.get_incarnation = [](const string& name) {
+ // The local session does not have changing incarnation numbers.
+ // Just return '1'.
+ return 1;
+ };
+ popts.control_flow_added = false;
+ TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
+
+ std::vector<string> device_names;
+ for (auto device : devices_) {
+ // Extract the LocalName from the device.
+ device_names.push_back(DeviceNameUtils::LocalName(device->name()));
+ }
+
+ // Check for valid partitions.
+ for (const auto& partition : partitions) {
+ const string& local_partition_name =
+ DeviceNameUtils::LocalName(partition.first);
+ if (std::count(device_names.begin(), device_names.end(),
+ local_partition_name) == 0) {
+ return errors::InvalidArgument(
+ "Creating a partition for ", local_partition_name,
+ " which doesn't exist in the list of available devices. Available "
+ "devices: ",
+ str_util::Join(device_names, ","));
+ }
+ }
+
+ for (const auto& partition : partitions) {
+ const string& partition_name = partition.first;
+
+ const GraphDef& graph_def = partition.second;
+ VLOG(2) << "Created " << graph_def.DebugString() << " for "
+ << partition_name;
+
+ Graph* device_graph = new Graph(OpRegistry::Global());
+ GraphConstructorOptions device_opts;
+ // There are internal operations (e.g., send/recv) that we now
+ // allow.
+ device_opts.allow_internal_ops = true;
+ device_opts.expect_device_spec = true;
+ Status s =
+ ConvertGraphDefToGraph(device_opts, graph_def, device_graph);
+ if (!s.ok()) {
+ delete device_graph;
+ // Also delete other graphs created during the loop.
+ gtl::STLDeleteValues(outputs);
+ return s;
+ }
+ outputs->insert(std::make_pair(partition_name, device_graph));
+ }
+
+ return Status::OK();
+}
+
+::tensorflow::Status LocalSession::Close() {
+ cancellation_manager_->StartCancel();
+ return ::tensorflow::Status::OK();
+}
+
+class LocalSessionFactory : public SessionFactory {
+ public:
+ LocalSessionFactory() {}
+
+ Session* NewSession(const SessionOptions& options) override {
+ std::vector<Device*> devices;
+ DeviceFactory::AddDevices(options, "/job:localhost/replica:0/task:0",
+ &devices);
+ return new LocalSession(options, new DeviceMgr(devices));
+ }
+};
+
+class LocalSessionRegistrar {
+ public:
+ LocalSessionRegistrar() {
+ SessionFactory::Register("LOCAL_SESSION", new LocalSessionFactory());
+ }
+};
+static LocalSessionRegistrar registrar;
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/local_session.h b/tensorflow/core/common_runtime/local_session.h
new file mode 100644
index 0000000000..453cfdde47
--- /dev/null
+++ b/tensorflow/core/common_runtime/local_session.h
@@ -0,0 +1,109 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_LOCAL_SESSION_H_
+#define TENSORFLOW_COMMON_RUNTIME_LOCAL_SESSION_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+class Device;
+
+class LocalSession : public Session {
+ public:
+ // Takes ownership of 'device_mgr'.
+ LocalSession(const SessionOptions& options, const DeviceMgr* device_mgr);
+ ~LocalSession() override;
+
+ ::tensorflow::Status Create(const GraphDef& graph) override;
+ ::tensorflow::Status Extend(const GraphDef& graph) override;
+ ::tensorflow::Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ std::vector<Tensor>* outputs) override;
+ ::tensorflow::Status Close() override;
+
+ private:
+ struct ExecutorsAndKeys {
+ std::unordered_map<string, Executor*> device_executors;
+ std::unordered_map<string, string> input_keys;
+ std::unordered_map<string, string> output_keys;
+
+ ~ExecutorsAndKeys() {
+ for (auto it : device_executors) {
+ delete it.second;
+ }
+ }
+ };
+
+ // Retrieves an already existing set of executors to run 'inputs' and
+ // 'outputs', or creates and caches them for future use.
+ ::tensorflow::Status GetOrCreateExecutors(
+ gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
+ gtl::ArraySlice<string> target_nodes,
+ ExecutorsAndKeys** executors_and_keys);
+
+ // Creates several graphs given the existing graph_def_ and the
+ // input feeds and fetches, given 'devices'.
+ ::tensorflow::Status CreateGraphs(
+ gtl::ArraySlice<string> feeds, gtl::ArraySlice<string> fetches,
+ gtl::ArraySlice<string> target_nodes,
+ std::unordered_map<string, Graph*>* outputs);
+
+ ::tensorflow::Status ExtendLocked(const GraphDef& graph)
+ EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
+
+ const SessionOptions options_;
+
+ // Device structures.
+ const std::unique_ptr<const DeviceMgr> device_mgr_;
+ std::vector<Device*> devices_; // not owned
+ DeviceSet device_set_;
+
+ string session_handle_;
+ bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
+
+ mutex graph_def_lock_;
+ GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
+
+ mutex executor_lock_; // protects executors_
+ // Holds mappings from signature to the executors that process
+ // it. The reason for a level of indirection around mapped_type is
+ // to guarantee address stability.
+ std::unordered_map<string, ExecutorsAndKeys*> executors_
+ GUARDED_BY(executor_lock_);
+
+ CancellationManager* cancellation_manager_;
+
+ // Saves and restores device placements for stateful nodes.
+ mutex mu_;
+ void SaveStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ void RestoreStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ // Map of placed stateful nodes, i.e. nodes for which is_stateful()
+ // is true, such as "params" and "queue" nodes. Once placed these
+ // nodes can not be moved to a different device. Maps node names to
+ // device names.
+ std::unordered_map<string, string> stateful_placements_ GUARDED_BY(mu_);
+
+ // For generating unique names.
+ int64 name_counter_ GUARDED_BY(mu_) = 0;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(LocalSession);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_LOCAL_SESSION_H_
diff --git a/tensorflow/core/common_runtime/local_session_test.cc b/tensorflow/core/common_runtime/local_session_test.cc
new file mode 100644
index 0000000000..9325fe44c3
--- /dev/null
+++ b/tensorflow/core/common_runtime/local_session_test.cc
@@ -0,0 +1,314 @@
+#include "tensorflow/core/common_runtime/local_session.h"
+
+#include <map>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/device_name_utils.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+Session* CreateSession() {
+ SessionOptions options;
+ (*options.config.mutable_device_count())["CPU"] = 2;
+ return NewSession(options);
+}
+
+class LocalSessionMinusAXTest : public ::testing::Test {
+ public:
+ void Initialize(std::initializer_list<float> a_values) {
+ RequireDefaultOps();
+ Graph graph(OpRegistry::Global());
+
+ Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
+ test::FillValues<float>(&a_tensor, a_values);
+ Node* a = test::graph::Constant(&graph, a_tensor);
+ a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
+
+ Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
+ test::FillValues<float>(&x_tensor, {1, 1});
+ Node* x = test::graph::Constant(&graph, x_tensor);
+ x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
+ x_ = x->name();
+
+ // y = A * x
+ Node* y = test::graph::Matmul(&graph, a, x, false, false);
+ y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
+ y_ = y->name();
+
+ Node* y_neg = test::graph::Unary(&graph, "Neg", y);
+ y_neg_ = y_neg->name();
+ y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
+
+ test::graph::ToGraphDef(&graph, &def_);
+ }
+
+ string x_;
+ string y_;
+ string y_neg_;
+ GraphDef def_;
+};
+
+TEST_F(LocalSessionMinusAXTest, RunSimpleNetwork) {
+ Initialize({3, 2, -1, 0});
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ ASSERT_OK(session->Create(def_));
+ std::vector<std::pair<string, Tensor>> inputs;
+
+ // Request two targets: one fetch output and one non-fetched output.
+ std::vector<string> output_names = {y_ + ":0"};
+ std::vector<string> target_nodes = {y_neg_};
+ std::vector<Tensor> outputs;
+ Status s = session->Run(inputs, output_names, target_nodes, &outputs);
+ ASSERT_OK(s);
+
+ ASSERT_EQ(1, outputs.size());
+ // The first output should be initiailzed and have the correct
+ // output.
+ auto mat = outputs[0].matrix<float>();
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ EXPECT_FLOAT_EQ(5.0, mat(0, 0));
+}
+
+TEST_F(LocalSessionMinusAXTest, TestFeed) {
+ Initialize({1, 2, 3, 4});
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+
+ ASSERT_OK(session->Create(def_));
+
+ // Fill in the input and ask for the output
+ //
+ // Note that the input being fed is on the second device.
+ Tensor t(DT_FLOAT, TensorShape({2, 1}));
+ t.matrix<float>()(0, 0) = 5;
+ t.matrix<float>()(1, 0) = 6;
+ std::vector<std::pair<string, Tensor>> inputs = {{x_, t}};
+ std::vector<string> output_names = {y_ + ":0"};
+ std::vector<Tensor> outputs;
+
+ // Run the graph
+ Status s = session->Run(inputs, output_names, {}, &outputs);
+ ASSERT_OK(s);
+
+ ASSERT_EQ(1, outputs.size());
+ auto mat = outputs[0].matrix<float>();
+
+ // Expect outputs to be; 1*5 + 2*6, 3*5 + 4*6
+ EXPECT_FLOAT_EQ(17.0, mat(0, 0));
+ EXPECT_FLOAT_EQ(39.0, mat(1, 0));
+}
+
+TEST_F(LocalSessionMinusAXTest, TestConcurrency) {
+ Initialize({1, 2, 3, 4});
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ ASSERT_OK(session->Create(def_));
+
+ // Fill in the input and ask for the output
+ thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "test", 4);
+
+ // Run the graph 1000 times in 4 different threads concurrently.
+ std::vector<string> output_names = {y_ + ":0"};
+ auto fn = [&session, output_names]() {
+ for (int i = 0; i < 1000; ++i) {
+ std::vector<std::pair<string, Tensor>> inputs;
+ std::vector<Tensor> outputs;
+ // Run the graph
+ Status s = session->Run(inputs, output_names, {}, &outputs);
+ ASSERT_TRUE(s.ok());
+ ASSERT_EQ(1, outputs.size());
+ auto mat = outputs[0].matrix<float>();
+ EXPECT_FLOAT_EQ(3.0, mat(0, 0));
+ }
+ };
+
+ for (int i = 0; i < 4; ++i) {
+ tp->Schedule(fn);
+ }
+
+ // Wait for the functions to finish.
+ delete tp;
+}
+
+TEST_F(LocalSessionMinusAXTest, TwoCreateCallsFails) {
+ Initialize({1, 2, 3, 4});
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ ASSERT_OK(session->Create(def_));
+
+ // Second is not.
+ ASSERT_FALSE(session->Create(def_).ok());
+}
+
+TEST_F(LocalSessionMinusAXTest, ForgetToCreate) {
+ Initialize({1, 2, 3, 4});
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ std::vector<std::pair<string, Tensor>> inputs;
+ std::vector<Tensor> outputs;
+ ASSERT_FALSE(session->Run(inputs, {y_ + ":0"}, {y_neg_}, &outputs).ok());
+}
+
+TEST_F(LocalSessionMinusAXTest, InvalidDevice) {
+ GraphDef def;
+ Graph graph(OpRegistry::Global());
+
+ Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
+ a_tensor.flat<float>().setRandom();
+ Node* a = test::graph::Constant(&graph, a_tensor);
+ a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
+ Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
+ x_tensor.flat<float>().setRandom();
+ Node* x = test::graph::Constant(&graph, x_tensor);
+ x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
+ // Skip placing y.
+ Node* y = test::graph::Matmul(&graph, a, x, false, false);
+ y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:2");
+
+ test::graph::ToGraphDef(&graph, &def);
+
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ ASSERT_OK(session->Create(def));
+ std::vector<std::pair<string, Tensor>> inputs;
+ std::vector<string> output_names = {y->name() + ":0"};
+ std::vector<Tensor> outputs;
+
+ // Should return an error.
+ ASSERT_FALSE(session->Run(inputs, output_names, {}, &outputs).ok());
+
+ // Fix placement and run again
+ def.Clear();
+ y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
+ test::graph::ToGraphDef(&graph, &def);
+ session.reset(CreateSession());
+ ASSERT_OK(session->Create(def));
+ ASSERT_OK(session->Run(inputs, output_names, {}, &outputs));
+}
+
+TEST(LocalSessionTest, KeepsStateAcrossRunsOfSession) {
+ GraphDef def;
+ Graph g(OpRegistry::Global());
+ Node* var = test::graph::Var(&g, DT_FLOAT, TensorShape({10}));
+ var->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
+
+ Tensor twenty(DT_FLOAT, TensorShape({10}));
+ for (int i = 0; i < 10; ++i) {
+ twenty.flat<float>()(i) = 20.0;
+ }
+
+ Node* twenty_node = test::graph::Constant(&g, twenty);
+ twenty_node->set_assigned_device_name(
+ "/job:localhost/replica:0/task:0/cpu:0");
+
+ Node* init = test::graph::Assign(&g, var, twenty_node);
+ init->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
+
+ test::graph::ToGraphDef(&g, &def);
+
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ ASSERT_OK(session->Create(def));
+
+ std::vector<std::pair<string, Tensor>> inputs;
+ std::vector<Tensor> outputs;
+
+ // Initialize the variable
+ Status s = session->Run(inputs, {init->name()}, {}, &outputs);
+ ASSERT_OK(s);
+
+ // Get the variable's data
+ s = session->Run(inputs, {var->name() + ":0"}, {}, &outputs);
+ ASSERT_OK(s);
+ ASSERT_EQ(1, outputs.size());
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ EXPECT_EQ(20.0, outputs[0].flat<float>()(0));
+}
+
+TEST(LocalSessionTest, MultipleFeedTest) {
+ GraphDef def;
+ Graph g(OpRegistry::Global());
+ Node* var = test::graph::Var(&g, DT_FLOAT, TensorShape({10}));
+ var->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
+
+ Tensor first_value(DT_FLOAT, TensorShape({}));
+ first_value.scalar<float>()() = 1.0;
+ Node* first_const = test::graph::Constant(&g, first_value);
+ Node* first_identity = test::graph::Identity(&g, first_const);
+
+ Tensor second_value(DT_FLOAT, TensorShape({}));
+ second_value.scalar<float>()() = 2.0;
+ Node* second_const = test::graph::Constant(&g, second_value);
+ Node* second_identity = test::graph::Identity(&g, second_const);
+
+ test::graph::ToGraphDef(&g, &def);
+
+ std::unique_ptr<Session> session(CreateSession());
+ ASSERT_TRUE(session != nullptr);
+ ASSERT_OK(session->Create(def));
+
+ std::vector<Tensor> outputs;
+
+ // Fetch without feeding.
+ Status s = session->Run(
+ {}, {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
+ &outputs);
+ ASSERT_TRUE(s.ok());
+ ASSERT_EQ(2, outputs.size());
+ ASSERT_EQ(1.0, outputs[0].flat<float>()(0));
+ ASSERT_EQ(2.0, outputs[1].flat<float>()(0));
+
+ s = session->Run(
+ {}, {second_identity->name() + ":0", first_identity->name() + ":0"}, {},
+ &outputs);
+ ASSERT_TRUE(s.ok());
+ ASSERT_EQ(2, outputs.size());
+ ASSERT_EQ(2.0, outputs[0].flat<float>()(0));
+ ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
+
+ Tensor value_11(DT_FLOAT, TensorShape({}));
+ value_11.scalar<float>()() = 11.0;
+ Tensor value_22(DT_FLOAT, TensorShape({}));
+ value_22.scalar<float>()() = 22.0;
+
+ // Feed [first_const, second_const]
+ s = session->Run(
+ {{first_const->name(), value_11}, {second_const->name(), value_22}},
+ {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
+ &outputs);
+ ASSERT_TRUE(s.ok());
+ ASSERT_EQ(2, outputs.size());
+ ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
+ ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
+
+ // Feed [second_const, first_const]
+ s = session->Run(
+ {{second_const->name(), value_22}, {first_const->name(), value_11}},
+ {first_identity->name() + ":0", second_identity->name() + ":0"}, {},
+ &outputs);
+ ASSERT_TRUE(s.ok());
+ ASSERT_EQ(2, outputs.size());
+ ASSERT_EQ(11.0, outputs[0].flat<float>()(0));
+ ASSERT_EQ(22.0, outputs[1].flat<float>()(0));
+}
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc
new file mode 100644
index 0000000000..111dea6d4c
--- /dev/null
+++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc
@@ -0,0 +1,170 @@
+#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+
+#include <unordered_set>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#if (!defined(PLATFORM_POSIX_ANDROID) && !defined(PLATFORM_GOOGLE_ANDROID)) && \
+ (defined(PLATFORM_GOOGLE) || GOOGLE_CUDA)
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#endif
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+namespace {
+
+void CopyTensorBetweenDevices(const string& id, DeviceContext* send_dev_context,
+ DeviceContext* recv_dev_context, Device* src,
+ Device* dst,
+ const AllocatorAttributes src_alloc_attr,
+ const AllocatorAttributes dst_alloc_attr,
+ const Tensor* input, Tensor* output,
+ std::function<void(const Status&)> done) {
+ if (src->attributes().device_type() != dst->attributes().device_type()) {
+ done(errors::Unimplemented(
+ "Copy between device types not yet implemented: src=", src->name(),
+ " dst=", dst->name()));
+ } else if (src->attributes().device_type() != "CPU") {
+ done(errors::Unimplemented(
+ "Copy between non-CPU devices not yet implemented"));
+ }
+ *output = *input;
+ done(Status::OK());
+}
+
+#if (!defined(PLATFORM_POSIX_ANDROID) && !defined(PLATFORM_GOOGLE_ANDROID)) && \
+ (defined(PLATFORM_GOOGLE) || GOOGLE_CUDA)
+constexpr auto CopyTensorBetweenDevicesFunc = &GPUUtil::CopyViaDMA;
+#else
+constexpr auto CopyTensorBetweenDevicesFunc = &CopyTensorBetweenDevices;
+#endif
+
+} // end namespace
+
+IntraProcessRendezvous::IntraProcessRendezvous(const DeviceMgr* device_mgr)
+ : device_mgr_(device_mgr), local_(NewLocalRendezvous()) {}
+
+IntraProcessRendezvous::~IntraProcessRendezvous() { local_->Unref(); }
+
+Status IntraProcessRendezvous::Send(const string& key,
+ const Rendezvous::Args& args,
+ const Tensor& val, const bool is_dead) {
+ VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key;
+ {
+ mutex_lock l(mu_);
+ if (!status_.ok()) return status_;
+ }
+ Rendezvous::ParsedKey parsed;
+ TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
+
+ // Buffers "val" and "device_context" in local_.
+ return local_->Send(key, args, val, is_dead);
+}
+
+Status IntraProcessRendezvous::ParseKey(const string& key, bool is_src,
+ Rendezvous::ParsedKey* parsed) {
+ {
+ mutex_lock l(mu_);
+ if (!status_.ok()) return status_;
+ }
+ TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, parsed));
+ return Status::OK();
+}
+
+void IntraProcessRendezvous::SameWorkerRecvDone(
+ const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
+ StatusCallback done) {
+ // Do a quick copy (sharing the underlying buffer) if both tensors
+ // are on host memory.
+ const bool src_host =
+ (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
+ const bool dst_host =
+ (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
+ if (src_host && dst_host) {
+ *out = in;
+ done(Status::OK());
+ return;
+ }
+
+ // This copy must involve a non-CPU device. Hence, "in" must support DMA
+ // (e.g., string tensors do not work on GPU).
+ if (!DataTypeCanUseMemcpy(in.dtype())) {
+ done(errors::InvalidArgument("Non-DMA-safe ", DataTypeString(in.dtype()),
+ " tensor may not be copied from/to a GPU."));
+ return;
+ }
+
+ Device* src_device;
+ Status s = device_mgr_->LookupDevice(parsed.src_device, &src_device);
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+ Device* dst_device;
+ s = device_mgr_->LookupDevice(parsed.dst_device, &dst_device);
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+
+ AllocatorAttributes attr = recv_args.alloc_attrs;
+ attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
+ recv_args.alloc_attrs.gpu_compatible());
+ Allocator* out_allocator = dst_device->GetAllocator(attr);
+ Tensor copy(out_allocator, in.dtype(), in.shape());
+ *out = copy;
+
+ CopyTensorBetweenDevicesFunc(parsed.edge_name, send_args.device_context,
+ recv_args.device_context, src_device, dst_device,
+ send_args.alloc_attrs, recv_args.alloc_attrs,
+ &in, out, done);
+}
+
+void IntraProcessRendezvous::RecvAsync(const string& key,
+ const Rendezvous::Args& recv_args,
+ DoneCallback done) {
+ VLOG(1) << "IntraProcessRendezvous Recv " << this << " " << key;
+
+ Rendezvous::ParsedKey parsed;
+ Status s = ParseKey(key, false /*!is_src*/, &parsed);
+ if (!s.ok()) {
+ done(s, Args(), recv_args, Tensor(), false);
+ return;
+ }
+
+ // Recv the tensor from local_.
+ local_->RecvAsync(key, recv_args, [this, parsed, done](
+ const Status& status,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& in, bool is_dead) {
+ Status s = status;
+ Tensor* out = new Tensor;
+ StatusCallback final_callback = [done, send_args, recv_args, out,
+ is_dead](const Status& s) {
+ done(s, send_args, recv_args, *out, is_dead);
+ delete out;
+ };
+
+ if (s.ok()) {
+ SameWorkerRecvDone(parsed, send_args, recv_args, in, out, final_callback);
+ } else {
+ final_callback(s);
+ }
+ });
+}
+
+void IntraProcessRendezvous::StartAbort(const Status& s) {
+ CHECK(!s.ok());
+ local_->StartAbort(s);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.h b/tensorflow/core/common_runtime/rendezvous_mgr.h
new file mode 100644
index 0000000000..eaae65f956
--- /dev/null
+++ b/tensorflow/core/common_runtime/rendezvous_mgr.h
@@ -0,0 +1,73 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
+#define TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+// IntraProcessRendezvous is a Rendezvous which expects all producers
+// and consumers to be devices immediately accessible within the
+// process. That is, it will never be necessary to perform an RPC to
+// communicate with either.
+//
+// Buffering of Tensor values is delegated to a "local" Rendezvous
+// obtained from NewLocalRendezvous(). This class just adds
+// functionality to coordinate multiple process-local devices.
+class IntraProcessRendezvous : public Rendezvous {
+ public:
+ explicit IntraProcessRendezvous(const DeviceMgr* device_mgr);
+
+ // Forwards to local_, where the Tensor "val" will be buffered and
+ // any waiting callback stored.
+ Status Send(const string& key, const Rendezvous::Args& args,
+ const Tensor& val, const bool is_dead) override;
+
+ // This method is called only by the RecvOp. It tests to see
+ // whether the value will be produced by a local or remote device
+ // and handles accordingly. In the local case it forwards to
+ // local_, in the remote case it initiates an RPC request.
+ void RecvAsync(const string& key, const Rendezvous::Args& args,
+ DoneCallback done) override;
+
+ void StartAbort(const Status& status) override;
+
+ private:
+ const DeviceMgr* device_mgr_;
+ Rendezvous* local_; // Owns a Ref on this object.
+
+ mutable mutex mu_;
+
+ // Status given by StartAbort() if any.
+ Status status_ GUARDED_BY(mu_);
+
+ ~IntraProcessRendezvous() override;
+
+ // Parses "key" into "parsed". If "is_src" is true, checks that the
+ // rendezvous key's source is in this process. If "is_src" is false,
+ // checks that the rendezvous key's destination is in this process.
+ Status ParseKey(const string& key, bool is_src,
+ Rendezvous::ParsedKey* parsed);
+
+ // Callback handling the case when a rendezvous has been
+ // accomplished in local_ and the consumer is local to this process.
+ // Tensor "in" will be copied into "out". The key "parsed" encodes
+ // the src and dst devices.
+ typedef std::function<void(const Status&)> StatusCallback;
+ void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args, const Tensor& in,
+ Tensor* out, StatusCallback done);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(IntraProcessRendezvous);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
diff --git a/tensorflow/core/common_runtime/session.cc b/tensorflow/core/common_runtime/session.cc
new file mode 100644
index 0000000000..6d1ab5cea4
--- /dev/null
+++ b/tensorflow/core/common_runtime/session.cc
@@ -0,0 +1,51 @@
+#include <string>
+
+#include "tensorflow/core/common_runtime/session_factory.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+
+namespace {
+Status GetFactory(const SessionOptions& options, SessionFactory** ret) {
+ string runtime_type = "LOCAL_SESSION";
+ if (!options.target.empty()) {
+ // Use the service based session.
+ runtime_type = "REMOTE_SESSION";
+ }
+ *ret = SessionFactory::GetFactory(runtime_type);
+ if (!*ret) {
+ return errors::NotFound("Could not find session factory for ",
+ runtime_type);
+ }
+ return Status::OK();
+}
+} // end namespace
+
+Session* NewSession(const SessionOptions& options) {
+ SessionFactory* factory;
+ Status s = GetFactory(options, &factory);
+ if (!s.ok()) {
+ LOG(ERROR) << s;
+ return nullptr;
+ }
+ return factory->NewSession(options);
+}
+
+Status NewSession(const SessionOptions& options, Session** out_session) {
+ SessionFactory* factory;
+ Status s = GetFactory(options, &factory);
+ if (!s.ok()) {
+ *out_session = nullptr;
+ LOG(ERROR) << s;
+ return s;
+ }
+ *out_session = factory->NewSession(options);
+ if (!*out_session) {
+ return errors::Internal("Failed to create session.");
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/session_factory.cc b/tensorflow/core/common_runtime/session_factory.cc
new file mode 100644
index 0000000000..666b99812d
--- /dev/null
+++ b/tensorflow/core/common_runtime/session_factory.cc
@@ -0,0 +1,41 @@
+#include "tensorflow/core/common_runtime/session_factory.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+namespace tensorflow {
+namespace {
+
+static mutex* get_session_factory_lock() {
+ static mutex session_factory_lock;
+ return &session_factory_lock;
+}
+
+typedef std::unordered_map<string, SessionFactory*> SessionFactories;
+SessionFactories* session_factories() {
+ static SessionFactories* factories = new SessionFactories;
+ return factories;
+}
+
+} // namespace
+
+void SessionFactory::Register(const string& runtime_type,
+ SessionFactory* factory) {
+ mutex_lock l(*get_session_factory_lock());
+ if (!session_factories()->insert({runtime_type, factory}).second) {
+ LOG(ERROR) << "Two session factories are being registered "
+ << "under" << runtime_type;
+ }
+}
+
+SessionFactory* SessionFactory::GetFactory(const string& runtime_type) {
+ mutex_lock l(*get_session_factory_lock()); // could use reader lock
+ auto it = session_factories()->find(runtime_type);
+ if (it == session_factories()->end()) {
+ return nullptr;
+ }
+ return it->second;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/session_factory.h b/tensorflow/core/common_runtime/session_factory.h
new file mode 100644
index 0000000000..f770ba93ff
--- /dev/null
+++ b/tensorflow/core/common_runtime/session_factory.h
@@ -0,0 +1,25 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_SESSION_FACTORY_H_
+#define TENSORFLOW_COMMON_RUNTIME_SESSION_FACTORY_H_
+
+#include <string>
+
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+class Session;
+class SessionOptions;
+
+class SessionFactory {
+ public:
+ virtual Session* NewSession(const SessionOptions& options) = 0;
+ virtual ~SessionFactory() {}
+ static void Register(const string& runtime_type, SessionFactory* factory);
+ static SessionFactory* GetFactory(const string& runtime_type);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_SESSION_FACTORY_H_
diff --git a/tensorflow/core/common_runtime/session_options.cc b/tensorflow/core/common_runtime/session_options.cc
new file mode 100644
index 0000000000..ef585efb5c
--- /dev/null
+++ b/tensorflow/core/common_runtime/session_options.cc
@@ -0,0 +1,9 @@
+#include "tensorflow/core/public/session_options.h"
+
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+
+SessionOptions::SessionOptions() : env(Env::Default()) {}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/session_test.cc b/tensorflow/core/common_runtime/session_test.cc
new file mode 100644
index 0000000000..82b5d7ffb0
--- /dev/null
+++ b/tensorflow/core/common_runtime/session_test.cc
@@ -0,0 +1,17 @@
+#include "tensorflow/core/public/session.h"
+
+#include "tensorflow/core/public/session_options.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+TEST(SessionTest, InvalidTargetReturnsNull) {
+ SessionOptions options;
+ options.target = "invalid target";
+
+ EXPECT_EQ(nullptr, tensorflow::NewSession(options));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc
new file mode 100644
index 0000000000..1cd1db29db
--- /dev/null
+++ b/tensorflow/core/common_runtime/simple_placer.cc
@@ -0,0 +1,559 @@
+#include "tensorflow/core/common_runtime/simple_placer.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Returns a list of devices sorted by name from 'devices' whose type is in
+// 'supported_device_types'. This function searches in order of the device
+// types in 'supported_device_types' and returns the *first* subset of devices
+// that match.
+//
+// For example, if suported_device_types contains {GPU, CPU} and
+// 'devices' contains CPU and GPU devices, the returned vector will
+// include *only* GPU devices, since that is higher in the priority
+// order in 'supported_device_types'.
+std::vector<Device*> FilterSupportedDevices(
+ const std::vector<Device*>& devices,
+ const DeviceTypeVector& supported_device_types) {
+ std::vector<Device*> filtered_devices;
+ auto device_sort = [](const Device* a, const Device* b) {
+ return a->name() < b->name();
+ };
+ for (DeviceType d : supported_device_types) {
+ for (Device* device : devices) {
+ if (DeviceType(device->attributes().device_type()) == d) {
+ filtered_devices.emplace_back(device);
+ }
+ }
+
+ // If there are any devices under this device type, return this
+ // subset.
+ if (!filtered_devices.empty()) {
+ std::sort(filtered_devices.begin(), filtered_devices.end(), device_sort);
+ return filtered_devices;
+ }
+ }
+
+ std::sort(filtered_devices.begin(), filtered_devices.end(), device_sort);
+ return filtered_devices;
+}
+
+bool HasColocatedNodeName(const Node& node) {
+ return StringPiece(node.def().device()).starts_with("@");
+}
+
+Status ParseColocatedNodeName(const Node& node,
+ string* out_colocated_node_name) {
+ StringPiece device(node.def().device());
+ if (!device.Consume("@")) {
+ return errors::InvalidArgument("Malformed colocated node name: '", device,
+ "'");
+ }
+ // TODO(mrry): Validate that the node name is a valid node name.
+ *out_colocated_node_name = device.ToString();
+ return Status::OK();
+}
+
+// This class maintains the connected components of a colocation
+// constraint graph, and uses this information to assign a satisfying
+// device placement to the nodes of the graph.
+//
+// The typical usage pattern is:
+//
+// Graph graph = ...;
+// DeviceSet device_set = ...;
+// ColocationGraph colocation_graph(graph, device_set);
+//
+// // Add all the nodes of graph to colocation_graph.
+// for (Node* node : graph.nodes()) {
+// TF_RETURN_IF_ERROR(colocation_graph.AddNode(*node));
+// }
+//
+// // Add one or more colocation constraint.
+// Node node_1 = *graph.FindNodeId(...);
+// Node node_2 = *graph.FindNodeId(...);
+// TF_RETURN_IF_ERROR(colocation_graph.ColocateNodes(node_1, node_2));
+//
+// // Assign devices based on the accumulated constraints.
+// for (Node* node : graph.nodes()) {
+// TF_RETURN_IF_ERROR(colocation_graph.AssignDevice(node));
+// }
+//
+// The implementation uses the union-find algorithm to maintain the
+// connected components efficiently and incrementally as edges
+// (implied by ColocationGraph::ColocateNodes() invocations) are added.
+class ColocationGraph {
+ public:
+ ColocationGraph(Graph* graph, const DeviceSet* device_set,
+ const SessionOptions* options)
+ : device_set_(device_set),
+ device_types_(device_set->PrioritizedDeviceTypeList()),
+ options_(options) {
+ members_.reserve(graph->num_node_ids());
+ }
+
+ // Adds the given node to this ColocationGraph as a singleton.
+ //
+ // NOTE: The implementation assumes that the ids of nodes passed to
+ // this method are dense and zero-based; the memory used will be linear in
+ // the largest node ID.
+ // NOTE: If this method returns an error, *this is left in an undefined
+ // state.
+ Status AddNode(const Node& node) {
+ Member member;
+ TF_RETURN_IF_ERROR(InitializeMember(node, &member));
+ CHECK_GE(member.parent, 0);
+ members_.resize(member.parent + 1);
+ members_[member.parent] = std::move(member);
+ return Status::OK();
+ }
+
+ // Merge the (possibly disjoint) sets containing nodes "x" and
+ // "y". Returns OK if the all nodes in the union of these sets can
+ // be placed on the same device type.
+ //
+ // NOTE: If this method returns an error, *this is left in an undefined
+ // state.
+ Status ColocateNodes(const Node& x, const Node& y) {
+ int x_root = FindRoot(x.id());
+ int y_root = FindRoot(y.id());
+ if (x_root != y_root) {
+ // Merge the sets by swinging the parent pointer of the smaller
+ // tree to point to the root of the larger tree. Together with
+ // path compression in ColocationGraph::FindRoot, this ensures
+ // that we do not experience pathological performance on graphs
+ // such as chains.
+ int new_root, old_root;
+ if (members_[x_root].rank < members_[y_root].rank) {
+ // The tree rooted at x_root is shallower, so connect it to
+ // y_root. The rank of y_root is unchanged because its new
+ // child has strictly less rank.
+ members_[x_root].parent = y_root;
+ new_root = y_root;
+ old_root = x_root;
+ } else if (members_[x_root].rank > members_[y_root].rank) {
+ // The tree rooted at y_root is shallower, so connect it to
+ // x_root. The rank of x_root is unchanged because its new
+ // child has strictly less rank.
+ members_[y_root].parent = x_root;
+ new_root = x_root;
+ old_root = y_root;
+ } else {
+ // Both trees have the same rank, so break the tie by choosing
+ // x_root as the new root.
+ members_[y_root].parent = x_root;
+ // Increment the rank of the tree rooted at x_root, because it
+ // is now strictly deeper than before.
+ ++members_[x_root].rank;
+ new_root = x_root;
+ old_root = y_root;
+ }
+
+ // Merge the partial device specifications, and ensure that they are
+ // compatible. NULL options_ is treated as allowing soft placement.
+ // TODO(mrry): Consider enriching the error message by pointing
+ // out which nodes have the explicit partial device
+ // specifications that caused this conflict.
+ TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
+ &members_[new_root].device_name, members_[old_root].device_name,
+ options_ == nullptr || options_->config.allow_soft_placement()));
+
+ // Ensure that the common root has at least one supported device
+ // type, by computing the intersection of
+ // members_[new_root].supported_device_types and
+ // members_[old_root].supported_device_types.
+ MergeSupportedDevices(&members_[new_root].supported_device_types,
+ members_[old_root].supported_device_types);
+ if (members_[x_root].supported_device_types.size() == 0) {
+ return errors::InvalidArgument(
+ "Cannot colocate nodes '", x.name(), "' and '", y.name(),
+ "' because no device type supports both of those nodes and the "
+ "other nodes colocated with them");
+ }
+ }
+ return Status::OK();
+ }
+
+ // For the given node, subject to the constraints previously given
+ // to this ColocationGraph, set its assigned_device_name. Returns OK
+ // if a satisfying device can be found, otherwise an error.
+ Status AssignDevice(Node* node) {
+ int node_root = FindRoot(node->id());
+ if (members_[node_root].assigned_device == nullptr) {
+ // We have not yet assigned a device for the colocated node set containing
+ // n, so we do so now using the constraints on the root node.
+
+ // "devices" will contain the set of feasible placements for the
+ // colocated node set containing n.
+ std::vector<Device*> devices;
+ if (DeviceNameUtils::HasSomeDetails(members_[node_root].device_name)) {
+ // The root node has a (possibly partial) device
+ // specification, so enumerate the physical devices that
+ // conform to it.
+ device_set_->FindMatchingDevices(members_[node_root].device_name,
+ &devices);
+
+ if (!devices.empty()) {
+ // Filter devices into those that are compatible with the root
+ // node (and its children).
+ devices = FilterSupportedDevices(
+ devices, members_[node_root].supported_device_types);
+ }
+
+ // Perform soft placement if allow_soft_placement is set. options_
+ // being NULL is treated as allowing soft placement.
+ if (devices.empty() &&
+ (options_ == nullptr || options_->config.allow_soft_placement())) {
+ // The soft_device_name is the same as the node's device name
+ // without specifying the device type or ID.
+ DeviceNameUtils::ParsedName soft_device_name =
+ members_[node_root].device_name;
+ soft_device_name.type.clear();
+ soft_device_name.has_type = false;
+ soft_device_name.has_id = false;
+ device_set_->FindMatchingDevices(soft_device_name, &devices);
+ if (!devices.empty()) {
+ devices = FilterSupportedDevices(
+ devices, members_[node_root].supported_device_types);
+ }
+ }
+
+ if (devices.empty()) {
+ // Return an error when a physical device that matches an explicit
+ // device specification is not found. This ensures that we don't
+ // assign a node to GPU when the user wanted to force it on CPU.
+ DeviceNameUtils::ParsedName specified_device_name;
+ if (DeviceNameUtils::ParseFullName(node->def().device(),
+ &specified_device_name) &&
+ specified_device_name == members_[node_root].device_name) {
+ // The specified device and merged set device match, and
+ // will appear in the GraphDef (for debugging), so just
+ // print the specified device.
+ return errors::InvalidArgument(
+ "Could not satisfy explicit device specification '",
+ node->def().device(), "'");
+ } else {
+ // The specified device may be a valid device but the
+ // merged set device is different, so print both.
+ return errors::InvalidArgument(
+ "Could not satisfy explicit device specification '",
+ node->def().device(),
+ "' because the node was colocated with a group of nodes that "
+ "required incompatible device '",
+ DeviceNameUtils::ParsedNameToString(
+ members_[node_root].device_name),
+ "'");
+ }
+ }
+ } else {
+ // The device is completely unspecified, so enumerate the devices that
+ // support all of the nodes in the set.
+ if (device_set_->devices().empty()) {
+ return errors::Internal("No devices are registered");
+ }
+ devices = FilterSupportedDevices(
+ device_set_->devices(), members_[node_root].supported_device_types);
+
+ if (devices.empty()) {
+ return errors::InvalidArgument(
+ "Node had no OpKernel registered to support this operation: ",
+ "Operation was ", node->type_string(), " and inputs were ",
+ DataTypeVectorString(node->input_types()));
+ }
+ }
+
+ // Returns the first device in sorted devices list so we will always
+ // choose the same device.
+ members_[node_root].assigned_device = devices[0];
+ }
+ node->set_assigned_device_name(members_[node_root].assigned_device->name());
+
+ // Log placement if log_device_placement is set.
+ if (options_ && options_->config.log_device_placement()) {
+ printf("%s: %s\n", node->name().c_str(),
+ node->assigned_device_name().c_str());
+ LOG(INFO) << node->name() << ": " << node->assigned_device_name();
+ }
+
+ return Status::OK();
+ }
+
+ private:
+ // Represents a node in the disjoint node set forest, and the
+ // accumulated constraints on the device used by that node.
+ struct Member {
+ Member() = default;
+ // The id of the node that is the parent of this one, or its own
+ // id if it is a root. parent <= 0 indicates that this member is invalid.
+ int parent = -1;
+ // A proxy for the depth of the tree that is used to prefer
+ // connecting smaller trees to larger trees when merging disjoint
+ // sets.
+ int rank = 0;
+ // The intersection of all device types supported by this node,
+ // and those of all of its children, in priority order
+ // of the preferred device.
+ DeviceTypeVector supported_device_types;
+ // The merged form of the device requested for this node, with
+ // those of all of its children.
+ DeviceNameUtils::ParsedName device_name;
+ // If this node is a root, stores the Device to which this node
+ // and all of its children have been assigned, or nullptr if this
+ // has not yet been computed by GetAssignedDevice().
+ Device* assigned_device = nullptr;
+ };
+
+ Status InitializeMember(const Node& node, Member* member) {
+ const int id = node.id();
+ if (id < 0) {
+ return errors::InvalidArgument("Node id was not positive: ", id);
+ }
+ member->parent = id;
+ TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
+ device_types_, node.def(), &member->supported_device_types));
+
+ if (!node.assigned_device_name().empty()) {
+ // This node has already been assigned to a device, so we
+ // respect this placement, after sanity-checking it. The
+ // device_name and supported_device_types for this node reflect
+ // the assigned device, so any nodes colocated with this node
+ // will be assigned to the same device (assuming this is
+ // possible).
+ // NOTE: Since any assignment must have been performed by
+ // the TensorFlow runtime, we consider errors in this branch to
+ // be INTERNAL.
+ if (!DeviceNameUtils::ParseFullName(node.assigned_device_name(),
+ &member->device_name)) {
+ return errors::Internal("Malformed assigned device '",
+ node.assigned_device_name(), "'");
+ }
+ std::vector<Device*> devices;
+ const Device* assigned_device =
+ device_set_->FindDeviceByName(node.assigned_device_name());
+ if (assigned_device == nullptr) {
+ return errors::Internal("Assigned device '",
+ node.assigned_device_name(),
+ "' does not match any device");
+ }
+
+ for (DeviceType d : member->supported_device_types) {
+ if (DeviceType(assigned_device->attributes().device_type()) == d) {
+ return Status::OK();
+ }
+ }
+
+ return errors::Internal("Assigned device '", node.assigned_device_name(),
+ "' does not have registered OpKernel support "
+ "for ",
+ node.def().op());
+ } else {
+ // This node has not yet been assigned to a device, so we
+ // calculate any constraints due to the set of registered
+ // kernels and any (partial) user-provided device specification
+ // in the NodeDef.
+
+ // If no kernels are registered for this op type, fail with an error.
+ if (member->supported_device_types.empty()) {
+ return errors::InvalidArgument(
+ "No OpKernel was registered to support "
+ "Op '",
+ node.def().op(), "' with these attrs");
+ }
+
+ // If the NodeDef contains a device that is *not* a colocated node name
+ // (i.e. it does not begin with '@') then we interpret it as a (partial)
+ // device specification.
+ string colocated_node_name;
+ if (!node.def().device().empty() && !HasColocatedNodeName(node)) {
+ // The user has specified a device in the NodeDef, try to find a
+ // valid device matching their specification in the set of
+ // devices.
+ // NOTE: The full name may specify a device that is not in
+ // n.supported_device_types(), but we check that in AssignDevice().
+ if (!DeviceNameUtils::ParseFullName(node.def().device(),
+ &member->device_name)) {
+ return errors::InvalidArgument("Malformed device specification '",
+ node.def().device(), "'");
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ // Updates target to contain the intersection of the device types in
+ // "target" and "other".
+ static void MergeSupportedDevices(DeviceTypeVector* target,
+ const DeviceTypeVector& other) {
+ DeviceTypeVector temp = *target;
+ target->clear();
+
+ // Iterate in priority order.
+ for (DeviceType device_type : temp) {
+ bool found = false;
+ for (DeviceType other_device_type : other) {
+ if (device_type == other_device_type) {
+ found = true;
+ break;
+ }
+ }
+ if (found) {
+ target->push_back(device_type);
+ }
+ }
+ }
+
+ // Returns the root node of the disjoint tree to which the node with the
+ // given id is connected.
+ int FindRoot(int node_id) {
+ DCHECK_GE(members_[node_id].parent, 0);
+ if (members_[node_id].parent != node_id) {
+ // NOTE: Compress paths from node_id to its root, so that future
+ // calls to FindRoot and ColocateNodes are more efficient.
+ members_[node_id].parent = FindRoot(members_[node_id].parent);
+ }
+ return members_[node_id].parent;
+ }
+
+ std::vector<Member> members_;
+ const DeviceSet* device_set_; // Not owned.
+ const std::vector<DeviceType> device_types_;
+ const SessionOptions* options_; // Not owned;
+};
+
+} // namespace
+
+SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices,
+ const NodeNameToIdMap* name_to_id_map,
+ const SessionOptions* options)
+ : graph_(graph),
+ devices_(devices),
+ name_to_id_map_(name_to_id_map),
+ options_(options) {}
+
+SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices,
+ const NodeNameToIdMap* name_to_id_map)
+ : graph_(graph), devices_(devices), name_to_id_map_(name_to_id_map) {
+ options_ = nullptr;
+}
+
+SimplePlacer::~SimplePlacer() {}
+
+Status SimplePlacer::Run() {
+ if (devices_->devices().empty()) {
+ return errors::FailedPrecondition("No devices are registered");
+ }
+
+ ColocationGraph colocation_graph(graph_, devices_, options_);
+ Status status;
+
+ // 1. First add all of the nodes. Note that steps (1) and (2)
+ // requires two passes over the nodes because the graph (and hence
+ // the constraints) may not be acyclic.
+ for (Node* node : graph_->nodes()) {
+ // Skip the source and sink nodes.
+ if (!node->IsOp()) {
+ continue;
+ }
+ status = colocation_graph.AddNode(*node);
+ if (!status.ok()) return AttachDef(status, node->def());
+ }
+
+ // 2. Enumerate the constraint edges, and use them to update the disjoint
+ // node set.
+ for (Node* node : graph_->nodes()) {
+ if (!node->IsOp()) {
+ continue;
+ }
+
+ // 2(a). If node n specifies a colocation constraint as its device name,
+ // add an edge from the colocated node to n.
+ if (HasColocatedNodeName(*node)) {
+ string colocated_node_name;
+ status = ParseColocatedNodeName(*node, &colocated_node_name);
+ if (!status.ok()) {
+ return AttachDef(status, node->def());
+ }
+ Node* colocated_node;
+ status = GetNodeByName(colocated_node_name, &colocated_node);
+ if (!status.ok()) {
+ return AttachDef(
+ errors::InvalidArgument("Colocated node named in device '",
+ colocated_node_name, "' does not exist"),
+ node->def());
+ }
+ status = colocation_graph.ColocateNodes(*colocated_node, *node);
+ if (!status.ok()) {
+ return AttachDef(
+ errors::InvalidArgument(
+ "Cannot satisfy colocation constraint named in device '",
+ colocated_node_name, "': ", status.error_message()),
+ node->def());
+ }
+ }
+
+ // 2(b). If `node` has an input edge with reference type, add an
+ // edge from the source of that edge to `node`.
+ for (const auto& edge : node->in_edges()) {
+ if (!edge->IsControlEdge() &&
+ IsRefType(node->input_type(edge->dst_input()))) {
+ status = colocation_graph.ColocateNodes(*edge->src(), *node);
+ if (!status.ok()) {
+ return AttachDef(
+ errors::InvalidArgument("Cannot satisfy colocation constraint "
+ "implied by reference connection: ",
+ status.error_message()),
+ node->def());
+ }
+ }
+ }
+ }
+
+ // 3. For each node, assign a device based on the constraints in the
+ // disjoint node set.
+ for (Node* node : graph_->nodes()) {
+ // Skip the source and sink nodes.
+ if (!node->IsOp()) {
+ continue;
+ }
+ // Skip nodes that already have an assigned name.
+ if (!node->assigned_device_name().empty()) {
+ continue;
+ }
+
+ status = colocation_graph.AssignDevice(node);
+ if (!status.ok()) {
+ return AttachDef(
+ errors::InvalidArgument("Cannot assign a device to node '",
+ node->name(), "': ", status.error_message()),
+ node->def());
+ }
+ }
+ return Status::OK();
+}
+
+Status SimplePlacer::GetNodeByName(const string& name, Node** out_node) const {
+ NodeNameToIdMap::const_iterator iter = name_to_id_map_->find(name);
+ if (iter != name_to_id_map_->end()) {
+ *out_node = graph_->FindNodeId(iter->second);
+ if (*out_node) {
+ return Status::OK();
+ }
+ }
+ return errors::NotFound(name);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/simple_placer.h b/tensorflow/core/common_runtime/simple_placer.h
new file mode 100644
index 0000000000..4b3df50c72
--- /dev/null
+++ b/tensorflow/core/common_runtime/simple_placer.h
@@ -0,0 +1,81 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_
+#define TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/util/device_name_utils.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+// A placement algorithm that assigns the nodes of the given Graph to
+// devices the given DeviceSet, respecting the following constraints:
+//
+// 1. Existing device assignments remain unchanged.
+// 2. Requested (partial or complete) device specifications in the
+// are granted.
+// 3. Nodes connected by edges of a reference type are colocated on
+// the same device.
+// 4. Given nodes "A" and "B", if node "B" has the device specification
+// "@A", nodes "A" and "B" will be colocated on the same device.
+//
+// The implementation builds a constraint graph with the same set of
+// nodes, and edges that represent colocation constraints between
+// nodes. Each connected component in the resulting constraint graph
+// is then assigned to a single device.
+//
+// TODO(mrry): "Soft" constraints, such as "place node 'x' as close as
+// possible to node 'y' while respecting the other constraints"?
+// TODO(mrry): Create a common interface for this and the other
+// placement algorithms so that they may be injected into the graph
+// builder.
+class SimplePlacer {
+ public:
+ // A map from graph node names to numerical IDs (in a Graph object).
+ typedef std::unordered_map<string, int> NodeNameToIdMap;
+
+ // Creates an instance of the SimplePlacer algorithm for the given
+ // Graph "graph" (nodes in which may or may not be assigned) on the
+ // given DeviceSet "devices". The "name_to_id_map" maps the names of
+ // nodes in "g" to their numerical ID.
+ //
+ // REQUIRES: for all mappings (k, v) in "name_to_id_map",
+ // graph.FindNodeId(v)->name() == k.
+ //
+ // The "graph", "devices", and "name_to_id_map" pointer arguments
+ // are borrowed by this SimplePlacer, and must outlive it.
+ SimplePlacer(Graph* graph, const DeviceSet* devices,
+ const NodeNameToIdMap* name_to_id_map,
+ const SessionOptions* options);
+
+ SimplePlacer(Graph* graph, const DeviceSet* devices,
+ const NodeNameToIdMap* name_to_id_map);
+
+ ~SimplePlacer();
+
+ // Assigns each node in this SimplePlacer's graph to a device in its
+ // set of devices.
+ //
+ // This method is not thread-safe.
+ // Run() may be invoked at most once.
+ Status Run();
+
+ private:
+ Status GetNodeByName(const string& name, Node** out_node) const;
+
+ Graph* const graph_; // Not owned.
+ const DeviceSet* const devices_; // Not owned.
+ const NodeNameToIdMap* const name_to_id_map_; // Not owned.
+ const SessionOptions* options_; // Not owned.
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SimplePlacer);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_SIMPLE_PLACER_H_
diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc
new file mode 100644
index 0000000000..3139962d7e
--- /dev/null
+++ b/tensorflow/core/common_runtime/simple_placer_test.cc
@@ -0,0 +1,863 @@
+#include "tensorflow/core/common_runtime/simple_placer.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+namespace {
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// Op, kernel, and device registrations to set up the environment.
+//
+// The SimplePlacer uses information about the op (input types),
+// kernel (device constraints), and available devices to make
+// placement decisions. To avoid depending on the full runtime, we
+// define dummy implementations of these, and register them with the
+// runtime.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+// A dummy OpKernel that is used to register ops on different devices.
+class DummyOp : public OpKernel {
+ public:
+ explicit DummyOp(OpKernelConstruction* context) : OpKernel(context) {}
+ void Compute(OpKernelContext* context) override {}
+};
+
+// A fake device that has specific device attributes, used to simulate
+// the presence of a CPU or a GPU (without depending on that part of
+// the runtime.
+class FakeDevice : public Device {
+ private:
+ explicit FakeDevice(const DeviceAttributes& device_attributes)
+ : Device(nullptr, device_attributes, nullptr) {}
+
+ public:
+ Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
+
+ Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
+
+ static std::unique_ptr<Device> MakeCPU(const string& name) {
+ DeviceAttributes device_attributes;
+ device_attributes.set_name(name);
+ device_attributes.set_device_type(DeviceType(DEVICE_CPU).type());
+ return std::unique_ptr<Device>(new FakeDevice(device_attributes));
+ }
+
+ static std::unique_ptr<Device> MakeGPU(const string& name) {
+ DeviceAttributes device_attributes;
+ device_attributes.set_name(name);
+ device_attributes.set_device_type(DeviceType(DEVICE_GPU).type());
+ return std::unique_ptr<Device>(new FakeDevice(device_attributes));
+ }
+};
+
+// Register the following ops so they can be added to a Graph, and
+// kernels so that they can be placed on particular device types.
+REGISTER_OP("TestVariable").Output("o: Ref(float)");
+REGISTER_KERNEL_BUILDER(Name("TestVariable").Device(DEVICE_CPU), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("TestVariable").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("VariableCPU").Output("o: Ref(float)");
+REGISTER_KERNEL_BUILDER(Name("VariableCPU").Device(DEVICE_CPU), DummyOp);
+
+REGISTER_OP("VariableGPU").Output("o: Ref(float)");
+REGISTER_KERNEL_BUILDER(Name("VariableGPU").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("VariableNoKernels").Output("o: Ref(float)");
+
+REGISTER_OP("TestAdd").Input("a: float").Input("b: float").Output("o: float");
+REGISTER_KERNEL_BUILDER(Name("TestAdd").Device(DEVICE_CPU), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("TestAdd").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("TestRelu").Input("i: float").Output("o: float");
+REGISTER_KERNEL_BUILDER(Name("TestRelu").Device(DEVICE_CPU), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("TestRelu").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("ReluGPU").Input("i: float").Output("o: float");
+REGISTER_KERNEL_BUILDER(Name("ReluGPU").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("TestAssign").Input("i: Ref(float)").Input("v: float");
+REGISTER_KERNEL_BUILDER(Name("TestAssign").Device(DEVICE_CPU), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("TestAssign").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("AssignCPU").Input("i: Ref(float)").Input("v: float");
+REGISTER_KERNEL_BUILDER(Name("AssignCPU").Device(DEVICE_CPU), DummyOp);
+
+REGISTER_OP("AssignGPU").Input("i: Ref(float)").Input("v: float");
+REGISTER_KERNEL_BUILDER(Name("AssignGPU").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("TestInput").Output("a: float").Output("b: float");
+REGISTER_KERNEL_BUILDER(Name("TestInput").Device(DEVICE_CPU), DummyOp);
+
+REGISTER_OP("TestDevice").Output("a: float").Output("b: float");
+REGISTER_KERNEL_BUILDER(Name("TestDevice").Device(DEVICE_GPU), DummyOp);
+
+REGISTER_OP("TestDeviceEnforce").Input("a: Ref(float)").Output("b: float");
+REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device(DEVICE_CPU), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device(DEVICE_GPU), DummyOp);
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// A SimplePlacerTest method has three phases:
+//
+// 1. Build a TensorFlow graph, with no (or partial) device assignments.
+// 2. Attempt to compute a placement using the SimplePlacer.
+// 3. EITHER: test that the constraints implied by the graph are respected;
+// or that an appropriate error was reported.
+//
+////////////////////////////////////////////////////////////////////////////////
+class SimplePlacerTest : public ::testing::Test {
+ protected:
+ SimplePlacerTest() {
+ RequireDefaultOps();
+ // Build a set of 10 GPU and 10 CPU devices.
+ // NOTE: this->local_devices_ owns the device objects;
+ // this->devices_ contains borrowed pointers to the device
+ // objects.
+ for (int i = 0; i < 10; ++i) {
+ local_devices_.emplace_back(FakeDevice::MakeCPU(
+ strings::StrCat("/job:a/replica:0/task:0/cpu:", i)));
+ devices_.AddDevice(local_devices_.back().get());
+ // Insert the GPUs in reverse order.
+ local_devices_.emplace_back(FakeDevice::MakeGPU(
+ strings::StrCat("/job:a/replica:0/task:0/gpu:", 9 - i)));
+ devices_.AddDevice(local_devices_.back().get());
+ }
+ }
+
+ // Builds the given graph, and (if successful) indexes the node
+ // names for use in placement, and later lookup.
+ Status BuildGraph(const GraphDefBuilder& builder, Graph* out_graph) {
+ TF_RETURN_IF_ERROR(builder.ToGraph(out_graph));
+ nodes_by_name_.clear();
+ for (Node* node : out_graph->nodes()) {
+ nodes_by_name_[node->name()] = node->id();
+ }
+ return Status::OK();
+ }
+
+ // Invokes the SimplePlacer on "graph". If no DeviceSet is specified, the
+ // placement will use the default DeviceSet (of 10 CPU and 10 GPU devices).
+ //
+ // REQUIRES: "*graph" was produced by the most recent call to BuildGraph.
+ Status Place(Graph* graph, DeviceSet* devices, SessionOptions* options) {
+ SimplePlacer placer(graph, devices, &nodes_by_name_, options);
+ return placer.Run();
+ }
+
+ Status Place(Graph* graph, DeviceSet* devices) {
+ return Place(graph, devices, nullptr);
+ }
+
+ Status Place(Graph* graph, SessionOptions* options) {
+ return Place(graph, &devices_, options);
+ }
+
+ Status Place(Graph* graph) { return Place(graph, &devices_, nullptr); }
+
+ // Returns the node in "graph" with the given name.
+ //
+ // REQUIRES: "graph" was produced by the most recent call to BuildGraph.
+ Node* GetNodeByName(const Graph& graph, const string& name) {
+ const auto search = nodes_by_name_.find(name);
+ CHECK(search != nodes_by_name_.end()) << "Unknown node name: " << name;
+ return graph.FindNodeId(search->second);
+ }
+
+ protected:
+ std::vector<std::unique_ptr<Device>> local_devices_;
+ DeviceSet devices_;
+ SimplePlacer::NodeNameToIdMap nodes_by_name_;
+
+ Status ReferenceTestHelper(const string& variable_op_type,
+ const string& assign_op_type,
+ DeviceType expected_device_type);
+};
+
+#define EXPECT_COLOCATED(g, name_a, name_b) \
+ do { \
+ Graph& g_ = (g); \
+ EXPECT_EQ(GetNodeByName(g_, (name_a))->assigned_device_name(), \
+ GetNodeByName(g_, (name_b))->assigned_device_name()); \
+ } while (0)
+
+#define EXPECT_DEVICE_TYPE(g, name, expected_device_type) \
+ EXPECT_EQ(DeviceType(expected_device_type).type(), \
+ devices_.FindDeviceByName( \
+ GetNodeByName((g), (name))->assigned_device_name()) \
+ ->attributes() \
+ .device_type())
+
+#define EXPECT_DEVICE_CONTAINS(g, name, device_substr) \
+ EXPECT_TRUE(StringPiece(GetNodeByName((g), (name))->assigned_device_name()) \
+ .contains(device_substr))
+
+// Test that a graph with no constraints will successfully assign nodes to the
+// "best available" device (i.e. prefer GPU over CPU).
+TEST_F(SimplePlacerTest, TestNoConstraints) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ ops::UnaryOp("TestRelu", ops::NodeOut(input, 0), b.opts().WithName("n1"));
+ ops::UnaryOp("TestRelu", ops::NodeOut(input, 1), b.opts().WithName("n2"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ EXPECT_OK(Place(&g));
+ EXPECT_DEVICE_TYPE(g, "in", DEVICE_CPU);
+ EXPECT_DEVICE_TYPE(g, "n1", DEVICE_GPU);
+ EXPECT_DEVICE_TYPE(g, "n2", DEVICE_GPU);
+}
+
+// Test that a graph with device type and reference constraints on
+// some of the ops will successfully assign nodes to the constrained
+// device, and colocate nodes with reference connections.
+TEST_F(SimplePlacerTest, TestDeviceTypeConstraints) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu"));
+ ops::BinaryOp("AssignCPU", var_cpu, input, b.opts().WithName("assign_cpu"));
+ Node* var_gpu = ops::SourceOp("VariableGPU", b.opts().WithName("var_gpu"));
+ ops::BinaryOp("AssignGPU", var_gpu, input, b.opts().WithName("assign_gpu"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ EXPECT_OK(Place(&g));
+ EXPECT_DEVICE_TYPE(g, "in", DEVICE_CPU);
+ EXPECT_DEVICE_TYPE(g, "var_cpu", DEVICE_CPU);
+ EXPECT_DEVICE_TYPE(g, "assign_cpu", DEVICE_CPU);
+ EXPECT_COLOCATED(g, "var_cpu", "assign_cpu");
+ EXPECT_DEVICE_TYPE(g, "var_gpu", DEVICE_GPU);
+ EXPECT_DEVICE_TYPE(g, "assign_gpu", DEVICE_GPU);
+ EXPECT_COLOCATED(g, "var_gpu", "assign_gpu");
+}
+
+// Test that a graph with partial device specifications on the ops
+// will successfully
+TEST_F(SimplePlacerTest, TestPartialSpec) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:a"));
+ ops::SourceOp("TestVariable",
+ b.opts().WithName("var").WithDevice("/job:a"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ EXPECT_OK(Place(&g));
+ EXPECT_DEVICE_TYPE(g, "in", DEVICE_CPU);
+ EXPECT_DEVICE_CONTAINS(g, "in", "/job:a");
+ EXPECT_DEVICE_TYPE(g, "var", DEVICE_GPU);
+ EXPECT_DEVICE_CONTAINS(g, "var", "/job:a");
+}
+
+// Test that a node with an assigned device is not relocated.
+TEST_F(SimplePlacerTest, TestAssignedDevicePreserved) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ GetNodeByName(g, "in")
+ ->set_assigned_device_name("/job:a/replica:0/task:0/cpu:7");
+
+ EXPECT_OK(Place(&g));
+ EXPECT_EQ("/job:a/replica:0/task:0/cpu:7",
+ GetNodeByName(g, "in")->assigned_device_name());
+}
+
+// Test that a graph with partial device specifications for CPU-only ops
+// will be relocated to CPU.
+TEST_F(SimplePlacerTest, TestPartialSpecGpuToCpu) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/gpu:0"));
+ ops::SourceOp("TestVariable",
+ b.opts().WithName("var").WithDevice("/gpu:0"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ SessionOptions options;
+ options.config.set_allow_soft_placement(true);
+ EXPECT_OK(Place(&g, &options));
+ EXPECT_DEVICE_TYPE(g, "in", DEVICE_CPU);
+ EXPECT_DEVICE_CONTAINS(g, "in", "/cpu");
+ EXPECT_DEVICE_TYPE(g, "var", DEVICE_GPU);
+ EXPECT_DEVICE_CONTAINS(g, "var", "/gpu:0");
+}
+
+// Test that a node with an assigned GPU device but has not registered
+// OpKernel will fail.
+TEST_F(SimplePlacerTest, TestAssignedGpuDeviceToCpuDevice) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ GetNodeByName(g, "in")
+ ->set_assigned_device_name("/job:a/replica:0/task:0/gpu:0");
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INTERNAL, s.code());
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains("Assigned device '/job:a/replica:0/task:0/gpu:0' "
+ "does not have registered OpKernel support for TestInput"));
+}
+
+// Test that graphs with reference connections are correctly placed.
+
+// Build a graph containing a Variable op of "variable_op_type" and an
+// Assign op of "assign_op_type", and expect all of the ops to be
+// placed on a device of type "expected_device_type".
+Status SimplePlacerTest::ReferenceTestHelper(const string& variable_op_type,
+ const string& assign_op_type,
+ DeviceType expected_device_type) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ // Build ten variable-and-assignment pairs.
+ for (int i = 0; i < 10; ++i) {
+ Node* var = ops::SourceOp(variable_op_type,
+ b.opts().WithName(strings::StrCat("var_", i)));
+ ops::BinaryOp(assign_op_type, var, input,
+ b.opts().WithName(strings::StrCat("assign_", i)));
+ }
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ TF_RETURN_IF_ERROR(Place(&g));
+
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_COLOCATED(g, strings::StrCat("var_", i),
+ strings::StrCat("assign_", i));
+ EXPECT_DEVICE_TYPE(g, strings::StrCat("var_", i), expected_device_type);
+ EXPECT_DEVICE_TYPE(g, strings::StrCat("assign_", i), expected_device_type);
+ }
+
+ return Status::OK();
+}
+
+// Test all 2^3 combinations of Variable and Assignment op types
+// (unconstrained, CPU-only, and GPU-only).
+TEST_F(SimplePlacerTest, TestReferenceConnection) {
+ Status s;
+ EXPECT_OK(ReferenceTestHelper("TestVariable", "TestAssign", DEVICE_GPU));
+ EXPECT_OK(ReferenceTestHelper("TestVariable", "AssignCPU", DEVICE_CPU));
+ EXPECT_OK(ReferenceTestHelper("TestVariable", "AssignGPU", DEVICE_GPU));
+ EXPECT_OK(ReferenceTestHelper("VariableCPU", "TestAssign", DEVICE_CPU));
+ EXPECT_OK(ReferenceTestHelper("VariableCPU", "AssignCPU", DEVICE_CPU));
+ {
+ Status s = ReferenceTestHelper("VariableCPU", "AssignGPU", DEVICE_CPU);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("no device type supports both of those nodes"));
+ }
+ EXPECT_OK(ReferenceTestHelper("VariableGPU", "TestAssign", DEVICE_GPU));
+ {
+ Status s = ReferenceTestHelper("VariableGPU", "AssignCPU", DEVICE_CPU);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("no device type supports both of those nodes"));
+ }
+ EXPECT_OK(ReferenceTestHelper("VariableGPU", "AssignGPU", DEVICE_GPU));
+}
+
+// Test the handling of '@node_name' colocation constraints, when
+// these are arranged in multiple chains.
+TEST_F(SimplePlacerTest, TestColocatedChain) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ Node* last_node = input;
+ for (int i = 0; i < 100; ++i) {
+ if (i % 10 == 0) {
+ // Every ten nodes, start a new chain.
+ last_node = ops::UnaryOp("TestRelu", last_node,
+ b.opts().WithName(strings::StrCat("n_", i)));
+ } else {
+ // Chain each successive node to the previous one.
+ last_node =
+ ops::UnaryOp("TestRelu", last_node,
+ b.opts()
+ .WithName(strings::StrCat("n_", i))
+ .WithDevice(strings::StrCat("@n_", i - 1)));
+ }
+ }
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ EXPECT_OK(Place(&g));
+ for (int i = 0; i < 100; ++i) {
+ if (i % 10 != 0) {
+ EXPECT_COLOCATED(g, strings::StrCat("n_", i - (i % 1)),
+ strings::StrCat("n_", i));
+ }
+ }
+}
+
+// Test the handling of '@node_name' colocation constraints, when the
+// chains are shuffled.
+TEST_F(SimplePlacerTest, TestColocatedChainWithLongRangeColocations) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ Node* last_node = input;
+ for (int i = 0; i < 10; ++i) {
+ // Start ten chains.
+ last_node = ops::UnaryOp("TestRelu", last_node,
+ b.opts().WithName(strings::StrCat("n_", i)));
+ }
+ for (int i = 10; i < 100; ++i) {
+ // Add each node to the (i % 10)^th chain.
+ last_node = ops::UnaryOp("TestRelu", last_node,
+ b.opts()
+ .WithName(strings::StrCat("n_", i))
+ .WithDevice(strings::StrCat("@n_", i % 10)));
+ }
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ EXPECT_OK(Place(&g));
+ for (int i = 10; i < 100; ++i) {
+ EXPECT_COLOCATED(g, strings::StrCat("n_", i % 10),
+ strings::StrCat("n_", i));
+ }
+}
+
+TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ for (int i = 0; i < 10; ++i) {
+ // Declare ten variable and assignment pairs.
+ Node* var = ops::SourceOp("TestVariable",
+ b.opts().WithName(strings::StrCat("var_", i)));
+ ops::BinaryOp("TestAssign", var, input,
+ b.opts().WithName(strings::StrCat("assign_", i)));
+ }
+ for (int i = 10; i < 100; ++i) {
+ // Create a variable colocated with some existing variable, and
+ // an assignment colocated with a possibly-different variable.
+ Node* var = ops::SourceOp(
+ "TestVariable", b.opts()
+ .WithName(strings::StrCat("var_", i))
+ .WithDevice(strings::StrCat("@var_", i % 6)));
+ ops::BinaryOp("TestAssign", var, input,
+ b.opts()
+ .WithName(strings::StrCat("assign_", i))
+ .WithDevice(strings::StrCat("@assign_", i % 3)));
+ }
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ EXPECT_OK(Place(&g));
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_COLOCATED(g, strings::StrCat("var_", i),
+ strings::StrCat("assign_", i));
+ }
+ for (int i = 10; i < 100; ++i) {
+ EXPECT_COLOCATED(g, strings::StrCat("var_", i),
+ strings::StrCat("assign_", i));
+ EXPECT_COLOCATED(g, strings::StrCat("var_", i),
+ strings::StrCat("var_", i % 6));
+ EXPECT_COLOCATED(g, strings::StrCat("assign_", i),
+ strings::StrCat("assign_", i % 3));
+ }
+}
+
+// Test that placement fails when no devices are registered.
+TEST_F(SimplePlacerTest, TestEmptyDeviceSet) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ DeviceSet empty;
+
+ Status s = Place(&g, &empty);
+ EXPECT_TRUE(
+ StringPiece(s.error_message()).contains("No devices are registered"));
+}
+
+// Test that placement fails when the requested device forces an
+// indirect constraint to be violated.
+TEST_F(SimplePlacerTest, TestHeterogeneousDeviceSetFailure) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* in = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ Node* var = ops::SourceOp("VariableGPU", b.opts().WithName("var"));
+ ops::BinaryOp("TestAssign", var, in,
+ b.opts().WithName("assign").WithDevice("/job:b/task:1"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ DeviceSet heterogeneous;
+ std::unique_ptr<Device> gpu(
+ FakeDevice::MakeGPU("/job:b/replica:0/task:0/gpu:0"));
+ heterogeneous.AddDevice(gpu.get());
+ std::unique_ptr<Device> cpu(
+ FakeDevice::MakeCPU("/job:b/replica:0/task:1/cpu:0"));
+ heterogeneous.AddDevice(cpu.get());
+ Status s = Place(&g, &heterogeneous);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("colocated with a group of nodes that required "
+ "incompatible device"));
+}
+
+// Test that placement fails when an unknown device is requested.
+TEST_F(SimplePlacerTest, TestUnknownDevice) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:foo"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains(
+ "Could not satisfy explicit device specification '/job:foo'"));
+}
+
+// Test that placement fails when the combination of partial
+// constraints leads to an unknown device.
+TEST_F(SimplePlacerTest, TestUnknownMergedDevice) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/job:foo"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains(
+ "Could not satisfy explicit device specification '/job:foo'"));
+}
+
+// Test that placement fails when the previously-assigned device for a
+// node is unknown.
+TEST_F(SimplePlacerTest, TestUnknownAssignedDevice) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ GetNodeByName(g, "in")->set_assigned_device_name("/job:foo");
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INTERNAL, s.code());
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains("Assigned device '/job:foo' does not match any device"));
+}
+
+// Test that placement fails when an op with no registered kernels is
+// requested.
+TEST_F(SimplePlacerTest, TestNoKernelsRegistered) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("VariableNoKernels", b.opts().WithName("var"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains(
+ "No OpKernel was registered to support Op 'VariableNoKernels'"));
+}
+
+// Test that placement fails when a kernel is registered but no known
+// device supports it.
+TEST_F(SimplePlacerTest, TestNoDevicesRegistered) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("VariableGPU", b.opts().WithName("var"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ DeviceSet cpu_only;
+ std::unique_ptr<Device> cpu(
+ FakeDevice::MakeCPU("/job:a/replica:0/task:0/cpu:0"));
+ cpu_only.AddDevice(cpu.get());
+
+ Status s = Place(&g, &cpu_only);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("No OpKernel was registered to support "
+ "Op 'VariableGPU'"));
+}
+
+// Test that placement fails when a requested device is malformed.
+TEST_F(SimplePlacerTest, TestMalformedDeviceSpecification) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("/foo:bar"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("Malformed device specification '/foo:bar'"));
+}
+
+// Test that placement fails when a previously-assigned device is malformed.
+TEST_F(SimplePlacerTest, TestMalformedAssignedDevice) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ GetNodeByName(g, "in")->set_assigned_device_name("/foo:bar");
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INTERNAL, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("Malformed assigned device '/foo:bar'"));
+}
+
+// Test that placement fails when a device was previously assigned to
+// a node, but it does not uniquely identify a particular device.
+TEST_F(SimplePlacerTest, TestNonUniqueAssignedDevice) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ GetNodeByName(g, "in")->set_assigned_device_name("/job:a");
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INTERNAL, s.code());
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains("Assigned device '/job:a' does not match any device"));
+}
+
+// Test that placement fails when a node requests colocation with another
+// node that does not exist.
+TEST_F(SimplePlacerTest, TestUnknownColocatedNode) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("@foo"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message()).contains("'foo' does not exist"));
+}
+
+// Test that placement fails when a node requests colocation with a
+// malformed node name.
+TEST_F(SimplePlacerTest, TestMalformedColocatedNode) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("@"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("node named in device '' does not exist"));
+}
+
+// Test that ops request to be placed on non-existent devices will be relocated
+// to existing device of the same type if allow_soft_placement is set.
+TEST_F(SimplePlacerTest, TestNonexistentGpuAllowSoftPlacement) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestDevice", b.opts().WithName("in").WithDevice("/gpu:11"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ SessionOptions options;
+ options.config.set_allow_soft_placement(true);
+ EXPECT_OK(Place(&g, &options));
+ EXPECT_DEVICE_CONTAINS(g, "in", "/gpu:0");
+}
+
+// Test that ops request to be placed on non-existent devices will fail if
+// allow_soft_placement is not set.
+TEST_F(SimplePlacerTest, TestNonexistentGpuNoAllowSoftPlacement) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("TestDevice", b.opts().WithName("in").WithDevice("/gpu:11"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ SessionOptions options;
+ Status s = Place(&g, &options);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains(
+ "Could not satisfy explicit device specification '/gpu:11'"));
+}
+
+// Test that placement fails when a node requests an explicit device that is not
+// supported by the registered kernels if allow_soft_placement is no set.
+TEST_F(SimplePlacerTest, TestUnsupportedDeviceNoAllowSoftPlacement) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("VariableGPU", b.opts().WithName("var").WithDevice("/cpu:0"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ SessionOptions options;
+ Status s = Place(&g, &options);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains(
+ "Could not satisfy explicit device specification '/cpu:0'"));
+}
+
+TEST_F(SimplePlacerTest, TestUnsupportedDeviceAllowSoftPlacement) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ ops::SourceOp("VariableGPU", b.opts().WithName("var").WithDevice("/cpu:0"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ SessionOptions options;
+ options.config.set_allow_soft_placement(true);
+ EXPECT_OK(Place(&g, &options));
+}
+
+// Test that a graph with device type and reference constraints on
+// some of the ops will successfully assign nodes to the constrained
+// device, and colocate nodes with reference connections.
+TEST_F(SimplePlacerTest, TestDeviceTypeConstraintsAllowSoftPlacement) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ // var_gpu has ref output and runs on GPU.
+ // force_gpu takes var_gpu and requested CPU.
+ // Verify that both are placed on GPU.
+ Node* var_gpu = ops::SourceOp("VariableGPU", b.opts().WithName("var_gpu"));
+ ops::UnaryOp("TestDeviceEnforce", var_gpu,
+ b.opts().WithName("force_gpu").WithDevice("/cpu:0"));
+ // var_cpu has ref output and runs on CPU.
+ // force_cpu takes var_cpu and requested GPU.
+ // Verify that both are placed on CPU.
+ Node* var_cpu = ops::SourceOp("VariableCPU", b.opts().WithName("var_cpu"));
+ ops::UnaryOp("TestDeviceEnforce", var_cpu,
+ b.opts().WithName("force_cpu").WithDevice("/gpu:0"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ SessionOptions options;
+ options.config.set_allow_soft_placement(true);
+ EXPECT_OK(Place(&g, &options));
+ EXPECT_DEVICE_TYPE(g, "var_gpu", DEVICE_GPU);
+ EXPECT_DEVICE_TYPE(g, "force_gpu", DEVICE_GPU);
+ EXPECT_COLOCATED(g, "var_gpu", "force_gpu");
+ EXPECT_DEVICE_TYPE(g, "var_cpu", DEVICE_CPU);
+ EXPECT_DEVICE_TYPE(g, "force_cpu", DEVICE_CPU);
+ EXPECT_COLOCATED(g, "var_cpu", "force_cpu");
+}
+
+// Test that placement fails when two nodes have a reference connection
+// constraint, and each node requires a mutually incompatible device.
+TEST_F(SimplePlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* var = ops::SourceOp("VariableGPU", b.opts().WithName("var"));
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ ops::BinaryOp("AssignCPU", var, input, b.opts().WithName("assign"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("Cannot colocate nodes 'var' and 'assign'"));
+}
+
+// Test that placement fails when two nodes have an explicit
+// colocation constraint, and each node requires a mutually
+// incompatible device.
+TEST_F(SimplePlacerTest, TestUnsatisfiableConstraintWithColocatedNodes) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput",
+ b.opts().WithName("in").WithDevice("/gpu:0"));
+ Node* relu_1 = ops::UnaryOp("TestRelu", input,
+ b.opts().WithName("relu_1").WithDevice("@in"));
+ ops::UnaryOp("ReluGPU", relu_1,
+ b.opts().WithName("relu_2").WithDevice("@relu_1"));
+ EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("Cannot colocate nodes 'relu_1' and 'relu_2'"));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc
new file mode 100644
index 0000000000..4806e69c67
--- /dev/null
+++ b/tensorflow/core/common_runtime/threadpool_device.cc
@@ -0,0 +1,55 @@
+#include "tensorflow/core/common_runtime/threadpool_device.h"
+
+#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/types.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
+ const string& name, Bytes memory_limit,
+ BusAdjacency bus_adjacency,
+ Allocator* allocator)
+ : LocalDevice(options, Device::BuildDeviceAttributes(
+ name, DEVICE_CPU, memory_limit, bus_adjacency),
+ allocator),
+ allocator_(allocator) {}
+
+ThreadPoolDevice::~ThreadPoolDevice() {}
+
+void ThreadPoolDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
+ if (port::Tracing::IsActive()) {
+ // TODO(pbar) We really need a useful identifier of the graph node.
+ const uint64 id = Hash64(op_kernel->name());
+ port::Tracing::ScopedActivity region(port::Tracing::EventCategory::kCompute,
+ id);
+ op_kernel->Compute(context);
+ } else {
+ op_kernel->Compute(context);
+ }
+}
+
+Allocator* ThreadPoolDevice::GetAllocator(AllocatorAttributes attr) {
+ return allocator_;
+}
+
+Status ThreadPoolDevice::MakeTensorFromProto(
+ const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) {
+ Tensor parsed(tensor_proto.dtype());
+ if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ tensor_proto.DebugString());
+ }
+ *tensor = parsed;
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/threadpool_device.h b/tensorflow/core/common_runtime/threadpool_device.h
new file mode 100644
index 0000000000..5b0347231f
--- /dev/null
+++ b/tensorflow/core/common_runtime/threadpool_device.h
@@ -0,0 +1,31 @@
+#ifndef TENSORFLOW_COMMON_RUNTIME_THREADPOOL_DEVICE_H_
+#define TENSORFLOW_COMMON_RUNTIME_THREADPOOL_DEVICE_H_
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/local_device.h"
+
+namespace tensorflow {
+
+// CPU device implementation.
+class ThreadPoolDevice : public LocalDevice {
+ public:
+ ThreadPoolDevice(const SessionOptions& options, const string& name,
+ Bytes memory_limit, BusAdjacency bus_adjacency,
+ Allocator* allocator);
+ ~ThreadPoolDevice() override;
+
+ void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
+ Allocator* GetAllocator(AllocatorAttributes attr) override;
+ Status MakeTensorFromProto(const TensorProto& tensor_proto,
+ const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) override;
+
+ Status Sync() override { return Status::OK(); }
+
+ private:
+ Allocator* allocator_; // Not owned
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_THREADPOOL_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/threadpool_device_factory.cc b/tensorflow/core/common_runtime/threadpool_device_factory.cc
new file mode 100644
index 0000000000..ee6319abad
--- /dev/null
+++ b/tensorflow/core/common_runtime/threadpool_device_factory.cc
@@ -0,0 +1,31 @@
+// Register a factory that provides CPU devices.
+#include "tensorflow/core/common_runtime/threadpool_device.h"
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+// TODO(zhifengc/tucker): Figure out the bytes of available RAM.
+class ThreadPoolDeviceFactory : public DeviceFactory {
+ public:
+ void CreateDevices(const SessionOptions& options, const string& name_prefix,
+ std::vector<Device*>* devices) override {
+ // TODO(zhifengc/tucker): Figure out the number of available CPUs
+ // and/or NUMA configuration.
+ int n = 1;
+ auto iter = options.config.device_count().find("CPU");
+ if (iter != options.config.device_count().end()) {
+ n = iter->second;
+ }
+ for (int i = 0; i < n; i++) {
+ string name = strings::StrCat(name_prefix, "/cpu:", i);
+ devices->push_back(new ThreadPoolDevice(options, name, Bytes(256 << 20),
+ BUS_ANY, cpu_allocator()));
+ }
+ }
+};
+REGISTER_LOCAL_DEVICE_FACTORY("CPU", ThreadPoolDeviceFactory);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/example/example.proto b/tensorflow/core/example/example.proto
new file mode 100644
index 0000000000..194d1e7c24
--- /dev/null
+++ b/tensorflow/core/example/example.proto
@@ -0,0 +1,95 @@
+// Protocol messages for describing input data Examples for machine learning
+// model training or inference.
+syntax = "proto3";
+
+import "tensorflow/core/example/feature.proto";
+// option cc_enable_arenas = true;
+
+package tensorflow;
+
+// Example for a movie recommendation application:
+// features {
+// feature {
+// key: "age"
+// float_list {
+// value: 29.0
+// }
+// }
+// feature {
+// key: "movie"
+// bytes_list {
+// value: "The Shawshank Redemption"
+// value: "Fight Club"
+// }
+// }
+// feature {
+// key: "movie_ratings"
+// float_list {
+// value: 9.0
+// value: 9.7
+// }
+// }
+// feature {
+// key: "suggestion"
+// bytes_list {
+// value: "Inception"
+// }
+// }
+// # Note that this feature exists to be used as a label in training.
+// # E.g., if training a logistic regression model to predict purchase
+// # probability in our learning tool we would set the label feature to
+// # "suggestion_purchased".
+// feature {
+// key: "suggestion_purchased"
+// float_list {
+// value: 1.0
+// }
+// }
+// # Similar to "suggestion_purchased" above this feature exists to be used
+// # as a label in training.
+// # E.g., if training a linear regression model to predict purchase
+// # price in our learning tool we would set the label feature to
+// # "purchase_price".
+// feature {
+// key: "purchase_price"
+// float_list {
+// value: 9.99
+// }
+// }
+// }
+//
+// A conformant data set obeys the following conventions:
+// - If a Feature K exists in one example with data type T, it must be of
+// type T in all other examples when present. It may be omitted.
+// - The number of instances of Feature K list data may vary across examples,
+// depending on the requirements of the model.
+// - If a Feature K doesn't exist in an example, a K-specific default will be
+// used, if configured.
+// - If a Feature K exists in an example but contains no items, the intent
+// is considered to be an empty tensor and no default will be used.
+
+message Example {
+ Features features = 1;
+};
+
+// Example representing a ranking instance.
+message RankingExample {
+ Features context = 1;
+ repeated Features positive = 2;
+ repeated Features negative = 3;
+};
+
+// Example representing a sequence.
+// The context contains features which apply to the entire sequence.
+// Each element in example represents an entry in the sequence.
+message SequenceExample {
+ Features context = 1;
+ repeated Features features = 2;
+};
+
+// Example representing a list of feature maps.
+// The context contains features which apply to all feature maps.
+message InferenceExample {
+ Features context = 1;
+ repeated Features features = 2;
+};
diff --git a/tensorflow/core/example/feature.proto b/tensorflow/core/example/feature.proto
new file mode 100644
index 0000000000..5ab77c2997
--- /dev/null
+++ b/tensorflow/core/example/feature.proto
@@ -0,0 +1,82 @@
+// Protocol messages for describing features for machine learning model
+// training or inference.
+//
+// There are three base Feature types:
+// - bytes
+// - float
+// - int64
+//
+// Base features are contained in Lists which may hold zero or more values.
+//
+// Features are organized into categories by name. The Features message
+// contains the mapping from name to Feature.
+//
+// Example Features for a movie recommendation application:
+// feature {
+// key: "age"
+// float_list {
+// value: 29.0
+// }
+// }
+// feature {
+// key: "movie"
+// bytes_list {
+// value: "The Shawshank Redemption"
+// value: "Fight Club"
+// }
+// }
+// feature {
+// key: "movie_ratings"
+// float_list {
+// value: 9.0
+// value: 9.7
+// }
+// }
+// feature {
+// key: "suggestion"
+// bytes_list {
+// value: "Inception"
+// }
+// }
+// feature {
+// key: "suggestion_purchased"
+// int64_list {
+// value: 1
+// }
+// }
+// feature {
+// key: "purchase_price"
+// float_list {
+// value: 9.99
+// }
+// }
+
+syntax = "proto3";
+// option cc_enable_arenas = true;
+
+package tensorflow;
+
+message Feature {
+ // Each feature can be exactly one kind.
+ oneof kind {
+ BytesList bytes_list = 1;
+ FloatList float_list = 2;
+ Int64List int64_list = 3;
+ }
+};
+
+message Features {
+ // Map from feature name to feature.
+ map<string, Feature> feature = 1;
+};
+
+// Containers to hold repeated fundamental features.
+message BytesList {
+ repeated bytes value = 1;
+}
+message FloatList {
+ repeated float value = 1 [packed=true];
+}
+message Int64List {
+ repeated int64 value = 1 [packed=true];
+}
diff --git a/tensorflow/core/framework/allocation_description.proto b/tensorflow/core/framework/allocation_description.proto
new file mode 100644
index 0000000000..f6f4bc0126
--- /dev/null
+++ b/tensorflow/core/framework/allocation_description.proto
@@ -0,0 +1,15 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+message AllocationDescription {
+ // Total number of bytes requested
+ int64 requested_bytes = 1;
+
+ // Total number of bytes allocated if known
+ int64 allocated_bytes = 2;
+
+ // Name of the allocator used
+ string allocator_name = 3;
+};
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc
new file mode 100644
index 0000000000..93f68dcccb
--- /dev/null
+++ b/tensorflow/core/framework/allocator.cc
@@ -0,0 +1,25 @@
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+Allocator::~Allocator() {}
+
+class CPUAllocator : public Allocator {
+ public:
+ ~CPUAllocator() override {}
+
+ string Name() override { return "cpu"; }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override {
+ return port::aligned_malloc(num_bytes, alignment);
+ }
+
+ void DeallocateRaw(void* ptr) override { port::aligned_free(ptr); }
+};
+
+Allocator* cpu_allocator() {
+ static CPUAllocator* cpu_alloc = new CPUAllocator;
+ return cpu_alloc;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
new file mode 100644
index 0000000000..6f162a608c
--- /dev/null
+++ b/tensorflow/core/framework/allocator.h
@@ -0,0 +1,132 @@
+#ifndef TENSORFLOW_FRAMEWORK_ALLOCATOR_H_
+#define TENSORFLOW_FRAMEWORK_ALLOCATOR_H_
+
+#include <stdlib.h>
+#include <unistd.h>
+
+#include <limits>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+// Allocator is an abstract interface for allocating and deallocating
+// device memory.
+class Allocator {
+ public:
+ virtual ~Allocator();
+
+ // Return a string identifying this allocator
+ virtual string Name() = 0;
+
+ // Return an uninitialized block of memory that is "num_bytes" bytes
+ // in size. The returned pointer is guaranteed to be aligned to a
+ // multiple of "alignment" bytes.
+ // REQUIRES: "alignment" is a power of 2.
+ virtual void* AllocateRaw(size_t alignment, size_t num_bytes) = 0;
+
+ // Deallocate a block of memory pointer to by "ptr"
+ // REQUIRES: "ptr" was previously returned by a call to AllocateRaw
+ virtual void DeallocateRaw(void* ptr) = 0;
+
+ // Convenience functions to do typed allocation. Note that these functions
+ // do not invoke C++ constructors or destructors. May return NULL if the
+ // tensor has too many elements to represent in a single allocation.
+ template <typename T>
+ T* Allocate(size_t num_elements) {
+ // TODO(jeff): Do we need to allow clients to pass in alignment
+ // requirements?
+
+ if (num_elements > (std::numeric_limits<size_t>::max() / sizeof(T))) {
+ return NULL;
+ }
+
+ void* p = AllocateRaw(32 /* align to 32 byte boundary */,
+ sizeof(T) * num_elements);
+ return reinterpret_cast<T*>(p);
+ }
+
+ template <typename T>
+ void Deallocate(T* ptr) {
+ DeallocateRaw(ptr);
+ }
+
+ // Returns true if this allocator tracks the sizes of allocations.
+ // RequestedSize and AllocatedSize must be overridden if
+ // TracksAlloctionSizes is overridden to return true.
+ virtual bool TracksAllocationSizes() { return false; }
+
+ // Returns the user-requested size of the data allocated at
+ // 'ptr'. Note that the actual buffer allocated might be larger
+ // than requested, but this function returns the size requested by
+ // the user.
+ //
+ // REQUIRES: TracksAllocationSizes() is true.
+ //
+ // REQUIRES: 'ptr!=nullptr' and points to a buffer previously
+ // allocated by this allocator.
+ virtual size_t RequestedSize(void* ptr) {
+ CHECK(false) << "allocator doesn't track sizes";
+ }
+
+ // Returns the allocated size of the buffer at 'ptr' if known,
+ // otherwise returns RequestedSize(ptr). AllocatedSize(ptr) is
+ // guaranteed to be >= RequestedSize(ptr).
+ //
+ // REQUIRES: TracksAllocationSizes() is true.
+ //
+ // REQUIRES: 'ptr!=nullptr' and points to a buffer previously
+ // allocated by this allocator.
+ virtual size_t AllocatedSize(void* ptr) { return RequestedSize(ptr); }
+
+ // TODO(jeff): Maybe provide some interface to give info about
+ // current allocation state (total number of bytes available for
+ // allocation, number of bytes free on device, etc.)
+};
+
+// A tensorflow Op may need access to different kinds of memory that
+// are not simply a function of the device to which the Op has been
+// assigned. For example, an Op executing on a GPU may still need
+// to allocate CPU RAM for some purpose. Internal to the tensorflow
+// runtime we may choose to allocate CPU ram from special regions
+// that have been prepared for higher performance in some use
+// contexts, e.g. doing DMA with particular devices. For these
+// reasons, the Device interface does not expose just one memory
+// Allocator, but instead provides an accessor that takes a
+// specification of the desired memory attributes in order to select
+// an Allocator.
+//
+// NOTE: The upper 8 bits of the value are reserved for
+// device-specific uses. Implementors of a device can interpret these
+// upper 8 bits in device-specific ways, and ops implemented for those
+// devices are responsible for setting those 8 bits appropriately.
+//
+// Example use:
+// // Allocator for ordinary device memory:
+// Allocator* a = allocator(AllocatorAttributes());
+// ...
+// // Allocator for CPU RAM, regardless of where Op is executing:
+// AllocatorAttributes attr;
+// attr.set_on_host(true);
+// Allocator* a = allocator(attr);
+struct AllocatorAttributes {
+ void set_on_host(bool v) { value |= (static_cast<int>(v)); }
+ bool on_host() const { return value & 0x1; }
+ void set_nic_compatible(bool v) { value |= (static_cast<int>(v) << 1); }
+ bool nic_compatible() const { return value & (0x1 << 1); }
+ void set_gpu_compatible(bool v) { value |= (static_cast<int>(v) << 2); }
+ bool gpu_compatible() const { return value & (0x1 << 2); }
+
+ void Merge(AllocatorAttributes other) { value |= other.value; }
+
+ uint32 value = 0;
+};
+
+// Returns a trivial implementation of Allocator which uses the system
+// default malloc.
+Allocator* cpu_allocator();
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_ALLOCATOR_H_
diff --git a/tensorflow/core/framework/allocator_test.cc b/tensorflow/core/framework/allocator_test.cc
new file mode 100644
index 0000000000..6b1e52cfc4
--- /dev/null
+++ b/tensorflow/core/framework/allocator_test.cc
@@ -0,0 +1,61 @@
+#include "tensorflow/core/framework/allocator.h"
+#include <algorithm>
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+namespace tensorflow {
+
+TEST(CPUAllocatorTest, Simple) {
+ Allocator* a = cpu_allocator();
+ std::vector<void*> ptrs;
+ for (int s = 1; s < 1024; s++) {
+ void* raw = a->AllocateRaw(1, s);
+ ptrs.push_back(raw);
+ }
+ std::sort(ptrs.begin(), ptrs.end());
+ for (size_t i = 0; i < ptrs.size(); i++) {
+ if (i > 0) {
+ CHECK_NE(ptrs[i], ptrs[i - 1]); // No dups
+ }
+ a->DeallocateRaw(ptrs[i]);
+ }
+ float* t1 = a->Allocate<float>(1024);
+ double* t2 = a->Allocate<double>(1048576);
+ a->Deallocate(t1);
+ a->Deallocate(t2);
+}
+
+// Define a struct that we will use to observe behavior in the unit tests
+struct TestStruct {
+ int x; // not used just want to make sure sizeof(TestStruct) > 1
+};
+
+TEST(CPUAllocatorTest, CheckStructSize) { CHECK_GT(sizeof(TestStruct), 1); }
+
+TEST(CPUAllocatorTest, AllocateOverflowMaxSizeT) {
+ Allocator* a = cpu_allocator();
+
+ // The maximum size_t value will definitely overflow.
+ size_t count_to_allocate = std::numeric_limits<size_t>::max();
+ TestStruct* const test_pointer = a->Allocate<TestStruct>(count_to_allocate);
+
+ CHECK_EQ(test_pointer, reinterpret_cast<TestStruct*>(NULL));
+}
+
+TEST(CPUAllocatorTest, AllocateOverflowSmallest) {
+ Allocator* a = cpu_allocator();
+
+ // count_to_allocate is the smallest count that will cause overflow.
+ const size_t count_to_allocate =
+ (std::numeric_limits<size_t>::max() / sizeof(TestStruct)) + 1;
+ TestStruct* const test_pointer = a->Allocate<TestStruct>(count_to_allocate);
+
+ CHECK_EQ(test_pointer, reinterpret_cast<TestStruct*>(NULL));
+}
+
+TEST(CPUAllocatorTest, Sizes) {
+ Allocator* a = cpu_allocator();
+
+ EXPECT_EQ(false, a->TracksAllocationSizes());
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/attr_value.proto b/tensorflow/core/framework/attr_value.proto
new file mode 100644
index 0000000000..c6a9940815
--- /dev/null
+++ b/tensorflow/core/framework/attr_value.proto
@@ -0,0 +1,57 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/tensor.proto";
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/types.proto";
+
+// Protocol buffer representing the value for an attr used to configure an Op.
+// Comment indicates the corresponding attr type. Only the field matching the
+// attr type may be filled.
+message AttrValue {
+ message ListValue {
+ repeated bytes s = 2; // "list(string)"
+ repeated int64 i = 3 [packed = true]; // "list(int)"
+ repeated float f = 4 [packed = true]; // "list(float)"
+ repeated bool b = 5 [packed = true]; // "list(bool)"
+ repeated DataType type = 6 [packed = true]; // "list(type)"
+ repeated TensorShapeProto shape = 7; // "list(shape)"
+ repeated TensorProto tensor = 8; // "list(tensor)"
+ // TODO(zhifengc/josh11b): implements list(func) if needed.
+ }
+
+ oneof value {
+ bytes s = 2; // "string"
+ int64 i = 3; // "int"
+ float f = 4; // "float"
+ bool b = 5; // "bool"
+ DataType type = 6; // "type"
+ TensorShapeProto shape = 7; // "shape"
+ TensorProto tensor = 8; // "tensor"
+ ListValue list = 1; // any "list(...)"
+
+ // "func" represents a function. func.name is a function's name or
+ // a primitive op's name. func.attr.first is the name of an attr
+ // defined for that function. func.attr.second is the value for
+ // that attr in the instantiation.
+ NameAttrList func = 10;
+
+ // This is a placeholder only used in nodes defined inside a
+ // function. It indicates the attr value will be supplied when
+ // the function is instantiated. For example, let us suppose a
+ // node "N" in function "FN". "N" has an attr "A" with value
+ // placeholder = "foo". When FN is instantiated with attr "foo"
+ // set to "bar", the instantiated node N's attr A will have been
+ // given the value "bar".
+ string placeholder = 9;
+ }
+}
+
+// A list of attr names and their values. The whole list is attached
+// with a string name. E.g., MatMul[T=float].
+message NameAttrList {
+ string name = 1;
+ map<string, AttrValue> attr = 2;
+}
diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc
new file mode 100644
index 0000000000..400ef118b8
--- /dev/null
+++ b/tensorflow/core/framework/attr_value_util.cc
@@ -0,0 +1,382 @@
+#include "tensorflow/core/framework/attr_value_util.h"
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace tensorflow {
+
+namespace {
+
+string SummarizeString(const string& str) {
+ return strings::StrCat("\"", str_util::CEscape(str), "\"");
+}
+
+string SummarizeShape(const TensorShapeProto& proto) {
+ TensorShape shape(proto);
+ return shape.ShortDebugString();
+}
+
+string SummarizeTensor(const TensorProto& tensor_proto) {
+ Tensor t;
+ if (!t.FromProto(tensor_proto)) {
+ return strings::StrCat("<Invalid TensorProto: ",
+ tensor_proto.ShortDebugString(), ">");
+ }
+ return t.DebugString();
+}
+
+} // namespace
+
+string SummarizeAttrValue(const AttrValue& attr_value) {
+ switch (attr_value.value_case()) {
+ case AttrValue::kS:
+ return SummarizeString(attr_value.s());
+ case AttrValue::kI:
+ return strings::StrCat(attr_value.i());
+ case AttrValue::kF:
+ return strings::StrCat(attr_value.f());
+ case AttrValue::kB:
+ return attr_value.b() ? "true" : "false";
+ case AttrValue::kType:
+ return DataType_Name(attr_value.type());
+ case AttrValue::kShape:
+ return SummarizeShape(attr_value.shape());
+ case AttrValue::kTensor:
+ return SummarizeTensor(attr_value.tensor());
+ case AttrValue::kList: {
+ string ret = "[";
+ if (attr_value.list().s_size() > 0) {
+ for (int i = 0; i < attr_value.list().s_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, SummarizeString(attr_value.list().s(i)));
+ }
+ } else if (attr_value.list().i_size() > 0) {
+ for (int i = 0; i < attr_value.list().i_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, attr_value.list().i(i));
+ }
+ } else if (attr_value.list().f_size() > 0) {
+ for (int i = 0; i < attr_value.list().f_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, attr_value.list().f(i));
+ }
+ } else if (attr_value.list().b_size() > 0) {
+ for (int i = 0; i < attr_value.list().b_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, attr_value.list().b(i) ? "true" : "false");
+ }
+ } else if (attr_value.list().type_size() > 0) {
+ for (int i = 0; i < attr_value.list().type_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, DataType_Name(attr_value.list().type(i)));
+ }
+ } else if (attr_value.list().shape_size() > 0) {
+ for (int i = 0; i < attr_value.list().shape_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, SummarizeShape(attr_value.list().shape(i)));
+ }
+ } else if (attr_value.list().tensor_size() > 0) {
+ for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret,
+ SummarizeTensor(attr_value.list().tensor(i)));
+ }
+ }
+ strings::StrAppend(&ret, "]");
+ return ret;
+ }
+ case AttrValue::kFunc: {
+ std::vector<string> entries;
+ for (auto p : attr_value.func().attr()) {
+ entries.push_back(
+ strings::StrCat(p.first, "=", SummarizeAttrValue(p.second)));
+ }
+ sort(entries.begin(), entries.end());
+ return strings::StrCat(attr_value.func().name(), "[",
+ str_util::Join(entries, ", "), "]");
+ }
+ case AttrValue::kPlaceholder:
+ return strings::StrCat("$", attr_value.placeholder());
+ case AttrValue::VALUE_NOT_SET:
+ return "<Unknown AttrValue type>";
+ }
+ return "<Unknown AttrValue type>"; // Prevent missing return warning
+}
+
+Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
+ int num_set = 0;
+
+#define VALIDATE_FIELD(name, type_string, oneof_case) \
+ do { \
+ if (attr_value.has_list()) { \
+ if (attr_value.list().name##_size() > 0) { \
+ if (type != "list(" type_string ")") { \
+ return errors::InvalidArgument( \
+ "AttrValue had value with type list(" type_string ") when ", \
+ type, " expected"); \
+ } \
+ ++num_set; \
+ } \
+ } else if (attr_value.value_case() == AttrValue::oneof_case) { \
+ if (type != type_string) { \
+ return errors::InvalidArgument( \
+ "AttrValue had value with type " type_string " when ", type, \
+ " expected"); \
+ } \
+ ++num_set; \
+ } \
+ } while (false)
+
+ VALIDATE_FIELD(s, "string", kS);
+ VALIDATE_FIELD(i, "int", kI);
+ VALIDATE_FIELD(f, "float", kF);
+ VALIDATE_FIELD(b, "bool", kB);
+ VALIDATE_FIELD(type, "type", kType);
+ VALIDATE_FIELD(shape, "shape", kShape);
+ VALIDATE_FIELD(tensor, "tensor", kTensor);
+
+#undef VALIDATE_FIELD
+
+ if (attr_value.value_case() == AttrValue::kFunc) {
+ if (type != "func") {
+ return errors::InvalidArgument(
+ "AttrValue had value with type 'func' when ", type, " expected");
+ }
+ ++num_set;
+ }
+
+ if (attr_value.value_case() == AttrValue::kPlaceholder) {
+ return errors::InvalidArgument(
+ "AttrValue had value with unexpected type 'placeholder");
+ }
+
+ // If the attr type is 'list', we expect attr_value.has_list() to be true.
+ // However, proto3's attr_value.has_list() can be false when set to an empty
+ // list. So we simply check if has_list is false and some other field in
+ // attr_value is set to flag the error.
+ if (StringPiece(type).starts_with("list(") && !attr_value.has_list()) {
+ if (num_set) {
+ return errors::InvalidArgument(
+ "AttrValue missing value with expected type ", type);
+ } else {
+ // Indicate that we have a list, but an empty one.
+ ++num_set;
+ }
+ }
+
+ // Okay to have an empty list, but not to be missing a non-list value.
+ if (num_set == 0 && !StringPiece(type).starts_with("list(")) {
+ return errors::InvalidArgument(
+ "AttrValue missing value with expected type ", type);
+ }
+
+ // Ref types and DT_INVALID are illegal.
+ if (type == "type") {
+ if (IsRefType(attr_value.type())) {
+ return errors::InvalidArgument(
+ "AttrValue must not have reference type value of ",
+ DataTypeString(attr_value.type()));
+ }
+ if (attr_value.type() == DT_INVALID) {
+ return errors::InvalidArgument("AttrValue has invalid DataType");
+ }
+ } else if (type == "list(type)") {
+ for (auto as_int : attr_value.list().type()) {
+ const DataType dtype = static_cast<DataType>(as_int);
+ if (IsRefType(dtype)) {
+ return errors::InvalidArgument(
+ "AttrValue must not have reference type value of ",
+ DataTypeString(dtype));
+ }
+ if (dtype == DT_INVALID) {
+ return errors::InvalidArgument("AttrValue contains invalid DataType");
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
+ // Parse type.
+ string field_name;
+ bool is_list = type.Consume("list(");
+ if (type.Consume("string")) {
+ field_name = "s";
+ } else if (type.Consume("int")) {
+ field_name = "i";
+ } else if (type.Consume("float")) {
+ field_name = "f";
+ } else if (type.Consume("bool")) {
+ field_name = "b";
+ } else if (type.Consume("type")) {
+ field_name = "type";
+ } else if (type.Consume("shape")) {
+ field_name = "shape";
+ } else if (type.Consume("tensor")) {
+ field_name = "tensor";
+ } else if (type.Consume("func")) {
+ field_name = "func";
+ } else if (type.Consume("placeholder")) {
+ field_name = "placeholder";
+ } else {
+ return false;
+ }
+ if (is_list && !type.Consume(")")) {
+ return false;
+ }
+
+ // Construct a valid text proto message to parse.
+ string to_parse;
+ if (is_list) {
+ // TextFormat parser considers "i: 7" to be the same as "i: [7]",
+ // but we only want to allow list values with [].
+ if (!RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[.*\\]\\s*")) {
+ return false;
+ }
+ if (RE2::FullMatch(ToRegexpStringPiece(text), "\\s*\\[\\s*\\]\\s*")) {
+ // User wrote "[]", so return empty list without invoking the TextFormat
+ // parse which returns an error for "i: []".
+ out->Clear();
+ out->mutable_list();
+ return true;
+ }
+ to_parse = strings::StrCat("list { ", field_name, ": ", text, " }");
+ } else {
+ to_parse = strings::StrCat(field_name, ": ", text);
+ }
+
+ // Parse if we can.
+ return protobuf::TextFormat::ParseFromString(to_parse, out);
+}
+
+#define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
+ void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }
+
+#define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD) \
+ void SetAttrValue(ARG_TYPE value, AttrValue* out) { \
+ out->mutable_list(); /* create list() even if value empty */ \
+ for (const auto& v : value) { \
+ out->mutable_list()->add_##FIELD(v); \
+ } \
+ }
+
+#define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \
+ DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
+ DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<ARG_TYPE>, FIELD)
+
+DEFINE_SET_ATTR_VALUE_ONE(const string&, s)
+DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<string>, s)
+DEFINE_SET_ATTR_VALUE_BOTH(const char*, s)
+DEFINE_SET_ATTR_VALUE_BOTH(int64, i)
+DEFINE_SET_ATTR_VALUE_BOTH(int32, i)
+DEFINE_SET_ATTR_VALUE_BOTH(float, f)
+DEFINE_SET_ATTR_VALUE_BOTH(double, f)
+DEFINE_SET_ATTR_VALUE_BOTH(bool, b)
+DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b)
+DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
+DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
+
+void SetAttrValue(StringPiece value, AttrValue* out) {
+ out->set_s(value.data(), value.size());
+}
+
+void SetAttrValue(const TensorShape& value, AttrValue* out) {
+ value.AsProto(out->mutable_shape());
+}
+
+void SetAttrValue(const gtl::ArraySlice<TensorShape> value, AttrValue* out) {
+ out->mutable_list(); // Create list() even if value empty.
+ for (const auto& v : value) {
+ v.AsProto(out->mutable_list()->add_shape());
+ }
+}
+
+void SetAttrValue(const Tensor& value, AttrValue* out) {
+ if (value.NumElements() > 1) {
+ value.AsProtoTensorContent(out->mutable_tensor());
+ } else {
+ value.AsProtoField(out->mutable_tensor());
+ }
+}
+
+void SetAttrValue(const gtl::ArraySlice<Tensor> value, AttrValue* out) {
+ out->mutable_list(); // Create list() even if value empty.
+ for (const auto& v : value) {
+ if (v.NumElements() > 1) {
+ v.AsProtoTensorContent(out->mutable_list()->add_tensor());
+ } else {
+ v.AsProtoField(out->mutable_list()->add_tensor());
+ }
+ }
+}
+
+void SetAttrValue(const TensorProto& value, AttrValue* out) {
+ *out->mutable_tensor() = value;
+}
+
+void SetAttrValue(const gtl::ArraySlice<TensorProto> value, AttrValue* out) {
+ out->mutable_list(); // Create list() even if value empty.
+ for (const auto& v : value) {
+ *out->mutable_list()->add_tensor() = v;
+ }
+}
+
+void SetAttrValue(const NameAttrList& value, AttrValue* out) {
+ *out->mutable_func() = value;
+}
+
+bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
+ string a_str, b_str;
+ a.SerializeToString(&a_str);
+ b.SerializeToString(&b_str);
+ // Note: it should be safe to compare proto serializations of the attr
+ // values since at most one field should be set in each (indeed, it
+ // must be the same field if they are to compare equal).
+ // Exception: there are multiple equivalent representations of
+ // TensorProtos. So a return value of true implies a == b, but not the
+ // converse.
+ return a_str == b_str;
+}
+
+bool HasPlaceHolder(const AttrValue& val) {
+ switch (val.value_case()) {
+ case AttrValue::kFunc:
+ for (const auto& p : val.func().attr()) {
+ if (HasPlaceHolder(p.second)) {
+ return true;
+ }
+ }
+ break;
+ case AttrValue::kPlaceholder:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+bool SubstitutePlaceholders(SubstituteFunc substitute, AttrValue* value) {
+ switch (value->value_case()) {
+ case AttrValue::kFunc:
+ for (auto& p : *(value->mutable_func()->mutable_attr())) {
+ if (!SubstitutePlaceholders(substitute, &p.second)) {
+ return false;
+ }
+ }
+ break;
+ case AttrValue::kPlaceholder:
+ return substitute(value->placeholder(), value);
+ case AttrValue::VALUE_NOT_SET:
+ return false;
+ default:
+ break;
+ }
+ return true;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/attr_value_util.h b/tensorflow/core/framework/attr_value_util.h
new file mode 100644
index 0000000000..1faf74a327
--- /dev/null
+++ b/tensorflow/core/framework/attr_value_util.h
@@ -0,0 +1,83 @@
+#ifndef TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
+#define TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
+
+#include <string>
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tensorflow {
+
+// A human-readable rendering of attr_value, that is more concise than a
+// text-format proto.
+string SummarizeAttrValue(const AttrValue& attr_value);
+
+// Generates an error if attr_value doesn't have the indicated attr type.
+Status AttrValueHasType(const AttrValue& attr_value, StringPiece type);
+
+// Converts a text proto value from "text" into the the field of *out
+// indicated by "type" (e.g. from the type field of an AttrDef).
+// Examples:
+// * If type:"int" and text:"-14", then *out is set to "i: -14"
+// * If type:"list(string)" and text:"['foo', 'bar']",
+// then *out is set to "list { s: ['foo', 'bar'] }"
+// Returns true on success.
+bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out);
+
+// Sets *out based on the type of value.
+void SetAttrValue(const string& value, AttrValue* out);
+void SetAttrValue(const char* value, AttrValue* out);
+void SetAttrValue(StringPiece value, AttrValue* out);
+void SetAttrValue(int64 value, AttrValue* out);
+void SetAttrValue(int32 value, AttrValue* out);
+void SetAttrValue(float value, AttrValue* out);
+void SetAttrValue(double value, AttrValue* out);
+void SetAttrValue(bool value, AttrValue* out);
+void SetAttrValue(DataType value, AttrValue* out);
+void SetAttrValue(const TensorShape& value, AttrValue* out);
+void SetAttrValue(const Tensor& value, AttrValue* out);
+void SetAttrValue(const TensorProto& value, AttrValue* out);
+void SetAttrValue(const NameAttrList& value, AttrValue* out);
+
+void SetAttrValue(gtl::ArraySlice<string> value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<const char*> value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<int64> value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<int32> value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<float> value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<double> value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<bool> value, AttrValue* out);
+void SetAttrValue(const std::vector<bool>& value, AttrValue* out);
+void SetAttrValue(std::initializer_list<bool> value, AttrValue* out);
+void SetAttrValue(DataTypeSlice value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<TensorShape> value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<Tensor> value, AttrValue* out);
+void SetAttrValue(gtl::ArraySlice<TensorProto> value, AttrValue* out);
+
+inline void SetAttrValue(const AttrValue& value, AttrValue* out) {
+ *out = value;
+}
+
+// Returns true if a and b have the same value.
+// NOTE: May return false negatives for tensor values.
+bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b);
+
+// Returns true if "val" has a placeholder.
+bool HasPlaceHolder(const AttrValue& val);
+
+// SubstitutePlaceholders recursively replaces placeholders in 'value'
+// with an attr value by calling SubstituteFunc. Returns true iff all
+// placeholders in "value" are replaced with a value.
+//
+// SubstituteFunc is given a placeholder string. If the placeholder is
+// unknown, SubstituteFunc returns false. Otherwise, overwrites the
+// attr value and returns true.
+typedef std::function<bool(const string&, AttrValue*)> SubstituteFunc;
+bool SubstitutePlaceholders(SubstituteFunc substitute, AttrValue* value);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc
new file mode 100644
index 0000000000..bdfbf1707a
--- /dev/null
+++ b/tensorflow/core/framework/attr_value_util_test.cc
@@ -0,0 +1,91 @@
+#include "tensorflow/core/framework/attr_value_util.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+// A few helpers to construct AttrValue protos.
+template <typename T>
+AttrValue V(T value) {
+ AttrValue ret;
+ SetAttrValue(value, &ret);
+ return ret;
+}
+
+AttrValue P(const string& p) {
+ AttrValue ret;
+ ret.set_placeholder(p);
+ return ret;
+}
+
+AttrValue F(const string& name,
+ std::vector<std::pair<string, AttrValue> > pairs) {
+ AttrValue ret;
+ ret.mutable_func()->set_name(name);
+ ret.mutable_func()->mutable_attr()->insert(pairs.begin(), pairs.end());
+ return ret;
+}
+
+TEST(AttrValueUtil, HasType) {
+ // OK
+ EXPECT_TRUE(AttrValueHasType(V(123), "int").ok());
+ EXPECT_TRUE(AttrValueHasType(V(1.2), "float").ok());
+ EXPECT_TRUE(AttrValueHasType(V(DT_FLOAT), "type").ok());
+ EXPECT_TRUE(AttrValueHasType(F("f", {}), "func").ok());
+
+ // not OK.
+ EXPECT_FALSE(AttrValueHasType(V(123), "func").ok());
+ EXPECT_FALSE(AttrValueHasType(V(1.2), "int").ok());
+ EXPECT_FALSE(AttrValueHasType(V(DT_FLOAT), "shape").ok());
+ EXPECT_FALSE(AttrValueHasType(F("f", {}), "string").ok());
+ EXPECT_FALSE(AttrValueHasType(P("T"), "float").ok());
+}
+
+SubstituteFunc ReplaceTWith(const AttrValue& val) {
+ return [val](const string& placeholder, AttrValue* target) {
+ if (placeholder == "T") {
+ *target = val;
+ return true;
+ } else {
+ return false;
+ }
+ };
+}
+
+TEST(AttrValueUtil, Basic) {
+ auto v = F("MatMul", {{"dtype", P("T")},
+ {"transpose_a", V(false)},
+ {"transpose_b", V(true)},
+ {"use_cublas", V(true)}});
+ TF_CHECK_OK(AttrValueHasType(v, "func"));
+ EXPECT_TRUE(HasPlaceHolder(v));
+
+ EXPECT_EQ(
+ SummarizeAttrValue(v),
+ "MatMul[dtype=$T, transpose_a=false, transpose_b=true, use_cublas=true]");
+
+ SubstitutePlaceholders(ReplaceTWith(V(DT_FLOAT)), &v);
+ EXPECT_TRUE(!HasPlaceHolder(v));
+ EXPECT_EQ(SummarizeAttrValue(v),
+ "MatMul[dtype=DT_FLOAT, transpose_a=false, transpose_b=true, "
+ "use_cublas=true]");
+}
+
+TEST(AttrValueUtil, DeepAttr) {
+ auto v = F("f", {{"T", P("T")}});
+ TF_CHECK_OK(AttrValueHasType(v, "func"));
+ EXPECT_TRUE(HasPlaceHolder(v));
+
+ for (int i = 0; i < 3; ++i) {
+ v = F("f", {{"T", P("T")}, {"F", v}});
+ EXPECT_TRUE(HasPlaceHolder(v));
+ }
+ EXPECT_EQ(SummarizeAttrValue(v), "f[F=f[F=f[F=f[T=$T], T=$T], T=$T], T=$T]");
+
+ SubstitutePlaceholders(ReplaceTWith(F("x", {})), &v);
+ EXPECT_TRUE(!HasPlaceHolder(v));
+ EXPECT_EQ(SummarizeAttrValue(v),
+ "f[F=f[F=f[F=f[T=x[]], T=x[]], T=x[]], T=x[]]");
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/bfloat16.cc b/tensorflow/core/framework/bfloat16.cc
new file mode 100644
index 0000000000..0068283367
--- /dev/null
+++ b/tensorflow/core/framework/bfloat16.cc
@@ -0,0 +1,22 @@
+#include "tensorflow/core/framework/bfloat16.h"
+
+namespace tensorflow {
+
+void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) {
+ const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
+ uint16_t* q = reinterpret_cast<uint16_t*>(dst);
+ for (; size; p += 2, q++, size--) {
+ *q = p[1];
+ }
+}
+
+void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) {
+ const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
+ uint16_t* q = reinterpret_cast<uint16_t*>(dst);
+ for (; size; p++, q += 2, size--) {
+ q[0] = 0;
+ q[1] = *p;
+ }
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/bfloat16.h b/tensorflow/core/framework/bfloat16.h
new file mode 100644
index 0000000000..9cd260ee13
--- /dev/null
+++ b/tensorflow/core/framework/bfloat16.h
@@ -0,0 +1,58 @@
+#ifndef TENSORFLOW_FRAMEWORK_BFLOAT16_H_
+#define TENSORFLOW_FRAMEWORK_BFLOAT16_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+// Compact 16-bit encoding of floating point numbers. This representation uses
+// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. It
+// is assumed that floats are in IEEE 754 format so the representation is just
+// bits 16-31 of a single precision float.
+//
+// NOTE: The IEEE floating point standard defines a float16 format that
+// is different than this format (it has fewer bits of exponent and more
+// bits of mantissa). We don't use that format here because conversion
+// to/from 32-bit floats is more complex for that format, and the
+// conversion for this format is very simple.
+//
+// Because of the existing IEEE float16 type, we do not name our representation
+// "float16" but just use "uint16".
+//
+// <-----our 16bits float------->
+// s e e e e e e e e f f f f f f f f f f f f f f f f f f f f f f f
+// <------------------------------float-------------------------->
+// 3 3 2 2 1 1 0
+// 1 0 3 2 5 4 0
+//
+//
+// This type only supports conversion back and forth with float.
+//
+// This file must be compilable by nvcc.
+
+namespace tensorflow {
+struct bfloat16 {
+ EIGEN_DEVICE_FUNC bfloat16() {}
+ EIGEN_DEVICE_FUNC explicit bfloat16(const uint16_t v) : value(v) {}
+
+ uint16_t value;
+};
+
+// Conversion routines between an array of float and bfloat16 of
+// "size".
+void FloatToBFloat16(const float* src, bfloat16* dst, int64 size);
+void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size);
+
+} // namespace tensorflow
+
+namespace Eigen {
+template <>
+struct NumTraits<tensorflow::bfloat16> : GenericNumTraits<uint16_t> {};
+
+EIGEN_STRONG_INLINE bool operator==(const tensorflow::bfloat16 a,
+ const tensorflow::bfloat16 b) {
+ return a.value == b.value;
+}
+
+} // namespace Eigen
+
+#endif // TENSORFLOW_FRAMEWORK_BFLOAT16_H_
diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc
new file mode 100644
index 0000000000..4fe791fdeb
--- /dev/null
+++ b/tensorflow/core/framework/bfloat16_test.cc
@@ -0,0 +1,69 @@
+#include "tensorflow/core/framework/bfloat16.h"
+
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+TEST(Bfloat16Test, Simple) {
+ bfloat16 a(12);
+ EXPECT_EQ(12, a.value);
+}
+
+TEST(Bfloat16Test, Conversion) {
+ float a[100];
+ for (int i = 0; i < 100; ++i) {
+ a[i] = i + 1.25;
+ }
+ bfloat16 b[100];
+ float c[100];
+ FloatToBFloat16(a, b, 100);
+ BFloat16ToFloat(b, c, 100);
+ for (int i = 0; i < 100; ++i) {
+ // The relative error should be less than 1/(2^7) since bfloat16
+ // has 7 bits mantissa.
+ EXPECT_LE(fabs(c[i] - a[i]) / a[i], 1.0 / 128);
+ }
+}
+
+static void BM_FloatToBFloat16(int iters) {
+ testing::StopTiming();
+ static const int N = 32 << 20;
+ const int64 tot = static_cast<int64>(iters) * N;
+ testing::ItemsProcessed(tot);
+ testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
+
+ float* inp = new float[N];
+ bfloat16* out = new bfloat16[N];
+
+ testing::StartTiming();
+ while (iters--) {
+ FloatToBFloat16(inp, out, N);
+ }
+ delete[] inp;
+ delete[] out;
+}
+BENCHMARK(BM_FloatToBFloat16);
+
+static void BM_BFloat16ToFloat(int iters) {
+ testing::StopTiming();
+ static const int N = 32 << 20;
+ const int64 tot = static_cast<int64>(iters) * N;
+ testing::ItemsProcessed(tot);
+ testing::BytesProcessed(tot * (sizeof(float) + sizeof(bfloat16)));
+
+ bfloat16* inp = new bfloat16[N];
+ float* out = new float[N];
+
+ testing::StartTiming();
+ while (iters--) {
+ BFloat16ToFloat(inp, out, N);
+ }
+ delete[] inp;
+ delete[] out;
+}
+BENCHMARK(BM_BFloat16ToFloat);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc
new file mode 100644
index 0000000000..51423792a8
--- /dev/null
+++ b/tensorflow/core/framework/cancellation.cc
@@ -0,0 +1,79 @@
+#include "tensorflow/core/framework/cancellation.h"
+
+#include <vector>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+const CancellationToken CancellationManager::kInvalidToken = -1;
+
+CancellationManager::CancellationManager()
+ : is_cancelling_(false), is_cancelled_(0), next_cancellation_token_(0) {}
+
+void CancellationManager::StartCancel() {
+ std::unordered_map<CancellationToken, CancelCallback> callbacks_to_run;
+ {
+ mutex_lock l(mu_);
+ if (is_cancelled_.load(std::memory_order_relaxed) || is_cancelling_) {
+ return;
+ }
+ is_cancelling_ = true;
+ std::swap(callbacks_, callbacks_to_run);
+ }
+ // We call these callbacks without holding mu_, so that concurrent
+ // calls to DeregisterCallback, which can happen asynchronously, do
+ // not block. The callbacks remain valid because any concurrent call
+ // to DeregisterCallback will block until the
+ // cancelled_notification_ is notified.
+ for (auto key_and_value : callbacks_to_run) {
+ key_and_value.second();
+ }
+ {
+ mutex_lock l(mu_);
+ is_cancelling_ = false;
+ is_cancelled_.store(true, std::memory_order_release);
+ }
+ cancelled_notification_.Notify();
+}
+
+CancellationToken CancellationManager::get_cancellation_token() {
+ mutex_lock l(mu_);
+ return next_cancellation_token_++;
+}
+
+bool CancellationManager::RegisterCallback(CancellationToken token,
+ CancelCallback callback) {
+ mutex_lock l(mu_);
+ CHECK_LT(token, next_cancellation_token_) << "Invalid cancellation token";
+ bool should_register = !is_cancelled_ && !is_cancelling_;
+ if (should_register) {
+ std::swap(callbacks_[token], callback);
+ }
+ return should_register;
+}
+
+bool CancellationManager::DeregisterCallback(CancellationToken token) {
+ mu_.lock();
+ if (is_cancelled_) {
+ mu_.unlock();
+ return false;
+ } else if (is_cancelling_) {
+ mu_.unlock();
+ // Wait for all of the cancellation callbacks to be called. This
+ // wait ensures that the caller of DeregisterCallback does not
+ // return immediately and free objects that may be used in the
+ // execution of any currently pending callbacks in StartCancel.
+ cancelled_notification_.WaitForNotification();
+ return false;
+ } else {
+ callbacks_.erase(token);
+ mu_.unlock();
+ return true;
+ }
+}
+
+CancellationManager::~CancellationManager() { StartCancel(); }
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h
new file mode 100644
index 0000000000..feda548e97
--- /dev/null
+++ b/tensorflow/core/framework/cancellation.h
@@ -0,0 +1,121 @@
+#ifndef TENSORFLOW_FRAMEWORK_CANCELLATION_H_
+#define TENSORFLOW_FRAMEWORK_CANCELLATION_H_
+
+#include <atomic>
+#include <functional>
+#include <unordered_map>
+
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// A token that can be used to register and deregister a
+// CancelCallback with a CancellationManager.
+//
+// CancellationToken values must be created by a call to
+// CancellationManager::get_cancellation_token.
+typedef int64 CancellationToken;
+
+// A callback that is invoked when a step is cancelled.
+//
+// NOTE(mrry): See caveats about CancelCallback implementations in the
+// comment for CancellationManager::RegisterCallback.
+typedef std::function<void()> CancelCallback;
+
+class CancellationManager {
+ public:
+ // A value that won't be returned by get_cancellation_token().
+ static const CancellationToken kInvalidToken;
+
+ CancellationManager();
+ ~CancellationManager();
+
+ // Run all callbacks associated with this manager.
+ void StartCancel();
+
+ // Returns true iff StartCancel() has been called.
+ bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }
+
+ // Returns a token that must be used in calls to RegisterCallback
+ // and DeregisterCallback.
+ CancellationToken get_cancellation_token();
+
+ // Attempts to register the given callback to be invoked when this
+ // manager is cancelled. Returns true if the callback was
+ // registered; returns false if this manager was already cancelled,
+ // and the callback was not registered.
+ //
+ // If this method returns false, it is the caller's responsibility
+ // to perform any cancellation cleanup.
+ //
+ // This method is tricky to use correctly. The following usage pattern
+ // is recommended:
+ //
+ // class ObjectWithCancellableOperation {
+ // mutex mu_;
+ // void CancellableOperation(CancellationManager* cm,
+ // std::function<void(Status)> callback) {
+ // bool already_cancelled;
+ // CancellationToken token = cm->get_cancellation_token();
+ // {
+ // mutex_lock(mu_);
+ // already_cancelled = cm->RegisterCallback(
+ // [this, token]() { Cancel(token); });
+ // if (!already_cancelled) {
+ // // Issue asynchronous operation. Associate the pending operation
+ // // with `token` in some object state, or provide another way for
+ // // the Cancel method to look up the operation for cancellation.
+ // // Ensure that `cm->DeregisterCallback(token)` is called without
+ // // holding `mu_`, before `callback` is invoked.
+ // // ...
+ // }
+ // }
+ // if (already_cancelled) {
+ // callback(errors::Cancelled("Operation was cancelled"));
+ // }
+ // }
+ //
+ // void Cancel(CancellationToken token) {
+ // mutex_lock(mu_);
+ // // Take action to cancel the operation with the given cancellation
+ // // token.
+ // }
+ //
+ // NOTE(mrry): The caller should take care that (i) the calling code
+ // is robust to `callback` being invoked asynchronously (e.g. from
+ // another thread), (ii) `callback` is deregistered by a call to
+ // this->DeregisterCallback(token) when the operation completes
+ // successfully, and (iii) `callback` does not invoke any method
+ // on this cancellation manager. Furthermore, it is important that
+ // the eventual caller of the complementary DeregisterCallback does not
+ // hold any mutexes that are required by `callback`.
+ bool RegisterCallback(CancellationToken token, CancelCallback callback);
+
+ // Deregister the callback that, when registered, was associated
+ // with the given cancellation token. Returns true iff the callback
+ // was deregistered and will not be invoked; otherwise returns false
+ // after the callback has been invoked, blocking if necessary.
+ //
+ // NOTE(mrry): This method may block if cancellation is in progress.
+ // The caller of this method must not hold any mutexes that are required
+ // to invoke any cancellation callback that has been registered with this
+ // cancellation manager.
+ bool DeregisterCallback(CancellationToken token);
+
+ private:
+ bool is_cancelling_;
+ std::atomic_bool is_cancelled_;
+
+ mutex mu_;
+ Notification cancelled_notification_;
+ CancellationToken next_cancellation_token_ GUARDED_BY(mu_);
+ std::unordered_map<CancellationToken, CancelCallback> callbacks_
+ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_CANCELLATION_H_
diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc
new file mode 100644
index 0000000000..1925dd20cc
--- /dev/null
+++ b/tensorflow/core/framework/cancellation_test.cc
@@ -0,0 +1,102 @@
+#include "tensorflow/core/framework/cancellation.h"
+
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TEST(Cancellation, SimpleNoCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ bool deregistered = manager->DeregisterCallback(token);
+ EXPECT_TRUE(deregistered);
+ delete manager;
+ EXPECT_FALSE(is_cancelled);
+}
+
+TEST(Cancellation, SimpleCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ manager->StartCancel();
+ EXPECT_TRUE(is_cancelled);
+ delete manager;
+}
+
+TEST(Cancellation, CancelBeforeRegister) {
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ manager->StartCancel();
+ bool registered = manager->RegisterCallback(token, nullptr);
+ EXPECT_FALSE(registered);
+ delete manager;
+}
+
+TEST(Cancellation, DeregisterAfterCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ manager->StartCancel();
+ EXPECT_TRUE(is_cancelled);
+ bool deregistered = manager->DeregisterCallback(token);
+ EXPECT_FALSE(deregistered);
+ delete manager;
+}
+
+TEST(Cancellation, CancelMultiple) {
+ bool is_cancelled_1 = false, is_cancelled_2 = false, is_cancelled_3 = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token_1 = manager->get_cancellation_token();
+ bool registered_1 = manager->RegisterCallback(
+ token_1, [&is_cancelled_1]() { is_cancelled_1 = true; });
+ EXPECT_TRUE(registered_1);
+ auto token_2 = manager->get_cancellation_token();
+ bool registered_2 = manager->RegisterCallback(
+ token_2, [&is_cancelled_2]() { is_cancelled_2 = true; });
+ EXPECT_TRUE(registered_2);
+ EXPECT_FALSE(is_cancelled_1);
+ EXPECT_FALSE(is_cancelled_2);
+ manager->StartCancel();
+ EXPECT_TRUE(is_cancelled_1);
+ EXPECT_TRUE(is_cancelled_2);
+ EXPECT_FALSE(is_cancelled_3);
+ auto token_3 = manager->get_cancellation_token();
+ bool registered_3 = manager->RegisterCallback(
+ token_3, [&is_cancelled_3]() { is_cancelled_3 = true; });
+ EXPECT_FALSE(registered_3);
+ EXPECT_FALSE(is_cancelled_3);
+ delete manager;
+}
+
+TEST(Cancellation, IsCancelled) {
+ CancellationManager* cm = new CancellationManager();
+ thread::ThreadPool w(Env::Default(), "test", 4);
+ std::vector<Notification> done(8);
+ for (size_t i = 0; i < done.size(); ++i) {
+ Notification* n = &done[i];
+ w.Schedule([n, cm]() {
+ while (!cm->IsCancelled()) {
+ }
+ n->Notify();
+ });
+ }
+ Env::Default()->SleepForMicroseconds(1000000 /* 1 second */);
+ cm->StartCancel();
+ for (size_t i = 0; i < done.size(); ++i) {
+ done[i].WaitForNotification();
+ }
+ delete cm;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/config.proto b/tensorflow/core/framework/config.proto
new file mode 100644
index 0000000000..f0def3d6d7
--- /dev/null
+++ b/tensorflow/core/framework/config.proto
@@ -0,0 +1,61 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+message GPUOptions {
+ // A value between 0 and 1 that indicates what fraction of the
+ // available GPU memory to pre-allocate for each process. 1 means
+ // to pre-allocate all of the GPU memory, 0.5 means the process
+ // allocates ~50% of the available GPU memory.
+ double per_process_gpu_memory_fraction = 1;
+};
+
+// Session configuration parameters.
+// The system picks an appropriate values for fields that are not set.
+message ConfigProto {
+ // Map from device type name (e.g., "CPU" or "GPU" ) to maximum
+ // number of devices of that type to use. If a particular device
+ // type is not found in the map, the system picks an appropriate
+ // number.
+ map<string, int32> device_count = 1;
+
+ // The execution of an individual op (for some op types) can be
+ // parallelized on a pool of intra_op_parallelism_threads.
+ // 0 means the system picks an appropriate number.
+ int32 intra_op_parallelism_threads = 2;
+
+ // Nodes that perform blocking operations are enqueued on a pool of
+ // inter_op_parallelism_threads available in each process.
+ //
+ // 0 means the system picks an appropriate number.
+ //
+ // Note that the first Session created in the process sets the
+ // number of threads for all future sessions.
+ int32 inter_op_parallelism_threads = 5;
+
+ // Assignment of Nodes to Devices is recomputed every placement_period
+ // steps until the system warms up (at which point the recomputation
+ // typically slows down automatically).
+ int32 placement_period = 3;
+
+ // When any filters are present sessions will ignore all devices which do not
+ // match the filters. Each filter can be partially specified, e.g. "/job:ps"
+ // "/job:worker/replica:3", etc.
+ repeated string device_filters = 4;
+
+ // Options that apply to all GPUs.
+ GPUOptions gpu_options = 6;
+
+ // Whether soft placement is allowed. If allow_soft_placement is true,
+ // an op will be placed on CPU if
+ // 1. there's no GPU implementation for the OP
+ // or
+ // 2. no GPU devices are known or registered
+ // or
+ // 3. need to co-locate with reftype input(s) which are from CPU.
+ bool allow_soft_placement = 7;
+
+ // Whether device placements should be logged.
+ bool log_device_placement = 8;
+};
diff --git a/tensorflow/core/framework/control_flow.h b/tensorflow/core/framework/control_flow.h
new file mode 100644
index 0000000000..f59e0f5310
--- /dev/null
+++ b/tensorflow/core/framework/control_flow.h
@@ -0,0 +1,43 @@
+#ifndef TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
+#define TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/hash/hash.h"
+
+namespace tensorflow {
+
+const uint64 kIllegalFrameId = ~0uLL;
+const int64 kIllegalIterId = -1;
+
+// For the purpose of control flow, every tensor produced by TensorFlow is
+// conceptually tagged by a 'FrameAndIter'. FrameAndIter consists of a
+// 'frame_id' and an 'iter_id'. The tensor value it represents is produced
+// in the frame with frame_id at the iteration of iter_id.
+struct FrameAndIter {
+ uint64 frame_id = kIllegalFrameId;
+ int64 iter_id = kIllegalIterId;
+
+ FrameAndIter() {}
+
+ FrameAndIter(uint64 frame, int64 iter) {
+ frame_id = frame;
+ iter_id = iter;
+ }
+
+ bool operator==(const FrameAndIter& other) const {
+ return (frame_id == other.frame_id && iter_id == other.iter_id);
+ }
+};
+
+struct FrameAndIterHash {
+ size_t operator()(const FrameAndIter& key) const {
+ // Make sure there are no padding bytes that we don't want
+ CHECK_EQ(sizeof(uint64) + sizeof(int64), sizeof(FrameAndIter));
+ return Hash64(reinterpret_cast<const char*>(&key), sizeof(FrameAndIter));
+ }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
diff --git a/tensorflow/core/framework/device_attributes.proto b/tensorflow/core/framework/device_attributes.proto
new file mode 100644
index 0000000000..7592215d1e
--- /dev/null
+++ b/tensorflow/core/framework/device_attributes.proto
@@ -0,0 +1,35 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+// BusAdjacency identifies the ability of a device to participate in
+// maximally efficient DMA operations within the local context of a
+// process.
+//
+// This is currently ignored.
+enum BusAdjacency {
+ BUS_0 = 0;
+ BUS_1 = 1;
+ BUS_ANY = 2;
+ BUS_NUM_ADJACENCIES = 3;
+};
+
+message DeviceAttributes {
+ string name = 1;
+
+ // String representation of device_type.
+ string device_type = 2;
+
+ // Memory capacity of device in bytes.
+ int64 memory_limit = 4;
+
+ BusAdjacency bus_adjacency = 5;
+
+ // A device is assigned a global unique number each time it is
+ // initialized. "incarnation" should never be 0.
+ fixed64 incarnation = 6;
+
+ // String representation of the physical device that this device maps to.
+ string physical_device_desc = 7;
+}
diff --git a/tensorflow/core/framework/device_base.cc b/tensorflow/core/framework/device_base.cc
new file mode 100644
index 0000000000..83ad199062
--- /dev/null
+++ b/tensorflow/core/framework/device_base.cc
@@ -0,0 +1,7 @@
+#include "tensorflow/core/framework/device_base.h"
+
+namespace tensorflow {
+
+DeviceBase::~DeviceBase() {}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
new file mode 100644
index 0000000000..ed4ffc5d94
--- /dev/null
+++ b/tensorflow/core/framework/device_base.h
@@ -0,0 +1,172 @@
+#ifndef TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_
+#define TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_
+
+#include <memory>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/public/status.h"
+
+namespace Eigen {
+class ThreadPoolDevice;
+} // end namespace Eigen
+
+namespace perftools {
+namespace gputools {
+class Stream;
+} // namespace gputools
+} // namespace perftools
+
+namespace tensorflow {
+
+class Device;
+class Env;
+class EventMgr;
+
+namespace thread {
+class ThreadPool;
+}
+
+// A wrapper for an Eigen Gpu Device that includes per-op state
+class PerOpGpuDevice {
+ public:
+ virtual ~PerOpGpuDevice() {}
+ virtual const Eigen::GpuDevice& device() const = 0;
+};
+
+// A class that devices can subclass to pass around
+// Device-specific context to OpKernels.
+class DeviceContext : public core::RefCounted {
+ public:
+ ~DeviceContext() override {}
+ virtual perftools::gputools::Stream* stream() const { return nullptr; }
+ virtual void MaintainLifetimeOnStream(
+ const Tensor* t, perftools::gputools::Stream* stream) const {}
+
+ // "cpu_tensor" is a tensor on a CPU. Copies "cpu_tensor" into
+ // "device_tensor" which is on a GPU device "device". "device_tensor"
+ // must be allocated to be of the same size as "cpu_tensor".
+ virtual void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
+ Tensor* device_tensor,
+ StatusCallback done) const {
+ done(errors::Internal("Unrecognized device type in CPU-to-device Copy"));
+ }
+
+ // "device_tensor" is a tensor on a non-CPU device. Copies
+ // device_tensor into "cpu_tensor". "cpu_tensor" must be allocated
+ // to be of the same size as "device_tensor".
+ virtual void CopyDeviceTensorToCPU(const Tensor* device_tensor,
+ const string& tensor_name, Device* device,
+ Tensor* cpu_tensor, StatusCallback done) {
+ done(errors::Internal("Unrecognized device type in device-to-CPU Copy"));
+ }
+};
+
+typedef std::unordered_map<int, DeviceContext*> DeviceContextMap;
+
+class DeviceBase {
+ public:
+ explicit DeviceBase(Env* env) : env_(env) {}
+ virtual ~DeviceBase();
+
+ Env* env() const { return env_; }
+
+ // Override this to return true for devices that require an Op's
+ // compute method to save references to the temporary tensors it
+ // allocates until the Op execution completes
+ virtual bool SaveTemporaryTensors() const { return false; }
+
+ struct CpuWorkerThreads {
+ int num_threads = 0;
+ thread::ThreadPool* workers = nullptr;
+ };
+
+ // Does not take ownership.
+ void set_tensorflow_cpu_worker_threads(CpuWorkerThreads* t) {
+ cpu_worker_threads_ = t;
+ }
+
+ const CpuWorkerThreads* tensorflow_cpu_worker_threads() const {
+ CHECK(cpu_worker_threads_ != nullptr);
+ return cpu_worker_threads_;
+ }
+
+ // "stream" is used in special circumstances (such as the
+ // constructors of Ops) where there is no available OpKernelContext.
+ // "default_context" is used by OpKernelContext whenever a device does not
+ // supply a DeviceContext for an op in FillContextMap (e.g. when only
+ // using a single stream.)
+ // "event_mgr" is used to delay deallocation of temporary GPU buffers.
+ // TODO(pbar) Work out how to move this out of DeviceBase.
+ struct GpuDeviceInfo {
+ perftools::gputools::Stream* stream;
+ DeviceContext* default_context;
+ EventMgr* event_mgr;
+ };
+
+ // Does not take ownership.
+ void set_tensorflow_gpu_device_info(GpuDeviceInfo* g) {
+ gpu_device_info_ = g;
+ }
+
+ const GpuDeviceInfo* tensorflow_gpu_device_info() const {
+ return gpu_device_info_;
+ }
+
+ // Does not take ownership.
+ void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) {
+ eigen_cpu_device_ = d;
+ }
+
+ // Return the Allocator implementation to use based on the allocator
+ // attributes requested. See allocator.h for more details.
+ virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) {
+ LOG(FATAL) << "GetAllocator() is not implemented.";
+ }
+
+ const Eigen::ThreadPoolDevice* eigen_cpu_device() {
+ CHECK(eigen_cpu_device_ != nullptr);
+ return eigen_cpu_device_;
+ }
+
+ // The caller owns the returned device and must free it by calling
+ // DisposeGpuDevice below
+ virtual const PerOpGpuDevice* MakeGpuDevice(DeviceContext* /*dc*/,
+ Allocator* /*allocator*/) {
+ // The OpKernelContext calls this even for devices that do not
+ // implement an eigen_gpu_device
+ return nullptr;
+ }
+
+ virtual const DeviceAttributes& attributes() const {
+ LOG(FATAL) << "Device does not implement attributes()";
+ }
+
+ // Materializes the given TensorProto into 'tensor' stored in Device
+ // memory. Most devices will want to override this.
+ //
+ // TODO(vrv): We should be able to put this function into
+ // OpKernelContext and handle the copies from device memory via send
+ // and receive nodes, instead of requiring that each device handle
+ // the copies here as well as in copy ops.
+ virtual Status MakeTensorFromProto(const TensorProto& tensor_proto,
+ const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) {
+ return errors::Internal("Device does not implement MakeTensorFromProto()");
+ }
+
+ private:
+ Env* const env_;
+ CpuWorkerThreads* cpu_worker_threads_ = nullptr;
+ GpuDeviceInfo* gpu_device_info_ = nullptr;
+ Eigen::ThreadPoolDevice* eigen_cpu_device_ = nullptr;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_DEVICE_BASE_H_
diff --git a/tensorflow/core/framework/fake_input.cc b/tensorflow/core/framework/fake_input.cc
new file mode 100644
index 0000000000..493c35e05f
--- /dev/null
+++ b/tensorflow/core/framework/fake_input.cc
@@ -0,0 +1,214 @@
+#include "tensorflow/core/framework/fake_input.h"
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+namespace {
+
+class FakeInputImpl {
+ public:
+ FakeInputImpl(const OpDef* op_def, int in_index, const NodeDef* node_def,
+ NodeDefBuilder* builder);
+ void SetN(int n);
+ void SetDataType(DataType dt);
+ void SetTypeList(DataTypeSlice dts);
+ Status AddInputToBuilder();
+
+ private:
+ static string FakeNodeName(int in_index);
+ Status GetN(int* n) const;
+ Status GetDataType(DataType* dt) const;
+ void NSources(int n, DataType dt) const;
+ void SourceList(DataTypeSlice dts) const;
+
+ const OpDef* const op_def_;
+ const OpDef::ArgDef* const arg_;
+ const string in_node_;
+ const NodeDef* const node_def_;
+ NodeDefBuilder* const builder_;
+
+ bool n_specified_;
+ int n_;
+ bool dt_specified_;
+ DataType dt_;
+ bool dts_specified_;
+ DataTypeSlice dts_;
+};
+
+FakeInputImpl::FakeInputImpl(const OpDef* op_def, int in_index,
+ const NodeDef* node_def, NodeDefBuilder* builder)
+ : op_def_(op_def),
+ arg_(&op_def->input_arg(in_index)),
+ in_node_(FakeNodeName(in_index)),
+ node_def_(node_def),
+ builder_(builder),
+ n_specified_(false),
+ dt_specified_(false),
+ dts_specified_(false) {}
+
+void FakeInputImpl::SetN(int n) {
+ n_specified_ = true;
+ n_ = n;
+}
+
+void FakeInputImpl::SetDataType(DataType dt) {
+ dt_specified_ = true;
+ dt_ = dt;
+}
+
+void FakeInputImpl::SetTypeList(DataTypeSlice dts) {
+ dts_specified_ = true;
+ dts_ = dts;
+}
+
+Status FakeInputImpl::AddInputToBuilder() {
+ if (dts_specified_) {
+ SourceList(dts_);
+
+ } else if (n_specified_ || !arg_->number_attr().empty()) {
+ int n;
+ TF_RETURN_IF_ERROR(GetN(&n));
+
+ DataType dt;
+ if (n > 0) {
+ TF_RETURN_IF_ERROR(GetDataType(&dt));
+ } else {
+ dt = DT_FLOAT;
+ }
+
+ NSources(n, dt);
+ } else {
+ if (!dt_specified_ && !arg_->type_list_attr().empty()) {
+ DataTypeVector dts;
+ Status status =
+ GetNodeAttr(*node_def_, arg_->type_list_attr(), &dts);
+ if (!status.ok()) {
+ return errors::InvalidArgument(
+ "Could not infer list of types for input '", arg_->name(), "': ",
+ status.error_message());
+ }
+ SourceList(dts);
+ return Status::OK();
+ }
+
+ DataType dt;
+ TF_RETURN_IF_ERROR(GetDataType(&dt));
+ builder_->Input(in_node_, 0, dt);
+ }
+ return Status::OK();
+}
+
+// static
+string FakeInputImpl::FakeNodeName(int in_index) {
+ char c = 'a' + (in_index % 26);
+ return string(&c, 1);
+}
+
+Status FakeInputImpl::GetN(int* n) const {
+ if (n_specified_) {
+ *n = n_;
+ } else {
+ Status status = GetNodeAttr(*node_def_, arg_->number_attr(), n);
+ if (!status.ok()) {
+ return errors::InvalidArgument("Could not infer length of input '",
+ arg_->name(), "': ",
+ status.error_message());
+ }
+ }
+ return Status::OK();
+}
+
+Status FakeInputImpl::GetDataType(DataType* dt) const {
+ if (dt_specified_) {
+ *dt = dt_;
+ } else if (arg_->type() != DT_INVALID) {
+ *dt = arg_->type();
+ } else if (!arg_->type_attr().empty()) {
+ Status status = GetNodeAttr(*node_def_, arg_->type_attr(), dt);
+ if (!status.ok()) {
+ return errors::InvalidArgument("Could not infer type for input '",
+ arg_->name(), "': ",
+ status.error_message());
+ }
+ } else {
+ return errors::InvalidArgument("No type or type_attr field in arg '",
+ arg_->name(), "'");
+ }
+ return Status::OK();
+}
+
+void FakeInputImpl::NSources(int n, DataType dt) const {
+ std::vector<NodeDefBuilder::NodeOut> srcs;
+ srcs.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ srcs.emplace_back(in_node_, i, dt);
+ }
+ builder_->Input(srcs);
+}
+
+void FakeInputImpl::SourceList(DataTypeSlice dts) const {
+ std::vector<NodeDefBuilder::NodeOut> srcs;
+ srcs.reserve(dts.size());
+ for (size_t i = 0; i < dts.size(); ++i) {
+ srcs.emplace_back(in_node_, i, dts[i]);
+ }
+ builder_->Input(srcs);
+}
+
+} // namespace
+
+// Public interface ------------------------------------------------------------
+
+FakeInputFunctor FakeInput() {
+ return [](const OpDef& op_def, int in_index, const NodeDef& node_def,
+ NodeDefBuilder* builder) {
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
+ return impl.AddInputToBuilder();
+ };
+}
+
+FakeInputFunctor FakeInput(DataType dt) {
+ return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def,
+ NodeDefBuilder* builder) {
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
+ impl.SetDataType(dt);
+ return impl.AddInputToBuilder();
+ };
+}
+
+FakeInputFunctor FakeInput(int n) {
+ return [n](const OpDef& op_def, int in_index, const NodeDef& node_def,
+ NodeDefBuilder* builder) {
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
+ impl.SetN(n);
+ return impl.AddInputToBuilder();
+ };
+}
+
+FakeInputFunctor FakeInput(int n, DataType dt) {
+ return [n, dt](const OpDef& op_def, int in_index, const NodeDef& node_def,
+ NodeDefBuilder* builder) {
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
+ impl.SetN(n);
+ impl.SetDataType(dt);
+ return impl.AddInputToBuilder();
+ };
+}
+
+FakeInputFunctor FakeInput(DataTypeSlice dts) {
+ // Make a copy to ensure the data will still be around when the lambda is
+ // called.
+ DataTypeVector dtv(dts.begin(), dts.end());
+ return [dtv](const OpDef& op_def, int in_index, const NodeDef& node_def,
+ NodeDefBuilder* builder) {
+ FakeInputImpl impl(&op_def, in_index, &node_def, builder);
+ impl.SetTypeList(dtv);
+ return impl.AddInputToBuilder();
+ };
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/fake_input.h b/tensorflow/core/framework/fake_input.h
new file mode 100644
index 0000000000..39b38e9a59
--- /dev/null
+++ b/tensorflow/core/framework/fake_input.h
@@ -0,0 +1,25 @@
+#ifndef TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
+#define TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
+
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/types.h"
+
+namespace tensorflow {
+
+// These functions return values that may be passed to
+// NodeDefBuilder::Input() to add an input for a test. Use them when
+// you don't care about the node names/output indices providing the
+// input. They also allow you to omit the input types and/or
+// list length when they may be inferred.
+FakeInputFunctor FakeInput(); // Infer everything
+FakeInputFunctor FakeInput(DataType dt);
+FakeInputFunctor FakeInput(int n); // List of length n
+FakeInputFunctor FakeInput(int n, DataType dt);
+FakeInputFunctor FakeInput(DataTypeSlice dts);
+inline FakeInputFunctor FakeInput(std::initializer_list<DataType> dts) {
+ return FakeInput(DataTypeSlice(dts));
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
new file mode 100644
index 0000000000..b73e1ab8a9
--- /dev/null
+++ b/tensorflow/core/framework/function.cc
@@ -0,0 +1,878 @@
+#include "tensorflow/core/framework/function.h"
+
+#include <unordered_set>
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace tensorflow {
+
+REGISTER_OP("_Arg")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("index: int >= 0")
+ .Doc(R"doc(
+A graph node which represents an argument to a function.
+
+output: The argument.
+index: This argument is the index-th argument of the function.
+)doc");
+
+REGISTER_OP("_Retval")
+ .Input("input: T")
+ .Attr("T: type")
+ .Attr("index: int >= 0")
+ .Doc(R"doc(
+A graph node which represents a return value of a function.
+
+input: The return value.
+index: This return value is the index-th return value of the function.
+)doc");
+
+REGISTER_OP("_ListToArray")
+ .Input("input: Tin")
+ .Output("output: N * T")
+ .Attr("Tin: list(type)")
+ .Attr("T: type")
+ .Attr("N: int >= 1")
+ .Doc(R"doc(
+Converts a list of tensors to an array of tensors.
+)doc");
+
+REGISTER_OP("_ArrayToList")
+ .Input("input: N * T")
+ .Output("output: out_types")
+ .Attr("T: type")
+ .Attr("N: int >= 1")
+ .Attr("out_types: list(type)")
+ .Doc(R"doc(
+Converts an array of tensors to a list of tensors.
+)doc");
+
+namespace {
+
+// Extracts the actual type from "attr_values" based on its definition
+// "arg_def".
+Status ArgNumType(const InstantiateAttrValueMap& attrs,
+ const OpDef::ArgDef& arg_def, int* num, DataType* dtype) {
+ if (!arg_def.type_list_attr().empty()) {
+ return errors::Unimplemented("type_list is not supported.");
+ }
+
+ if (arg_def.number_attr().empty()) {
+ *num = 1;
+ } else {
+ const AttrValue* v = gtl::FindOrNull(attrs, arg_def.number_attr());
+ if (v == nullptr) {
+ return errors::NotFound("type attr not found: ", arg_def.type_attr());
+ }
+ *num = v->i();
+ }
+
+ if (arg_def.type() != DT_INVALID) {
+ *dtype = arg_def.type();
+ } else if (arg_def.type_attr().empty()) {
+ *dtype = DT_INVALID;
+ } else {
+ const AttrValue* v = gtl::FindOrNull(attrs, arg_def.type_attr());
+ if (v == nullptr) {
+ return errors::NotFound("type attr not found: ", arg_def.type_attr());
+ }
+ *dtype = v->type();
+ }
+ return Status::OK();
+}
+
+string Name(int node_index) { return strings::StrCat("n", node_index); }
+
+string Name(int node_index, int output_index) {
+ if (output_index == 0) {
+ return Name(node_index);
+ } else {
+ return strings::StrCat("n", node_index, ":", output_index);
+ }
+}
+
+string Dep(int node_index) { return strings::StrCat("^", Name(node_index)); }
+
+template <typename T>
+void AddAttr(const string& name, const T& val, NodeDef* ndef) {
+ SetAttrValue(val, &((*ndef->mutable_attr())[name]));
+}
+
+Status ValidateSignatureWithAttrs(const OpDef& sig,
+ const InstantiateAttrValueMap& attr_values) {
+ // attr_values should specify all attrs defined in fdef.
+ for (const auto& a : sig.attr()) {
+ if (attr_values.find(a.name()) == attr_values.end()) {
+ return errors::NotFound("Attr ", a.name(), " is not found.");
+ }
+ }
+
+ for (const auto& p : attr_values) {
+ if (HasPlaceHolder(p.second)) {
+ return errors::InvalidArgument(p.first,
+ " in attr_values is still a placeholder.");
+ }
+ }
+
+ return Status::OK();
+}
+
+// We build a small index for all names that can be used as a node's
+// input arguments.
+//
+// If is_func_arg is true, the name is a function's argument. In
+// this case, the produced graph def has gdef.node[nid ... nid +
+// num).
+//
+// Otherwise, the name is a function body's node return value. In
+// this case, the produced graph def has one node gdef.node[nid] and
+// the node's output index [idx ... idx + num) corresponds to the
+// named outputs.
+//
+// In all cases, "dtype" specifies the data type.
+struct NameInfoItem {
+ bool is_func_arg;
+ int nid;
+ int idx;
+ int num;
+ DataType dtype;
+};
+typedef std::unordered_map<string, NameInfoItem> NameInfoIndex;
+
+Status BuildInputArgIndex(const OpDef::ArgDef& arg_def,
+ const InstantiateAttrValueMap& attr_values,
+ NameInfoIndex* name_info,
+ InstantiationResult* result) {
+ int num;
+ DataType dtype;
+ TF_RETURN_IF_ERROR(ArgNumType(attr_values, arg_def, &num, &dtype));
+ CHECK_GE(num, 1);
+ GraphDef* gdef = &result->gdef;
+ int arg_index = gdef->node_size();
+ if (!name_info->insert({arg_def.name(), {true, arg_index, 0, num, dtype}})
+ .second) {
+ return errors::InvalidArgument("Duplicated arg name.");
+ }
+ // Creates "num" nodes in the gdef.
+ for (int i = 0; i < num; ++i) {
+ DCHECK_EQ(arg_index, gdef->node_size());
+ NodeDef* gnode = gdef->add_node();
+ gnode->set_name(Name(arg_index));
+ gnode->set_op("_Arg");
+ AddAttr("T", dtype, gnode);
+ AddAttr("index", arg_index, gnode);
+ result->arg_types.push_back(dtype);
+ ++arg_index;
+ }
+ return Status::OK();
+}
+
+Status BuildNodeOutputIndex(const FunctionDef::Node& node,
+ const InstantiateAttrValueMap& attrs,
+ GetFunctionSignature get_function,
+ const int arg_index, NameInfoIndex* name_info) {
+ const OpDef* node_sig = nullptr;
+ TF_RETURN_IF_ERROR(get_function(node.op(), &node_sig));
+ if (node_sig->output_arg_size() == 0) {
+ // This node produces no output.
+ if (node.ret_size() != 1) {
+ return errors::InvalidArgument("Expect one ret name.");
+ }
+ if (!name_info->insert({node.ret(0), {false, arg_index, 0, 0, DT_INVALID}})
+ .second) {
+ return errors::InvalidArgument("Duplicated ret name.");
+ }
+ return Status::OK();
+ }
+
+ // When the signature says the last return value is of list(type),
+ // i.e., it's variadic, we need to consult
+ // attrs[last_retval.type_list_attr] to determine for the last arg
+ // * the actual number of outputs;
+ // * the actual data type of outputs.
+ const int num_retval = node_sig->output_arg_size();
+ const OpDef::ArgDef& last_retval = node_sig->output_arg(num_retval - 1);
+ const bool last_retval_is_typelist = !last_retval.type_list_attr().empty();
+ if (!last_retval_is_typelist && (node.ret_size() != num_retval)) {
+ return errors::InvalidArgument("Malformed function node (#ret).");
+ }
+ int start = 0;
+ const int num_fixed_size_retval =
+ last_retval_is_typelist ? num_retval - 1 : num_retval;
+ for (int i = 0; i < num_fixed_size_retval; ++i) {
+ int num;
+ DataType dtype;
+ TF_RETURN_IF_ERROR(
+ ArgNumType(attrs, node_sig->output_arg(i), &num, &dtype));
+ if (!name_info->insert({node.ret(i), {false, arg_index, start, num, dtype}})
+ .second) {
+ return errors::InvalidArgument("Duplicated ret name.");
+ }
+ start += num;
+ }
+ if (last_retval_is_typelist) {
+ const AttrValue* typelist =
+ gtl::FindOrNull(attrs, last_retval.type_list_attr());
+ if (typelist == nullptr) {
+ return errors::InvalidArgument("Missing attr ",
+ last_retval.type_list_attr(), ".");
+ }
+ if (num_fixed_size_retval + typelist->list().type_size() !=
+ node.ret_size()) {
+ return errors::InvalidArgument("Wrong #ret: ", num_fixed_size_retval, " ",
+ typelist->list().type_size(), " ",
+ node.ret_size(), ".");
+ }
+ for (int i = 0; i < typelist->list().type_size(); ++i) {
+ if (!name_info->insert({node.ret(i),
+ {false, arg_index, start, 1,
+ typelist->list().type(i)}})
+ .second) {
+ return errors::InvalidArgument("Duplicated ret name.");
+ }
+ ++start;
+ }
+ }
+ return Status::OK();
+}
+
+Status InstantiateNode(const FunctionDef::Node& fnode,
+ const InstantiateAttrValueMap& attrs,
+ GetFunctionSignature get_function,
+ const NameInfoIndex& name_info, GraphDef* gdef) {
+ const OpDef* fnode_sig = nullptr;
+ TF_CHECK_OK(get_function(fnode.op(), &fnode_sig));
+ NodeDef* gnode = gdef->add_node();
+ gnode->set_name(Name(gdef->node_size() - 1));
+ gnode->set_op(fnode.op());
+
+ // Input
+ //
+ // When the signature says the last argument is of list(type),
+ // i.e., it's variadic, we need to consult
+ // attrs[last_arg.type_list_attr] to determine for the last arg
+ // * the number of arguments;
+ // * the data types of arguments.
+ const int num_arg = fnode_sig->input_arg_size();
+ bool last_arg_is_typelist = false;
+ if (num_arg > 0 &&
+ !fnode_sig->input_arg(num_arg - 1).type_list_attr().empty()) {
+ last_arg_is_typelist = true;
+ }
+ if (!last_arg_is_typelist && (fnode.arg_size() != num_arg)) {
+ return errors::InvalidArgument("arg.size != sig.arg.size.");
+ }
+ const int num_fixed_size_args = last_arg_is_typelist ? num_arg - 1 : num_arg;
+ for (int i = 0; i < num_fixed_size_args; ++i) {
+ int num;
+ DataType dtype;
+ TF_RETURN_IF_ERROR(
+ ArgNumType(attrs, fnode_sig->input_arg(i), &num, &dtype));
+ const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.arg(i));
+ if (item == nullptr) {
+ return errors::InvalidArgument("arg[", i, "] is not found: ",
+ fnode.ShortDebugString());
+ }
+ if (num != item->num || dtype != item->dtype) {
+ return errors::InvalidArgument("Invalid arg(", i, ") for function arg: ",
+ " ", num, "/", dtype, " vs. ", item->num,
+ "/", item->dtype, ".");
+ }
+ for (int j = 0; j < num; ++j) {
+ if (item->is_func_arg) {
+ gnode->add_input(Name(item->nid + j));
+ } else {
+ gnode->add_input(Name(item->nid, item->idx + j));
+ }
+ }
+ }
+ if (last_arg_is_typelist) {
+ AttrValue typelist;
+ for (int i = num_fixed_size_args; i < fnode.arg_size(); ++i) {
+ const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.arg(i));
+ if (item == nullptr) {
+ return errors::InvalidArgument("arg[", i, "] is not found.");
+ }
+ for (int j = 0; j < item->num; ++j) {
+ if (item->is_func_arg) {
+ gnode->add_input(Name(item->nid + j));
+ } else {
+ gnode->add_input(Name(item->nid, item->idx + j));
+ }
+ typelist.mutable_list()->add_type(item->dtype);
+ }
+ }
+
+ // 'typelist' is inferred from the inputs' data types.
+ const auto& last_arg = fnode_sig->input_arg(num_arg - 1);
+ gnode->mutable_attr()->insert({last_arg.type_list_attr(), typelist});
+ }
+
+ // Control deps.
+ for (int i = 0; i < fnode.dep_size(); ++i) {
+ const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.dep(i));
+ if (item == nullptr) {
+ return errors::InvalidArgument("dep[", i, "] is not found.");
+ }
+ gnode->add_input(Dep(item->nid));
+ }
+
+ // Attrs.
+ for (const auto& p : attrs) {
+ (*gnode->mutable_attr())[p.first] = p.second;
+ }
+
+ return Status::OK();
+}
+
+Status AddReturnNode(const OpDef::ArgDef& ret_def,
+ const InstantiateAttrValueMap& attrs,
+ const NameInfoIndex& name_info, int* ret_index,
+ InstantiationResult* result) {
+ int num;
+ DataType dtype;
+ TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &num, &dtype));
+ CHECK_GE(num, 1);
+ const NameInfoItem* item = gtl::FindOrNull(name_info, ret_def.name());
+ if (item == nullptr) {
+ return errors::InvalidArgument("ret is not found.");
+ }
+ if (num != item->num || dtype != item->dtype) {
+ return errors::InvalidArgument("Invalid ret name.");
+ }
+ GraphDef* gdef = &result->gdef;
+ for (int i = 0; i < num; ++i) {
+ NodeDef* gnode = gdef->add_node();
+ gnode->set_name(Name(gdef->node_size() - 1));
+ gnode->set_op("_Retval");
+ gnode->add_input(Name(item->nid, item->idx + i));
+ AddAttr("T", dtype, gnode);
+ AddAttr("index", (*ret_index)++, gnode);
+ result->ret_types.push_back(dtype);
+ }
+ return Status::OK();
+}
+
+// Various helpers Print(proto) to print relevant protos to ascii.
+string Print(const OpDef::ArgDef& arg) {
+ string out;
+ strings::StrAppend(&out, arg.name(), ":");
+ if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
+ if (!arg.number_attr().empty()) {
+ strings::StrAppend(&out, arg.number_attr(), "*");
+ }
+ if (arg.type() != DT_INVALID) {
+ strings::StrAppend(&out, DataTypeString(arg.type()));
+ } else {
+ strings::StrAppend(&out, arg.type_attr());
+ }
+ if (arg.is_ref()) strings::StrAppend(&out, ")");
+ return out;
+}
+
+string Print(const AttrValue& attr_value) {
+ if (attr_value.value_case() == AttrValue::kType) {
+ return DataTypeString(attr_value.type());
+ } else if ((attr_value.value_case() == AttrValue::kList) &&
+ (attr_value.list().type_size() > 0)) {
+ string ret = "{";
+ for (int i = 0; i < attr_value.list().type_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
+ }
+ strings::StrAppend(&ret, "}");
+ return ret;
+ } else if (attr_value.value_case() == AttrValue::kFunc) {
+ if (attr_value.func().attr_size() == 0) {
+ return attr_value.func().name();
+ }
+ std::vector<string> entries;
+ for (auto p : attr_value.func().attr()) {
+ entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
+ }
+ sort(entries.begin(), entries.end());
+ return strings::StrCat(attr_value.func().name(), "[",
+ str_util::Join(entries, ", "), "]");
+ }
+ return SummarizeAttrValue(attr_value);
+}
+
+string Print(const FunctionDef::Node& node) {
+ string out;
+ for (int i = 0; i < node.ret_size(); ++i) {
+ const auto& name = node.ret(i);
+ if (i > 0) strings::StrAppend(&out, ", ");
+ strings::StrAppend(&out, name);
+ }
+ strings::StrAppend(&out, " = ", node.op());
+ if (node.attr_size() > 0) {
+ std::vector<string> entries;
+ for (auto p : node.attr()) {
+ entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
+ }
+ sort(entries.begin(), entries.end());
+ strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
+ }
+ strings::StrAppend(&out, "(");
+ for (int i = 0; i < node.arg_size(); ++i) {
+ if (i > 0) strings::StrAppend(&out, ", ");
+ strings::StrAppend(&out, node.arg(i));
+ }
+ strings::StrAppend(&out, ")");
+ if (node.dep_size() > 0) {
+ strings::StrAppend(&out, " @ ");
+ for (int i = 0; i < node.dep_size(); ++i) {
+ if (i > 0) strings::StrAppend(&out, ", ");
+ strings::StrAppend(&out, node.dep(i));
+ }
+ }
+ return out;
+}
+
+string Print(const FunctionDef& fdef) {
+ string out;
+ const OpDef& sig = fdef.signature();
+ strings::StrAppend(&out, "\n", sig.name());
+ if (sig.attr_size() > 0) {
+ strings::StrAppend(&out, "[");
+ for (int i = 0; i < sig.attr_size(); ++i) {
+ const auto& a = sig.attr(i);
+ if (i > 0) strings::StrAppend(&out, ", ");
+ if (a.type() == "type") {
+ strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
+ } else {
+ strings::StrAppend(&out, a.name(), ":", a.type());
+ }
+ }
+ strings::StrAppend(&out, "]");
+ }
+ strings::StrAppend(&out, "(");
+ for (int i = 0; i < sig.input_arg_size(); ++i) {
+ if (i > 0) strings::StrAppend(&out, ", ");
+ strings::StrAppend(&out, Print(sig.input_arg(i)));
+ }
+ strings::StrAppend(&out, ") -> (");
+ for (int i = 0; i < sig.output_arg_size(); ++i) {
+ if (i > 0) strings::StrAppend(&out, ", ");
+ strings::StrAppend(&out, Print(sig.output_arg(i)));
+ }
+ strings::StrAppend(&out, ") {\n");
+ for (const auto& n : fdef.node()) {
+ strings::StrAppend(&out, " ", Print(n), "\n");
+ }
+ strings::StrAppend(&out, "}\n");
+ return out;
+}
+
+string Print(const NodeDef& n) {
+ string out;
+ strings::StrAppend(&out, n.name(), " = ", n.op());
+ if (n.attr_size() > 0) {
+ std::vector<string> entries;
+ for (auto& a : n.attr()) {
+ entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
+ }
+ sort(entries.begin(), entries.end());
+ strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
+ }
+ strings::StrAppend(&out, "(");
+ std::vector<StringPiece> dat;
+ std::vector<string> dep;
+ for (StringPiece s : n.input()) {
+ if (s.Consume("^")) {
+ dep.push_back(s.ToString());
+ } else {
+ dat.push_back(s);
+ }
+ }
+ strings::StrAppend(&out, str_util::Join(dat, ", "), ")");
+ if (!dep.empty()) {
+ strings::StrAppend(&out, " @ ", str_util::Join(dep, ", "));
+ }
+ return out;
+}
+
+string Print(const GraphDef& gdef) {
+ std::vector<const NodeDef*> arg;
+ std::vector<const NodeDef*> ret;
+ std::vector<const NodeDef*> body;
+ for (const NodeDef& n : gdef.node()) {
+ if (n.op() == "_Arg") {
+ arg.push_back(&n);
+ } else if (n.op() == "_Retval") {
+ ret.push_back(&n);
+ } else {
+ body.push_back(&n);
+ }
+ }
+ auto comp = [](const NodeDef* x, const NodeDef* y) {
+ int xi;
+ TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
+ int yi;
+ TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
+ return xi < yi;
+ };
+ sort(arg.begin(), arg.end(), comp);
+ sort(ret.begin(), ret.end(), comp);
+ string out;
+ strings::StrAppend(&out, "\n(");
+ auto get_type = [](const NodeDef& n) {
+ for (auto a : n.attr()) {
+ if (a.first == "T") {
+ return DataTypeString(a.second.type());
+ }
+ }
+ return DataTypeString(DT_INVALID);
+ };
+ for (size_t i = 0; i < arg.size(); ++i) {
+ const NodeDef* n = arg[i];
+ if (i > 0) strings::StrAppend(&out, ", ");
+ CHECK_EQ(2, n->attr_size());
+ strings::StrAppend(&out, n->name(), ":", get_type(*n));
+ }
+ strings::StrAppend(&out, ") -> (");
+ for (size_t i = 0; i < ret.size(); ++i) {
+ const NodeDef* n = ret[i];
+ if (i > 0) strings::StrAppend(&out, ", ");
+ CHECK_EQ(2, n->attr_size());
+ CHECK_EQ(1, n->input_size());
+ strings::StrAppend(&out, n->input(0), ":", get_type(*n));
+ }
+ strings::StrAppend(&out, ") {\n");
+ for (size_t i = 0; i < body.size(); ++i) {
+ strings::StrAppend(&out, " ", Print(*body[i]), "\n");
+ }
+ strings::StrAppend(&out, "}\n");
+ return out;
+}
+
+} // end namespace
+
+Status InstantiateFunction(const FunctionDef& fdef,
+ const InstantiateAttrValueMap& attr_values,
+ GetFunctionSignature get_function,
+ InstantiationResult* result) {
+ const OpDef& sig = fdef.signature();
+ GraphDef* gdef = &result->gdef;
+ gdef->Clear();
+
+ TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
+
+ auto substitute = [&attr_values](const string& name, AttrValue* val) {
+ auto iter = attr_values.find(name);
+ if (iter == attr_values.end()) {
+ return false;
+ } else {
+ *val = iter->second;
+ return true;
+ }
+ };
+
+ // Makes a copy of all attrs in fdef and substitutes placeholders.
+ // After this step, every attr is bound to a concrete value.
+ std::vector<InstantiateAttrValueMap> node_attrs;
+ node_attrs.resize(fdef.node_size());
+ for (int i = 0; i < fdef.node_size(); ++i) {
+ for (auto attr : fdef.node(i).attr()) {
+ if (!SubstitutePlaceholders(substitute, &attr.second)) {
+ return errors::InvalidArgument("Failed to bind all placeholders in ",
+ SummarizeAttrValue(attr.second));
+ }
+ CHECK(node_attrs[i].insert(attr).second);
+ }
+ }
+
+ NameInfoIndex name_info;
+ Status s;
+ for (const OpDef::ArgDef& arg_def : sig.input_arg()) {
+ s = BuildInputArgIndex(arg_def, attr_values, &name_info, result);
+ if (!s.ok()) {
+ errors::AppendToMessage(&s, " In ", Print(arg_def));
+ return s;
+ }
+ }
+ for (int i = 0; i < fdef.node_size(); ++i) {
+ s = BuildNodeOutputIndex(fdef.node(i), node_attrs[i], get_function,
+ gdef->node_size() + i, &name_info);
+ if (!s.ok()) {
+ errors::AppendToMessage(&s, " In ", Print(fdef.node(i)));
+ return s;
+ }
+ }
+
+ // Emits one gdef.node for each fdef.node.
+ for (int i = 0; i < fdef.node_size(); ++i) {
+ s = InstantiateNode(fdef.node(i), node_attrs[i], get_function, name_info,
+ gdef);
+ if (!s.ok()) {
+ errors::AppendToMessage(&s, " In ", Print(fdef.node(i)));
+ return s;
+ }
+ }
+
+ // Emits nodes for the function's return values.
+ int ret_index = 0;
+ for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
+ s = AddReturnNode(ret_def, attr_values, name_info, &ret_index, result);
+ if (!s.ok()) {
+ errors::AppendToMessage(&s, " In ", Print(ret_def));
+ return s;
+ }
+ }
+
+ return Status::OK();
+}
+
+string DebugString(const FunctionDef& func_def) { return Print(func_def); }
+
+string DebugString(const GraphDef& instantiated_func_def) {
+ return Print(instantiated_func_def);
+}
+
+string DebugStringWhole(const GraphDef& gdef) {
+ string ret;
+ for (auto fdef : gdef.library().function()) {
+ strings::StrAppend(&ret, Print(fdef));
+ }
+ strings::StrAppend(&ret, "\n");
+ for (auto ndef : gdef.node()) {
+ strings::StrAppend(&ret, Print(ndef), "\n");
+ }
+ return ret;
+}
+
+string Canonicalize(const string& funcname,
+ const InstantiateAttrValueMap& attrs) {
+ std::vector<string> entries;
+ entries.reserve(attrs.size());
+ for (auto p : attrs) {
+ entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
+ }
+ sort(entries.begin(), entries.end());
+ return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
+}
+
+FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
+ DataTypeSlice ret_types)
+ : arg_types_(arg_types.begin(), arg_types.end()),
+ ret_types_(ret_types.begin(), ret_types.end()) {
+ args_.resize(arg_types_.size());
+ rets_.resize(ret_types_.size());
+}
+
+FunctionCallFrame::~FunctionCallFrame() {}
+
+Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
+ // Input type checks.
+ if (args.size() != arg_types_.size()) {
+ return errors::InvalidArgument("Expects ", arg_types_.size(),
+ " arguments, but ", args.size(),
+ " is provided");
+ }
+ for (size_t i = 0; i < args.size(); ++i) {
+ if (arg_types_[i] != args[i].dtype()) {
+ return errors::InvalidArgument(
+ "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
+ DataTypeString(args[i].dtype()), " is provided");
+ }
+ args_[i] = args[i];
+ }
+ return Status::OK();
+}
+
+Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
+ rets->clear();
+ rets->reserve(rets_.size());
+ for (size_t i = 0; i < rets_.size(); ++i) {
+ auto item = rets_[i];
+ if (item.has_val) {
+ rets->push_back(item.val);
+ } else {
+ return errors::Internal("Retval[", i, "] does not have value");
+ }
+ }
+ return Status::OK();
+}
+
+Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
+ if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
+ return errors::OutOfRange("GetArg ", index, " is not within [0, ",
+ args_.size(), ")");
+ }
+ *val = args_[index];
+ return Status::OK();
+}
+
+Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
+ if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
+ return errors::OutOfRange("SetRetval ", index, " is not within [0, ",
+ rets_.size(), ")");
+ }
+ if (val.dtype() != ret_types_[index]) {
+ return errors::InvalidArgument(
+ "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
+ ", but ", DataTypeString(val.dtype()), " is provided.");
+ }
+ Retval* item = &rets_[index];
+ if (!item->has_val) {
+ item->has_val = true;
+ item->val = val;
+ } else {
+ return errors::Internal("Retval[", index, "] has already been set.");
+ }
+ return Status::OK();
+}
+
+FunctionLibraryDefinition::FunctionLibraryDefinition(
+ const FunctionDefLibrary& def_lib)
+ : function_defs_(def_lib.function_size()) {
+ for (auto fdef : def_lib.function()) {
+ // The latter function definition wins.
+ function_defs_[fdef.signature().name()] = fdef;
+ }
+}
+
+FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
+
+const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const {
+ auto iter = function_defs_.find(name);
+ if (iter == function_defs_.end()) {
+ return nullptr;
+ } else {
+ return &iter->second;
+ }
+}
+
+const OpDef* FunctionLibraryDefinition::LookUp(const string& op,
+ Status* status) const {
+ auto fdef = Find(op);
+ if (fdef != nullptr) {
+ return &(fdef->signature());
+ }
+ return OpRegistry::Global()->LookUp(op, status);
+}
+
+Status InstantiateFunction(const FunctionDef& fdef,
+ InstantiateAttrValueSlice attr_values,
+ GetFunctionSignature get_function,
+ InstantiationResult* result) {
+ InstantiateAttrValueMap m;
+ for (const auto& aval : attr_values) {
+ m.insert({aval.first, aval.second.proto});
+ }
+ return InstantiateFunction(fdef, m, get_function, result);
+}
+
+string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs) {
+ InstantiateAttrValueMap m;
+ for (const auto& aval : attrs) {
+ m.insert({aval.first, aval.second.proto});
+ }
+ return Canonicalize(funcname, m);
+}
+
+Status FunctionLibraryRuntime::Instantiate(const string& function_name,
+ InstantiateAttrValueSlice attrs,
+ Handle* handle) {
+ InstantiateAttrValueMap m;
+ for (const auto& aval : attrs) {
+ m.insert({aval.first, aval.second.proto});
+ }
+ return Instantiate(function_name, m, handle);
+}
+
+void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
+ if (val.size() >= 2 && val[0] == '$') {
+ proto.set_placeholder(val.data() + 1, val.size() - 1);
+ } else {
+ SetAttrValue(val, &proto);
+ }
+}
+
+FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
+ const string& name,
+ gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
+ AttrValueWrapper ret;
+ ret.proto.mutable_func()->set_name(name);
+ for (const auto& a : attrs) {
+ ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
+ }
+ return ret;
+}
+
+FunctionDef::Node FunctionDefHelper::Node::ToProto() const {
+ FunctionDef::Node n;
+ for (const string& r : this->ret) {
+ n.add_ret(r);
+ }
+ n.set_op(this->op);
+ for (const string& a : arg) {
+ n.add_arg(a);
+ }
+ for (const auto& a : this->attr) {
+ n.mutable_attr()->insert({a.first, a.second.proto});
+ }
+ for (const string& d : dep) {
+ n.add_dep(d);
+ }
+ return n;
+}
+
+/* static */
+FunctionDef FunctionDefHelper::Define(const string& name,
+ gtl::ArraySlice<string> arg_def,
+ gtl::ArraySlice<string> ret_def,
+ gtl::ArraySlice<string> attr_def,
+ gtl::ArraySlice<Node> node_def) {
+ FunctionDef fdef;
+ OpDefBuilder b(name);
+ for (const auto& a : arg_def) b.Input(a);
+ for (const auto& r : ret_def) b.Output(r);
+ for (const auto& a : attr_def) b.Attr(a);
+ TF_CHECK_OK(b.Finalize(fdef.mutable_signature()));
+ for (const auto& n : node_def) {
+ *(fdef.add_node()) = n.ToProto();
+ }
+ return fdef;
+}
+
+FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
+ gtl::ArraySlice<string> ret_def,
+ gtl::ArraySlice<string> attr_def,
+ gtl::ArraySlice<Node> node_def) {
+ return Define("_", arg_def, ret_def, attr_def, node_def);
+}
+
+namespace gradient {
+
+typedef std::unordered_map<string, Creator> OpGradFactory;
+
+OpGradFactory* GetOpGradFactory() {
+ static OpGradFactory* factory = new OpGradFactory;
+ return factory;
+}
+
+bool RegisterOp(const string& op, Creator func) {
+ CHECK(GetOpGradFactory()->insert({op, func}).second)
+ << "Duplicated gradient for " << op;
+ return true;
+}
+
+Status GetOpGradientCreator(const string& op, Creator* creator) {
+ auto fac = GetOpGradFactory();
+ auto iter = fac->find(op);
+ if (iter == fac->end()) {
+ return errors::NotFound("No gradient defined for op: ", op);
+ }
+ *creator = iter->second;
+ return Status::OK();
+}
+
+} // end namespace gradient
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
new file mode 100644
index 0000000000..1ef93a0533
--- /dev/null
+++ b/tensorflow/core/framework/function.h
@@ -0,0 +1,376 @@
+#ifndef TENSORFLOW_FRAMEWORK_FUNCTION_H_
+#define TENSORFLOW_FRAMEWORK_FUNCTION_H_
+
+#include <unordered_map>
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+class CancellationManager;
+class Node;
+class OpKernel;
+
+// FunctionDefHelper::Define is a convenient helper to construct a
+// FunctionDef proto.
+//
+// E.g.,
+// FunctionDef my_func = FunctionDefHelper::Define(
+// "my_func_name",
+// {"x:T", "y:T" /* one string per argument */},
+// {"z:T" /* one string per return value */},
+// {"T: {float, double}" /* one string per attribute */},
+// {
+// {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}
+// /* one entry per function node */
+// })
+//
+// NOTE: When we have a TFLang parser, we can add another helper:
+// FunctionDef FunctionDefHelper::Define(const string& tf_func);
+class FunctionDefHelper {
+ public:
+ // AttrValueWrapper has copy constructors for the type T so that
+ // it's easy to construct a simple AttrValue proto.
+ //
+ // If T is a string type (const char*, string, or StringPiece), and
+ // it starts with "$", we construct a AttrValue of "placeholder".
+ //
+ // E.g.,
+ // std::<string, AttrValueWrapper> x = {"T", "$T"}
+ // is a named attr value placeholder.
+ struct AttrValueWrapper {
+ AttrValue proto;
+
+ AttrValueWrapper() {}
+
+ template <typename T>
+ AttrValueWrapper(T val) { // NOLINT(runtime/explicit)
+ SetAttrValue(val, &proto);
+ }
+
+ private:
+ void InitFromString(StringPiece val);
+ };
+
+ // Constructs an AttrValue.func given the "name" and "attrs".
+ static AttrValueWrapper FunctionRef(
+ const string& name,
+ gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs);
+ static AttrValueWrapper FunctionRef(const string& name) {
+ return FunctionRef(name, {});
+ }
+
+ // Node is used to consturct FunctionDef.Node using initialization
+ // lists. E.g.,
+ // Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}; // z = x * y
+ struct Node {
+ std::vector<string> ret;
+ string op;
+ std::vector<string> arg;
+ std::vector<std::pair<string, AttrValueWrapper>> attr;
+ std::vector<string> dep;
+
+ FunctionDef::Node ToProto() const;
+ };
+
+ static FunctionDef Define(const string& function_name,
+ gtl::ArraySlice<string> arg_def,
+ gtl::ArraySlice<string> ret_def,
+ gtl::ArraySlice<string> attr_def,
+ gtl::ArraySlice<Node> node_def);
+
+ // Defines an anonymous function. I.e., its name is not relevant.
+ static FunctionDef Define(gtl::ArraySlice<string> arg_def,
+ gtl::ArraySlice<string> ret_def,
+ gtl::ArraySlice<string> attr_def,
+ gtl::ArraySlice<Node> node_def);
+
+ // Helpers to construct a constant scalar.
+ template <typename T>
+ static Node Const(const string& name, const T& val) {
+ Node n = {{name}, "Const"};
+ const DataType dtype = DataTypeToEnum<T>::value;
+ n.attr.push_back({"dtype", dtype});
+ Tensor t(dtype, TensorShape({}));
+ t.scalar<T>()() = val;
+ n.attr.push_back({"value", t});
+ return n;
+ }
+
+ template <typename T>
+ static Node Const(const string& name, gtl::ArraySlice<T> vals) {
+ Node n = {{name}, "Const"};
+ const DataType dtype = DataTypeToEnum<T>::value;
+ n.attr.push_back({"dtype", dtype});
+ int64 num = vals.size();
+ Tensor t(dtype, TensorShape({num}));
+ for (int i = 0; i < vals.size(); ++i) {
+ t.flat<T>()(i) = vals[i];
+ }
+ n.attr.push_back({"value", t});
+ return n;
+ }
+};
+
+template <>
+inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) {
+ InitFromString(val);
+}
+
+template <>
+inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(
+ const string& val) {
+ InitFromString(val);
+}
+
+template <>
+inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) {
+ InitFromString(val);
+}
+
+// Instantiate a function.
+//
+// "fdef" encodes a TF function with some attrs in fdef.signature.attr
+// containing placeholders. InstantiateFunction binds these
+// placeholders and produces an instantiated function encoded in
+// "result.gdef". The value to substitute a placeholder is given by
+// "attr_values", which is a map from a placeholder name to an attr
+// value.
+//
+// InstatiateFunction calls "get_function" to find signatures of other
+// functions and primitive ops.
+
+// Placeholders in "fdef" is substitued based on "attr_values" here.
+typedef ::tensorflow::protobuf::Map<string, AttrValue> InstantiateAttrValueMap;
+typedef gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>>
+ InstantiateAttrValueSlice;
+
+// GetFunctionSignature(func name, opdef) returns OK if the func name is found
+// and opdef is filled with a pointer to the corresponding signature
+// (a OpDef proto). Otherwise, returns an error.
+typedef std::function<Status(const string&, const OpDef**)>
+ GetFunctionSignature;
+
+struct InstantiationResult {
+ DataTypeVector arg_types;
+ DataTypeVector ret_types;
+ GraphDef gdef;
+};
+Status InstantiateFunction(const FunctionDef& fdef,
+ const InstantiateAttrValueMap& attr_values,
+ GetFunctionSignature get_function,
+ InstantiationResult* result);
+Status InstantiateFunction(const FunctionDef& fdef,
+ InstantiateAttrValueSlice attr_values,
+ GetFunctionSignature get_function,
+ InstantiationResult* result);
+
+// Returns a debug string for a function definition.
+//
+// The returned text is multiple-line. It is intended to be
+// human-readable rather than being friendly to parsers. It is _NOT_
+// intended to be the canonical string representation of "func_def".
+// Particularly, it may not include all information presented in
+// "func_def" (e.g., comments, description of the function arguments,
+// etc.)
+string DebugString(const FunctionDef& func_def);
+string DebugString(const GraphDef& instantiated_func_def);
+
+// Returns a debug string for a top level graph (the main program and
+// its supporting functions defined in its library).
+string DebugStringWhole(const GraphDef& gdef);
+
+// Returns a canonicalized string for the instantiation of the
+// function of the given "name" and attributes "attrs".
+//
+// The returned string is guaranteed to be stable within one address
+// space. But it may be change as the implementation
+// evolves. Therefore, it should not be persisted or compared across
+// address spaces.
+string Canonicalize(const string& funcname,
+ const InstantiateAttrValueMap& attrs);
+string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs);
+
+// Represents a function call frame. I.e., the data structure used to
+// pass arguments to a function and retrieve its results.
+//
+// Runtime must arrange accesses to one FunctionCallFrame s.t.
+// 1. SetArgs() happens before any GetArg();
+// 2. GetRetvals happens after all SetRetval();
+class FunctionCallFrame {
+ public:
+ FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types);
+ ~FunctionCallFrame();
+
+ // Caller methods.
+ Status SetArgs(gtl::ArraySlice<Tensor> args);
+ Status GetRetvals(std::vector<Tensor>* rets) const;
+
+ // Callee methods.
+ Status GetArg(int index, Tensor* val) const;
+ Status SetRetval(int index, const Tensor& val);
+
+ private:
+ DataTypeVector arg_types_;
+ DataTypeVector ret_types_;
+ gtl::InlinedVector<Tensor, 4> args_;
+ struct Retval {
+ bool has_val = false;
+ Tensor val;
+ };
+ gtl::InlinedVector<Retval, 4> rets_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionCallFrame);
+};
+
+// Helper to maintain a map between function names in a given
+// FunctionDefLibrary and function definitions.
+class FunctionLibraryDefinition : public OpRegistryInterface {
+ public:
+ explicit FunctionLibraryDefinition(const FunctionDefLibrary& lib_def);
+ ~FunctionLibraryDefinition() override;
+
+ // Returns nullptr if "func" is not defined in "lib_def". Otherwise,
+ // returns its definition proto.
+ const FunctionDef* Find(const string& func) const;
+
+ // OpRegistryInterface method. Useful for constructing a Graph.
+ //
+ // If "op" is defined in the library, returns its signature.
+ // Otherwise, assume "op" is a primitive op and returns its op
+ // signature.
+ const OpDef* LookUp(const string& op, Status* status) const override;
+
+ private:
+ std::unordered_map<string, FunctionDef> function_defs_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryDefinition);
+};
+
+// Forward declare. Defined in common_runtime/function.h
+struct FunctionBody;
+
+class FunctionLibraryRuntime {
+ public:
+ virtual ~FunctionLibraryRuntime() {}
+
+ // Instantiate a function with the given "attrs".
+ //
+ // Returns OK and fills in "handle" if the instantiation succeeds.
+ // Otherwise returns an error and "handle" is undefined.
+ typedef uint64 Handle;
+ virtual Status Instantiate(const string& function_name,
+ const InstantiateAttrValueMap& attrs,
+ Handle* handle) = 0;
+ Status Instantiate(const string& function_name,
+ InstantiateAttrValueSlice attrs, Handle* handle);
+
+ // Returns the function body for the instantiated function given its
+ // handle 'h'. Returns nullptr if "h" is not found.
+ //
+ // *this keeps the ownership of the returned object, which remains alive
+ // as long as *this.
+ virtual const FunctionBody* GetFunctionBody(Handle h) = 0;
+
+ // Asynchronously invokes the instantiated function identified by
+ // "handle".
+ //
+ // If function execution succeeds, "done" is called with OK and
+ // "*rets" is filled with the function's return values. Otheriwse,
+ // "done" is called with an error status.
+ //
+ // Does not take ownership of "rets".
+ struct Options {
+ CancellationManager* cancellation_manager = nullptr;
+ };
+ typedef std::function<void(const Status&)> DoneCallback;
+ virtual void Run(const Options& opts, Handle handle,
+ gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
+ DoneCallback done) = 0;
+
+ // Creates a "kernel" for the given node def "ndef".
+ //
+ // If succeeds, returns OK and the caller takes the ownership of the
+ // returned "*kernel". Otherwise, returns an error.
+ virtual Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) = 0;
+
+ // Return true iff 'function_name' is the name of a defined function.
+ virtual bool IsDefined(const string& function_name) = 0;
+};
+
+// To register a gradient function for a builtin op, one should use
+// REGISTER_OP_GRADIENT(<op_name>, <c++ grad factory>);
+//
+// Typically, the c++ grad factory is a plan function that can be
+// converted into ::tensorflow::gradient::Creator, which is
+// std::function<Status(const AttrSlice&, FunctionDef*)>.
+//
+// A ::tensorflow::gradient::Creator should populate in FunctionDef* with a
+// definition of a brain function which computate the gradient for the
+// <op_name> when the <op_name> is instantiated with the given attrs.
+//
+// E.g.,
+//
+// Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
+// bool transpose_a;
+// TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a));
+// bool transpose_b;
+// TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b));
+// DataType dtype;
+// TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype));
+// if (!transpose_a && !transpose_b) {
+// *g = FunctionDefHelper::Define(
+// "MatMulGrad",
+// {"x:T ", "y:T", "dz:T"}, // Inputs to this function
+// {"dx:T", "dy:T"}, // Outputs from this function
+// {"T: {float, double}"}, // Attributes needed by this function
+// {
+// {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}},
+// {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}},
+// {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}},
+// {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}},
+// });
+// } else {
+// ... ...
+// }
+// return Status::OK();
+// }
+//
+// NOTE: $T is substituted with the type variable "T" when the
+// gradient function MatMul is instantiated.
+//
+// TODO(zhifengc): Better documentation somewhere.
+
+// Macros to define a gradient function factory for a primitive
+// operation.
+#define REGISTER_OP_GRADIENT(name, fn) \
+ REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn)
+
+#define REGISTER_OP_NO_GRADIENT(name) \
+ REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr)
+
+#define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \
+ REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)
+
+#define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \
+ static bool unused_grad_##ctr = ::tensorflow::gradient::RegisterOp(name, fn)
+
+namespace gradient {
+// Register a gradient creator for the "op".
+typedef std::function<Status(const AttrSlice& attrs, FunctionDef*)> Creator;
+bool RegisterOp(const string& op, Creator func);
+
+// Returns OK the gradient creator for the "op" is found (may be
+// nullptr if REGISTER_OP_NO_GRADIENT is used.
+Status GetOpGradientCreator(const string& op, Creator* creator);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_FUNCTION_H_
diff --git a/tensorflow/core/framework/function.proto b/tensorflow/core/framework/function.proto
new file mode 100644
index 0000000000..4b8a26947c
--- /dev/null
+++ b/tensorflow/core/framework/function.proto
@@ -0,0 +1,68 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/attr_value.proto";
+import "tensorflow/core/framework/op_def.proto";
+
+// A library is a set of named functions.
+message FunctionDefLibrary {
+ repeated FunctionDef function = 1;
+}
+
+// A function can be instantiated when the runtime can bind every attr
+// with a value. When a GraphDef has a call to a function, it must
+// have binding for every attr defined in the signature.
+//
+// TODO(zhifengc):
+// * device spec, etc.
+message FunctionDef {
+ // The definition of the function's name, arguments, return values,
+ // attrs etc.
+ OpDef signature = 1;
+
+ // The body of the function.
+ repeated Node node = 2; // function.node.ret[*] are unique.
+
+ // A node is a multi-value assignment:
+ // (ret[0], ret[1], ...) = func(arg[0], arg[1], ...)
+ //
+ // By convention, "func" is resolved by consulting with a user-defined
+ // library first. If not resolved, "func" is assumed to be a builtin op.
+ message Node {
+ // This node produces multiple outputs. They are named ret[0],
+ // ret[1], ..., etc.
+ //
+ // REQUIRES: function.node.ret[*] are unique across all nodes.
+ // REQUIRES: ret.size == func/op def's number of output args.
+ repeated string ret = 1;
+
+ // The op/function name.
+ string op = 2;
+
+ // Arguments passed to this func/op.
+ //
+ // arg[i] must be either one of
+ // function.signature.input_args[*].name or one of
+ // function.node[*].ret[*].
+ //
+ // REQUIRES: arg.size == func/op def's number of input args.
+ repeated string arg = 3;
+
+ // Control dependencies.
+ //
+ // dep[i] must be one of function.node[*].ret[*] or one of
+ // function.signature.input_args[*].name.
+ repeated string dep = 4;
+
+ // Attrs.
+ //
+ // 'attr' maps names defined by 'func's attr defs to attr values.
+ // attr values may have placeholders which are substituted
+ // recursively by concrete values when this node is instantiated.
+ // These placeholdes must name an attr listed in the FunctionDef's
+ // signature.
+ map<string, AttrValue> attr = 5;
+ }
+}
diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc
new file mode 100644
index 0000000000..c9483fad18
--- /dev/null
+++ b/tensorflow/core/framework/function_test.cc
@@ -0,0 +1,634 @@
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/port.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+typedef FunctionDefHelper FDH;
+
+Status GetOpSig(const string& op, const OpDef** sig) {
+ Status s;
+ *sig = OpRegistry::Global()->LookUp(op, &s);
+ return s;
+}
+
+REGISTER_OP("One")
+ .Output("y: T")
+ .Attr("T: {float, double, int32, int64}")
+ .Doc(R"doc(
+Returns a tensor with a single element (1) of type T.
+
+y: A scalar in type T.
+
+)doc");
+
+static InstantiateAttrValueMap kNoAttrs;
+
+TEST(TFunc, SquarePlusOne) {
+ RequireDefaultOps();
+ auto fdef = FDH::Define(
+ // Name
+ "SquarePlusOne",
+ // Args
+ {"x: T"},
+ // Return values
+ {"y: T"},
+ // Attrs
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {// a = Square<T>(x)
+ {{"a"}, "Square", {"x"}, {{"T", "$T"}}},
+ // o = One<T>()
+ // NOTE: We can also have a Cast<Tin, Tout>(x) instead.
+ {{"o"}, "One", {}, {{"T", "$T"}}},
+ // y = Add<T>(a, o)
+ {{"y"}, "Add", {"a", "o"}, {{"T", "$T"}}}});
+
+ const char* e = R"P(
+SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
+ a = Square[T=$T](x)
+ o = One[T=$T]()
+ y = Add[T=$T](a, o)
+}
+)P";
+ EXPECT_EQ(DebugString(fdef), e);
+
+ // Instantiate one with T=float
+ InstantiationResult result;
+ TF_CHECK_OK(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result));
+ const char* e2 = R"P(
+(n0:float) -> (n3:float) {
+ n1 = Square[T=float](n0)
+ n2 = One[T=float]()
+ n3 = Add[T=float](n1, n2)
+}
+)P";
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
+ EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
+ EXPECT_EQ(DebugString(result.gdef), e2);
+}
+
+// NOTE: This is the simplest Map op. It takes a f:T->U.
+REGISTER_OP("Map")
+ .Input("x: N * T")
+ .Output("y: N * U")
+ .Attr("T: type")
+ .Attr("U: type")
+ .Attr("N: int >= 1")
+ // .Attr("func: func_name_with_attr")
+ .Doc(R"doc(
+Applies the 'func' on every input. I.e.,
+
+y[i] = func<...>(x[i])
+
+x: N tensors, each of type T;
+y: N tensors, each of type U;
+
+)doc");
+
+TEST(TFunc, AddSquared) {
+ auto fdef = FDH::Define(
+ // Name
+ "AddSquared",
+ // Args
+ {"x: N*T"},
+ // Return values
+ {"y: T"},
+ // Attrs
+ {"N:int", "T:{float, double, int32, int64}"},
+ // Nodes
+ {// a = Map<func=Square<$T>,T=$T,U=$T,N=$N>(x)
+ {{"a"},
+ "Map",
+ {"x"},
+ {{"func", FDH::FunctionRef("Square", {{"T", "$T"}})},
+ {"T", "$T"},
+ {"U", "$T"},
+ {"N", "$N"}}},
+ // y = AddN<N=$N,T=$T>(a)
+ {{"y"}, "AddN", {"a"}, {{"N", "$N"}, {"T", "$T"}}}});
+
+ const char* e = R"P(
+AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) {
+ a = Map[N=$N, T=$T, U=$T, func=Square[T=$T]](x)
+ y = AddN[N=$N, T=$T](a)
+}
+)P";
+ EXPECT_EQ(DebugString(fdef), e);
+
+ // Instantiate one with T=float
+ InstantiationResult result;
+ TF_CHECK_OK(InstantiateFunction(fdef, {{"N", 3}, {"T", DT_FLOAT}}, GetOpSig,
+ &result));
+ const char* e2 = R"P(
+(n0:float, n1:float, n2:float) -> (n4:float) {
+ n3 = Map[N=3, T=float, U=float, func=Square[T=float]](n0, n1, n2)
+ n4 = AddN[N=3, T=float](n3, n3:1, n3:2)
+}
+)P";
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT, DT_FLOAT, DT_FLOAT}));
+ EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
+ EXPECT_EQ(DebugString(result.gdef), e2);
+}
+
+TEST(TFunc, ControlDeps) {
+ auto fdef = FDH::Define(
+ // Name
+ "ControlDeps",
+ // Args
+ {"x: float"},
+ // Return values
+ {},
+ // Attrs
+ {},
+ // Nodes
+ {
+ {{"a"}, "One", {}, {{"T", DT_FLOAT}}, {"x"}},
+ {{"u"}, "NoOp", {}, {}, {"a"}},
+ {{"b"}, "One", {}, {{"T", DT_FLOAT}}, {"u"}},
+ {{"v"}, "NoOp", {}, {}, {"b"}},
+ {{"c"}, "One", {}, {{"T", DT_FLOAT}}, {"a", "v"}},
+ });
+ const char* e = R"P(
+ControlDeps(x:float) -> () {
+ a = One[T=float]() @ x
+ u = NoOp() @ a
+ b = One[T=float]() @ u
+ v = NoOp() @ b
+ c = One[T=float]() @ a, v
+}
+)P";
+ EXPECT_EQ(DebugString(fdef), e);
+
+ InstantiationResult result;
+ TF_CHECK_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
+ const char* e2 = R"P(
+(n0:float) -> () {
+ n1 = One[T=float]() @ n0
+ n2 = NoOp() @ n1
+ n3 = One[T=float]() @ n2
+ n4 = NoOp() @ n3
+ n5 = One[T=float]() @ n1, n4
+}
+)P";
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
+ EXPECT_EQ(result.ret_types, DataTypeVector({}));
+ EXPECT_EQ(DebugString(result.gdef), e2);
+}
+
+TEST(TFunc, XTimesTwo) {
+ auto expect = R"P(
+XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
+ two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
+ scale = Cast[DstT=$T, SrcT=int64](two)
+ y = Mul[T=$T](x, scale)
+}
+)P";
+ EXPECT_EQ(expect, DebugString(test::function::XTimesTwo()));
+}
+
+TEST(TFunc, WXPlusB) {
+ auto expect = R"P(
+WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
+ mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x)
+ y = Add[T=$T](mm, b)
+}
+)P";
+ EXPECT_EQ(expect, DebugString(test::function::WXPlusB()));
+}
+
+TEST(TFunc, Body_TypeList) {
+ const Tensor kZero = test::AsScalar<int32>(0);
+ auto fdef = FDH::Define(
+ // Name
+ "Test",
+ // Args
+ {"i:float"},
+ // Return values
+ {"o:float"},
+ // Attrs
+ {},
+ // Nodes
+ {{{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT32}}},
+ {{"s"}, "Split", {"zero", "i"}, {{"num_split", 4}, {"T", DT_FLOAT}}},
+ {{"a", "b", "c", "d"},
+ "_ArrayToList",
+ {"s"},
+ {{"N", 4},
+ {"T", DT_FLOAT},
+ {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}},
+ {{"l"}, "Mul", {"a", "b"}, {{"T", DT_FLOAT}}},
+ {{"r"}, "Mul", {"c", "d"}, {{"T", DT_FLOAT}}},
+ {{"x"}, "_ListToArray", {"l", "r"}, {{"N", 2}, {"T", DT_FLOAT}}},
+ {{"o"}, "AddN", {"x"}, {{"N", 2}, {"T", DT_FLOAT}}}});
+
+ const char* e = R"P(
+Test(i:float) -> (o:float) {
+ zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
+ s = Split[T=float, num_split=4](zero, i)
+ a, b, c, d = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](s)
+ l = Mul[T=float](a, b)
+ r = Mul[T=float](c, d)
+ x = _ListToArray[N=2, T=float](l, r)
+ o = AddN[N=2, T=float](x)
+}
+)P";
+ EXPECT_EQ(DebugString(fdef), e);
+
+ InstantiationResult result;
+ TF_CHECK_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
+ const char* e2 = R"P(
+(n0:float) -> (n7:float) {
+ n1 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
+ n2 = Split[T=float, num_split=4](n1, n0)
+ n3 = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](n2, n2:1, n2:2, n2:3)
+ n4 = Mul[T=float](n3, n3:1)
+ n5 = Mul[T=float](n3:2, n3:3)
+ n6 = _ListToArray[N=2, T=float, Tin={float, float}](n4, n5)
+ n7 = AddN[N=2, T=float](n6, n6:1)
+}
+)P";
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
+ EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
+ EXPECT_EQ(DebugString(result.gdef), e2);
+}
+
+REGISTER_OP("Cond")
+ .Input("input: Tin")
+ .Output("output: out_types")
+ .Attr("Tin: list(type)")
+ .Attr("out_types: list(type)")
+ .Attr("cond: func")
+ .Attr("then_branch: func")
+ .Attr("else_branch: func")
+ .Doc(R"doc(
+output = Cond(input) ? then_branch(input) : else_branch(input)
+
+cond: A function takes 'input' and returns a scalar.
+then_branch: A funcion takes 'input' and returns 'output'.
+else_branch: A funcion takes 'input' and returns 'output'.
+)doc");
+
+TEST(TFunc, Body_Array_List_Converter) {
+ auto fdef = FDH::Define(
+ // Name
+ "MySelect",
+ // Args
+ {"x:float"},
+ // Return values
+ {"z:float"},
+ // Attrs
+ {},
+ // Nodes
+ {
+ {{"y"},
+ "Cond",
+ {"x"},
+ {{"Tin", DataTypeSlice{DT_FLOAT}},
+ {"out_types", DataTypeSlice{DT_FLOAT}},
+ {"cond", FDH::FunctionRef("MyCond")},
+ {"then_branch", FDH::FunctionRef("MyThen")},
+ {"else_branch", FDH::FunctionRef("MyElse")}}},
+ {{"z"},
+ "Cond",
+ {"y", "y"},
+ {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
+ {"out_types", DataTypeSlice{DT_FLOAT}},
+ {"cond", FDH::FunctionRef("MyCond2")},
+ {"then_branch", FDH::FunctionRef("MyThen2")},
+ {"else_branch", FDH::FunctionRef("MyElse2")}}},
+ });
+
+ const char* e = R"P(
+MySelect(x:float) -> (z:float) {
+ y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x)
+ z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y, y)
+}
+)P";
+ EXPECT_EQ(DebugString(fdef), e);
+
+ InstantiationResult result;
+ TF_CHECK_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
+ const char* e2 = R"P(
+(n0:float) -> (n2:float) {
+ n1 = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](n0)
+ n2 = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](n1, n1)
+}
+)P";
+ EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
+ EXPECT_EQ(result.ret_types, DataTypeVector({DT_FLOAT}));
+ EXPECT_EQ(DebugString(result.gdef), e2);
+}
+
+static void HasError(const Status& s, const string& substr) {
+ EXPECT_TRUE(StringPiece(s.ToString()).contains(substr))
+ << s << ", expected substring " << substr;
+}
+
+TEST(InstantiateErrors, Not_Sufficient_Attrs) {
+ auto fdef =
+ FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, {{"U", DT_FLOAT}}, GetOpSig, &result),
+ "T is not found");
+}
+
+TEST(InstantiateErrors, AttrValue_Value_Placeholder) {
+ auto fdef =
+ FDH::Define("nop", {}, {}, {"T:{float, double, int32, int64}"}, {});
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, {{"T", "$bad"}}, GetOpSig, &result),
+ "T in attr_values is still a placeholder");
+}
+
+TEST(InstantiateErrors, Unbounded_Attr) {
+ auto fdef = FDH::Define("test", {}, {}, {"T:{float, double, int32, int64}"},
+ {
+ {{"a"}, "One", {}, {{"T", "$unknown"}}, {"x"}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, {{"T", DT_FLOAT}}, GetOpSig, &result),
+ "Failed to bind all placeholders");
+}
+
+TEST(InstantiateErrors, DupArgs) {
+ auto fdef = FDH::Define("test", {"x:float", "x:float"}, {}, {}, {});
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Duplicated arg name");
+}
+
+TEST(InstantiateErrors, Dup_Arg_Node_Name) {
+ auto fdef = FDH::Define("test", {"x:float"}, {}, {},
+ {
+ {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Duplicated ret name");
+}
+
+TEST(InstantiateErrors, Dup_Node_Names) {
+ auto fdef = FDH::Define("test", {"x:float"}, {}, {},
+ {
+ {{"y"}, "One", {}, {{"T", DT_FLOAT}}},
+ {{"y"}, "One", {}, {{"T", DT_FLOAT}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Duplicated ret name");
+}
+
+TEST(InstantiateErrors, Node_Signature_Mismatch_NoOp) {
+ auto fdef = FDH::Define("test", {"x:float"}, {}, {},
+ {
+ {{"y", "z"}, "NoOp", {}, {{"T", DT_FLOAT}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Expect one ret name");
+}
+
+TEST(InstantiateErrors, Node_Signature_Mismatch) {
+ auto fdef = FDH::Define("test", {"x:float"}, {}, {},
+ {
+ {{"y", "z"}, "One", {}, {{"T", DT_FLOAT}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Malformed function node (#ret)");
+}
+
+TEST(InstantiateErrors, Node_Arg_Notfound) {
+ auto fdef = FDH::Define("test", {"x:float"}, {}, {},
+ {
+ {{"y"}, "Add", {"x", "z"}, {{"T", DT_FLOAT}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "arg[1] is not found");
+}
+
+TEST(InstantiateErrors, Node_Arg_Mismatch) {
+ auto fdef = FDH::Define("test", {"x:float"}, {}, {},
+ {
+ {{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Invalid arg(0) for function arg");
+}
+
+TEST(InstantiateErrors, Node_Arg_ControlMissing) {
+ auto fdef =
+ FDH::Define("test", {"x:float"}, {}, {},
+ {
+ {{"y"}, "Add", {"x", "x"}, {{"T", DT_FLOAT}}, {"z"}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "dep[0] is not found");
+}
+
+TEST(InstantiateErrors, FuncRet_Missing) {
+ auto fdef = FDH::Define("test", {}, {"y: float"}, {},
+ {
+ {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "ret is not found");
+}
+
+TEST(InstantiateErrors, FuncRet_Mismatch) {
+ auto fdef = FDH::Define("test", {}, {"y: float"}, {},
+ {
+ {{"y"}, "One", {}, {{"T", DT_DOUBLE}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Invalid ret name.\n\t In y");
+}
+
+TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
+ auto fdef = FDH::Define(
+ // Name
+ "MySelect",
+ // Args
+ {"x: float"},
+ // Return values
+ {"y: float"},
+ // Attrs
+ {},
+ // Nodes
+ {
+ {{"y"},
+ "Cond",
+ {"x", "x"},
+ {{"tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
+ {"cond", FDH::FunctionRef("MyCond2")},
+ {"then_branch", FDH::FunctionRef("MyThen2")},
+ {"else_branch", FDH::FunctionRef("MyElse2")}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Missing attr out_types");
+}
+
+TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
+ auto fdef = FDH::Define(
+ // Name
+ "MySelect",
+ // Args
+ {"x: float"},
+ // Return values
+ {"y: float"},
+ // Attrs
+ {},
+ // Nodes
+ {
+ {{"y"},
+ "Cond",
+ {"x", "x"},
+ {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
+ {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
+ {"cond", FDH::FunctionRef("MyCond2")},
+ {"then_branch", FDH::FunctionRef("MyThen2")},
+ {"else_branch", FDH::FunctionRef("MyElse2")}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "Wrong #ret: 0 2 1");
+}
+
+TEST(InstantiateErrors, TypeList_Missing_Arg) {
+ auto fdef = FDH::Define(
+ // Name
+ "MySelect",
+ // Args
+ {"x: float"},
+ // Return values
+ {"y: float"},
+ // Attrs
+ {},
+ // Nodes
+ {
+ {{"y"},
+ "Cond",
+ {"x", "unknown"},
+ {{"out_types", DataTypeSlice{DT_FLOAT}},
+ {"cond", FDH::FunctionRef("MyCond2")},
+ {"then_branch", FDH::FunctionRef("MyThen2")},
+ {"else_branch", FDH::FunctionRef("MyElse2")}}},
+ });
+ InstantiationResult result;
+ HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
+ "arg[1] is not found");
+}
+
+TEST(FunctionCallFrame, Void_Void) {
+ FunctionCallFrame frame({}, {});
+ EXPECT_OK(frame.SetArgs({}));
+ auto a = test::AsTensor<float>({100});
+ HasError(frame.SetArgs({a}), "Invalid argument");
+ Tensor v;
+ HasError(frame.GetArg(0, &v), "Out of range");
+ HasError(frame.SetRetval(0, v), "Out of range");
+ std::vector<Tensor> rets;
+ EXPECT_OK(frame.GetRetvals(&rets));
+ EXPECT_EQ(rets.size(), 0);
+}
+
+TEST(FunctionCallFrame, Float_Float_Float) {
+ FunctionCallFrame frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT});
+ HasError(frame.SetArgs({}), "Invalid argument: Expects 2 arguments");
+ auto a = test::AsTensor<float>({100});
+ auto b = test::AsTensor<float>({200});
+ auto c = test::AsTensor<int64>({300});
+ HasError(frame.SetArgs({a, c}),
+ "Invalid argument: Expects arg[1] to be float");
+ EXPECT_OK(frame.SetArgs({a, b}));
+
+ Tensor v;
+ HasError(frame.GetArg(-1, &v), "Out of range");
+ HasError(frame.GetArg(2, &v), "Out of range");
+ EXPECT_OK(frame.GetArg(0, &v));
+ test::ExpectTensorEqual<float>(a, v);
+ EXPECT_OK(frame.GetArg(1, &v));
+ test::ExpectTensorEqual<float>(b, v);
+
+ v = test::AsTensor<float>({-100});
+ HasError(frame.SetRetval(-1, v), "Out of range");
+ HasError(frame.SetRetval(1, v), "Out of range");
+ HasError(frame.SetRetval(0, test::AsTensor<int64>({-100})),
+ "Invalid argument: Expects ret[0] to be float");
+
+ std::vector<Tensor> rets;
+ HasError(frame.GetRetvals(&rets), "does not have value");
+ EXPECT_OK(frame.SetRetval(0, v));
+ HasError(frame.SetRetval(0, v), "has already been set");
+
+ EXPECT_OK(frame.GetRetvals(&rets));
+ EXPECT_EQ(rets.size(), 1);
+ test::ExpectTensorEqual<float>(rets[0], v);
+}
+
+TEST(Canonicalize, Basic) {
+ EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT},
+ {"transpose_a", false},
+ {"transpose_b", false}}),
+ "MatMul[T=float,transpose_a=false,transpose_b=false]");
+ EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_FLOAT},
+ {"transpose_b", false},
+ {"transpose_a", false}}),
+ "MatMul[T=float,transpose_a=false,transpose_b=false]");
+ EXPECT_EQ(Canonicalize("MatMul", {{"T", DT_DOUBLE},
+ {"transpose_b", true},
+ {"transpose_a", false}}),
+ "MatMul[T=double,transpose_a=false,transpose_b=true]");
+}
+
+TEST(FunctionLibraryDefinitionTest, Find) {
+ FunctionDefLibrary proto;
+ *proto.add_function() = test::function::XTimesTwo();
+ FunctionLibraryDefinition lib_def(proto);
+
+ EXPECT_EQ(lib_def.Find("XTimes16"), nullptr);
+
+ auto expect = R"P(
+XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
+ two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
+ scale = Cast[DstT=$T, SrcT=int64](two)
+ y = Mul[T=$T](x, scale)
+}
+)P";
+ auto found = lib_def.Find("XTimesTwo");
+ ASSERT_NE(found, nullptr);
+ EXPECT_EQ(expect, DebugString(*found));
+}
+
+TEST(FunctionLibraryDefinitionTest, LookUp) {
+ FunctionDefLibrary proto;
+ *proto.add_function() = test::function::XTimesTwo();
+ FunctionLibraryDefinition lib_def(proto);
+
+ Status s;
+ EXPECT_EQ(lib_def.LookUp("XTimes16", &s), nullptr);
+
+ auto found = lib_def.LookUp("XTimesTwo", &s);
+ ASSERT_NE(found, nullptr);
+ EXPECT_EQ(found->DebugString(),
+ test::function::XTimesTwo().signature().DebugString());
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
new file mode 100644
index 0000000000..5ead947076
--- /dev/null
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -0,0 +1,146 @@
+#include "tensorflow/core/framework/function_testlib.h"
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+
+namespace tensorflow {
+namespace test {
+namespace function {
+
+typedef FunctionDefHelper FDH;
+
+GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
+ gtl::ArraySlice<FunctionDef> funcs) {
+ GraphDef g;
+ for (auto n : nodes) {
+ *(g.add_node()) = n;
+ }
+ auto lib = g.mutable_library();
+ for (auto f : funcs) {
+ *(lib->add_function()) = f;
+ }
+ return g;
+}
+
+// Helper to construct a NodeDef.
+NodeDef NDef(const string& name, const string& op,
+ gtl::ArraySlice<string> inputs,
+ gtl::ArraySlice<std::pair<string, FDH::AttrValueWrapper>> attrs,
+ const string& device) {
+ NodeDef n;
+ n.set_name(name);
+ n.set_op(op);
+ for (auto in : inputs) n.add_input(in);
+ n.set_device(device);
+ for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto});
+ return n;
+}
+
+FunctionDef NonZero() {
+ return FDH::Define(
+ // Name
+ "NonZero",
+ // Args
+ {"x:T"},
+ // Return values
+ {"y:T"},
+ // Attr def
+ {"T:{float, double, int32, int64, string}"},
+ // Nodes
+ {
+ {{"y"}, "Identity", {"x"}, {{"T", "$T"}}},
+ });
+}
+
+FunctionDef XTimesTwo() {
+ const Tensor kTwo = test::AsScalar<int64>(2);
+ return FDH::Define(
+ // Name
+ "XTimesTwo",
+ // Args
+ {"x: T"},
+ // Return values
+ {"y: T"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
+ {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+ {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}},
+ });
+}
+
+FunctionDef XTimesFour() {
+ return FDH::Define(
+ // Name
+ "XTimesFour",
+ // Args
+ {"x: T"},
+ // Return values
+ {"y: T"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}},
+ {{"y"}, "XTimesTwo", {"x2"}, {{"T", "$T"}}},
+ });
+}
+
+FunctionDef XTimes16() {
+ return FDH::Define(
+ // Name
+ "XTimes16",
+ // Args
+ {"x: T"},
+ // Return values
+ {"y: T"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}},
+ {{"y"}, "XTimesFour", {"x4"}, {{"T", "$T"}}},
+ });
+}
+
+FunctionDef WXPlusB() {
+ return FDH::Define(
+ // Name
+ "WXPlusB",
+ // Args
+ {"w: T", "x: T", "b: T"},
+ // Return values
+ {"y: T"},
+ // Attr def
+ {"T: {float, double}"},
+ // Nodes
+ {{{"mm"},
+ "MatMul",
+ {"w", "x"},
+ {{"T", "$T"},
+ {"transpose_a", false},
+ {"transpose_b", false},
+ {"_kernel", "eigen"}}},
+ {{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
+}
+
+FunctionDef Swap() {
+ return FDH::Define(
+ // Name
+ "Swap",
+ // Args
+ {"i0: T", "i1: T"},
+ // Return values
+ {"o0: T", "o1: T"},
+ // Attr def
+ {"T: {float, double}"},
+ // Nodes
+ {{{"o0"}, "Identity", {"i1"}, {{"T", "$T"}}},
+ {{"o1"}, "Identity", {"i0"}, {{"T", "$T"}}}});
+}
+
+} // end namespace function
+} // end namespace test
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h
new file mode 100644
index 0000000000..ed0446ea85
--- /dev/null
+++ b/tensorflow/core/framework/function_testlib.h
@@ -0,0 +1,53 @@
+#ifndef TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_
+#define TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace test {
+namespace function {
+
+// Helper to construct a NodeDef.
+NodeDef NDef(
+ const string& name, const string& op, gtl::ArraySlice<string> inputs,
+ gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>>
+ attrs = {},
+ const string& device = "");
+
+// Helper to construct a GraphDef proto.
+GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
+ gtl::ArraySlice<FunctionDef> funcs = {});
+
+// For testing convenience, we provide a few simple functions that can
+// be easily executed and tested.
+
+// x:T -> x * 2.
+FunctionDef XTimesTwo();
+
+// x:T -> (x * 2) * 2.
+FunctionDef XTimesFour();
+
+// x:T -> ((x * 2) * 2) * 2.
+FunctionDef XTimes16();
+
+// w:T, x:T, b:T -> MatMul(w, x) + b
+FunctionDef WXPlusB();
+
+// x:T -> x:T, T is a type which we automatically converts to a bool.
+FunctionDef NonZero();
+
+// x:T, y:T -> y:T, x:T
+FunctionDef Swap();
+
+} // end namespace function
+} // end namespace test
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_FUNCTION_TESTLIB_H_
diff --git a/tensorflow/core/framework/graph.proto b/tensorflow/core/framework/graph.proto
new file mode 100644
index 0000000000..a9bc07e88c
--- /dev/null
+++ b/tensorflow/core/framework/graph.proto
@@ -0,0 +1,103 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/attr_value.proto";
+import "tensorflow/core/framework/function.proto";
+
+// Represents the graph of operations
+// TODO(sanjay): Also want to put the following somewhere:
+// * random_seed
+// * replicas: Do we stamp them out in python itself?
+// * where to load parameters
+// * optimizer info? does it go with the parameter layers/ops?
+message GraphDef {
+ repeated NodeDef node = 1;
+
+ // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
+ //
+ // "library" provides user-defined functions.
+ //
+ // Naming:
+ // * library.function.name are in a flat namespace.
+ // NOTE: We may need to change it to be hierarchical to support
+ // different orgs. E.g.,
+ // { "/google/nn", { ... }},
+ // { "/google/vision", { ... }}
+ // { "/org_foo/module_bar", {...}}
+ // map<string, FunctionDefLib> named_lib;
+ // * If node[i].op is the name of one function in "library",
+ // node[i] is deemed as a function call. Otherwise, node[i].op
+ // must be a primitive operation supported by the runtime.
+ //
+ //
+ // Function call semantics:
+ //
+ // * The callee may start execution as soon as some of its inputs
+ // are ready. The caller may want to use Tuple() mechanism to
+ // ensure all inputs are ready in the same time.
+ //
+ // * The consumer of return values may start executing as soon as
+ // the return values the consumer depends on are ready. The
+ // consumer may want to use Tuple() mechanism to ensure the
+ // consumer does not start until all return values of the callee
+ // function are ready.
+ FunctionDefLibrary library = 2;
+};
+
+message NodeDef {
+ // The name given to this operator. Used for naming inputs,
+ // logging, visualization, etc. Unique within a single GraphDef.
+ // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*".
+ string name = 1;
+
+ // The operation name. There may be custom parameters in attrs.
+ // Op names starting with an underscore are reserved for internal use.
+ string op = 2;
+
+ // Each input is "node:src_output" with "node" being a string name and
+ // "src_output" indicating which output tensor to use from "node". If
+ // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs
+ // may optionally be followed by control inputs that have the format
+ // "^node".
+ repeated string input = 3;
+
+ // A (possibly partial) specification for the device on which this
+ // node should be placed.
+ // The expected syntax for this string is as follows:
+ //
+ // DEVICE_SPEC ::= COLOCATED_NODE | PARTIAL_SPEC
+ //
+ // COLOCATED_NODE ::= "@" NODE_NAME // See NodeDef.name above.
+ // PARTIAL_SPEC ::= ("/" CONSTRAINT) *
+ // CONSTRAINT ::= ("job:" JOB_NAME)
+ // | ("replica:" [1-9][0-9]*)
+ // | ("task:" [1-9][0-9]*)
+ // | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") )
+ //
+ // Valid values for this string include:
+ // * "@other/node" (colocate with "other/node")
+ // * "/job:worker/replica:0/task:1/gpu:3" (full specification)
+ // * "/job:worker/gpu:3" (partial specification)
+ // * "" (no specification)
+ //
+ // If the constraints do not resolve to a single device (or if this
+ // field is empty or not present), the runtime will attempt to
+ // choose a device automatically.
+ string device = 4;
+
+ // Operation-specific graph-construction-time configuration.
+ // Note that this should include all attrs defined in the
+ // corresponding OpDef, including those with a value matching
+ // the default -- this allows the default to change and makes
+ // NodeDefs easier to interpret on their own. However, if
+ // an attr with a default is not specified in this list, the
+ // default will be used.
+ // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and
+ // one of the names from the corresponding OpDef's attr field).
+ // The values must have a type matching the corresponding OpDef
+ // attr's type field.
+ // TODO(josh11b): Add some examples here showing best practices.
+ map<string, AttrValue> attr = 5;
+};
diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc
new file mode 100644
index 0000000000..1e0d280126
--- /dev/null
+++ b/tensorflow/core/framework/graph_def_util.cc
@@ -0,0 +1,25 @@
+#include "tensorflow/core/framework/graph_def_util.h"
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+string SummarizeGraphDef(const GraphDef& graph_def) {
+ string ret;
+ for (const NodeDef& node : graph_def.node()) {
+ strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
+ }
+ return ret;
+}
+
+Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) {
+ for (const NodeDef& node : graph_def.node()) {
+ TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node));
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/graph_def_util.h b/tensorflow/core/framework/graph_def_util.h
new file mode 100644
index 0000000000..7a2ec9c7a7
--- /dev/null
+++ b/tensorflow/core/framework/graph_def_util.h
@@ -0,0 +1,29 @@
+#ifndef TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
+#define TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// Produce a human-readable version of a GraphDef that is more concise
+// than a text-format proto.
+string SummarizeGraphDef(const GraphDef& graph_def);
+
+// Validates the syntax of a GraphDef provided externally.
+//
+// The following is an EBNF-style syntax for GraphDef objects. Note that
+// Node objects are actually specified as tensorflow::NodeDef protocol buffers,
+// which contain many other fields that are not (currently) validated.
+//
+// Graph = Node *
+// Node = NodeName, Inputs
+// Inputs = ( DataInput * ), ( ControlInput * )
+// DataInput = NodeName, ( ":", [1-9], [0-9] * ) ?
+// ControlInput = "^", NodeName
+// NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] *
+Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/kernel_def.proto b/tensorflow/core/framework/kernel_def.proto
new file mode 100644
index 0000000000..db7856a156
--- /dev/null
+++ b/tensorflow/core/framework/kernel_def.proto
@@ -0,0 +1,33 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/attr_value.proto";
+
+message KernelDef {
+ // Must match the name of an Op.
+ string op = 1;
+
+ // Type of device this kernel runs on.
+ string device_type = 2;
+
+ message AttrConstraint {
+ // Name of an attr from the Op.
+ string name = 1;
+
+ // A list of values that this kernel supports for this attr.
+ // Like OpDef.AttrDef.allowed_values, except for kernels instead of Ops.
+ AttrValue allowed_values = 2;
+ }
+ repeated AttrConstraint constraint = 3;
+
+ // Names of the Op's input_/output_args that reside in host memory
+ // instead of device memory.
+ repeated string host_memory_arg = 4;
+
+ // This allows experimental kernels to be registered for an op that
+ // won't be used unless the user specifies a "_kernel" attr with
+ // value matching this.
+ string label = 5;
+}
diff --git a/tensorflow/core/framework/kernel_def_builder.cc b/tensorflow/core/framework/kernel_def_builder.cc
new file mode 100644
index 0000000000..8fba883a16
--- /dev/null
+++ b/tensorflow/core/framework/kernel_def_builder.cc
@@ -0,0 +1,47 @@
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+
+KernelDefBuilder::KernelDefBuilder(const char* op_name) {
+ kernel_def_ = new KernelDef;
+ kernel_def_->set_op(op_name);
+}
+
+KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) {
+ kernel_def_->set_device_type(device_type);
+ return *this;
+}
+
+KernelDefBuilder& KernelDefBuilder::TypeConstraint(
+ const char* attr_name, gtl::ArraySlice<DataType> allowed) {
+ auto* constraint = kernel_def_->add_constraint();
+ constraint->set_name(attr_name);
+ auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
+ for (DataType dt : allowed) {
+ allowed_values->add_type(dt);
+ }
+ return *this;
+}
+
+KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name,
+ DataType allowed) {
+ auto* constraint = kernel_def_->add_constraint();
+ constraint->set_name(attr_name);
+ constraint->mutable_allowed_values()->mutable_list()->add_type(allowed);
+ return *this;
+}
+
+KernelDefBuilder& KernelDefBuilder::HostMemory(const char* arg_name) {
+ kernel_def_->add_host_memory_arg(arg_name);
+ return *this;
+}
+
+KernelDefBuilder& KernelDefBuilder::Label(const char* label) {
+ CHECK_EQ(kernel_def_->label(), "")
+ << "Trying to set a kernel's label a second time: '" << label
+ << "' in: " << kernel_def_->ShortDebugString();
+ kernel_def_->set_label(label);
+ return *this;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/kernel_def_builder.h b/tensorflow/core/framework/kernel_def_builder.h
new file mode 100644
index 0000000000..0c14d1e006
--- /dev/null
+++ b/tensorflow/core/framework/kernel_def_builder.h
@@ -0,0 +1,77 @@
+#ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
+#define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
+
+#include "tensorflow/core/framework/kernel_def.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Builder class passed to the REGISTER_KERNEL_BUILDER() macro.
+class KernelDefBuilder {
+ public:
+ // Starts with just the name field set.
+ // Caller MUST call Build() and take ownership of the result.
+ explicit KernelDefBuilder(const char* op_name);
+
+ ~KernelDefBuilder() {
+ DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
+ }
+
+ // Required: specify the type of device this kernel supports.
+ // Returns *this.
+ KernelDefBuilder& Device(const char* device_type);
+ // KernelDefBuilder& Device(DeviceType device_type);
+
+ // Specify that this kernel supports a limited set of values for a
+ // particular type or list(type) attr (a further restriction than
+ // what the Op allows).
+ // Returns *this.
+ KernelDefBuilder& TypeConstraint(const char* attr_name,
+ gtl::ArraySlice<DataType> allowed);
+
+ // Like TypeConstraint but supports just a single type.
+ KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed);
+
+ // Like TypeConstraint, but (a) gets the type from a template parameter
+ // and (b) only supports a constraint to a single type.
+ template <class T>
+ KernelDefBuilder& TypeConstraint(const char* attr_name);
+ // TODO(josh11b): Support other types of attr constraints as needed.
+
+ // Specify that this kernel requires/provides an input/output arg
+ // in host memory (instead of the default, device memory).
+ // Returns *this.
+ KernelDefBuilder& HostMemory(const char* arg_name);
+
+ // Specify that this kernel requires a particular value for the
+ // "_kernel" attr. May only be specified once. Returns *this.
+ KernelDefBuilder& Label(const char* label);
+
+ // Returns a pointer to a KernelDef with fields set based on the
+ // above calls to this instance.
+ // Caller takes ownership of the result.
+ const KernelDef* Build() {
+ KernelDef* r = kernel_def_;
+ kernel_def_ = nullptr;
+ return r;
+ }
+
+ private:
+ KernelDef* kernel_def_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder);
+};
+
+// IMPLEMENTATION
+
+template <class T>
+inline KernelDefBuilder& KernelDefBuilder::TypeConstraint(
+ const char* attr_name) {
+ return this->TypeConstraint(attr_name, DataTypeToEnum<T>::v());
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
diff --git a/tensorflow/core/framework/kernel_def_builder_test.cc b/tensorflow/core/framework/kernel_def_builder_test.cc
new file mode 100644
index 0000000000..eba7144b59
--- /dev/null
+++ b/tensorflow/core/framework/kernel_def_builder_test.cc
@@ -0,0 +1,76 @@
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+#include "tensorflow/core/framework/kernel_def.pb.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+TEST(KernelDefBuilderTest, Basic) {
+ const KernelDef* def = KernelDefBuilder("A").Device(DEVICE_CPU).Build();
+ KernelDef expected;
+ protobuf::TextFormat::ParseFromString("op: 'A' device_type: 'CPU'",
+ &expected);
+ EXPECT_EQ(def->DebugString(), expected.DebugString());
+ delete def;
+}
+
+TEST(KernelDefBuilderTest, TypeConstraint) {
+ const KernelDef* def = KernelDefBuilder("B")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .Build();
+ KernelDef expected;
+ protobuf::TextFormat::ParseFromString(R"proto(
+ op: 'B' device_type: 'GPU'
+ constraint { name: 'T' allowed_values { list { type: DT_FLOAT } } } )proto",
+ &expected);
+
+ EXPECT_EQ(def->DebugString(), expected.DebugString());
+ delete def;
+
+ def = KernelDefBuilder("C")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int32>("U")
+ .TypeConstraint<bool>("V")
+ .Build();
+
+ protobuf::TextFormat::ParseFromString(R"proto(
+ op: 'C' device_type: 'GPU'
+ constraint { name: 'U' allowed_values { list { type: DT_INT32 } } }
+ constraint { name: 'V' allowed_values { list { type: DT_BOOL } } } )proto",
+ &expected);
+ EXPECT_EQ(def->DebugString(), expected.DebugString());
+ delete def;
+
+ def = KernelDefBuilder("D")
+ .Device(DEVICE_CPU)
+ .TypeConstraint("W", {DT_DOUBLE, DT_STRING})
+ .Build();
+ protobuf::TextFormat::ParseFromString(R"proto(
+ op: 'D' device_type: 'CPU'
+ constraint { name: 'W'
+ allowed_values { list { type: [DT_DOUBLE, DT_STRING] } } } )proto",
+ &expected);
+ EXPECT_EQ(def->DebugString(), expected.DebugString());
+ delete def;
+}
+
+TEST(KernelDefBuilderTest, HostMemory) {
+ const KernelDef* def = KernelDefBuilder("E")
+ .Device(DEVICE_GPU)
+ .HostMemory("in")
+ .HostMemory("out")
+ .Build();
+ KernelDef expected;
+ protobuf::TextFormat::ParseFromString(
+ "op: 'E' device_type: 'GPU' "
+ "host_memory_arg: ['in', 'out']",
+ &expected);
+ EXPECT_EQ(def->DebugString(), expected.DebugString());
+ delete def;
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc
new file mode 100644
index 0000000000..c660b84aa0
--- /dev/null
+++ b/tensorflow/core/framework/lookup_interface.cc
@@ -0,0 +1,45 @@
+#include "tensorflow/core/framework/lookup_interface.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace lookup {
+
+Status LookupInterface::CheckKeyAndValueTensors(const Tensor& key,
+ const Tensor& value) {
+ if (key.dtype() != key_dtype()) {
+ return errors::InvalidArgument("Key must be type ", key_dtype(),
+ " but got ", key.dtype());
+ }
+ if (value.dtype() != value_dtype()) {
+ return errors::InvalidArgument("Value must be type ", value_dtype(),
+ " but got ", value.dtype());
+ }
+ if (key.NumElements() != value.NumElements()) {
+ return errors::InvalidArgument("Number of elements of key(",
+ key.NumElements(), ") and value(",
+ value.NumElements(), ") are different.");
+ }
+ if (!key.shape().IsSameSize(value.shape())) {
+ return errors::InvalidArgument("key and value have different shapes.");
+ }
+ return Status::OK();
+}
+
+Status LookupInterface::CheckFindArguments(const Tensor& key,
+ const Tensor& value,
+ const Tensor& default_value) {
+ TF_RETURN_IF_ERROR(CheckKeyAndValueTensors(key, value));
+
+ if (default_value.dtype() != value_dtype()) {
+ return errors::InvalidArgument("Default value must be type ", value_dtype(),
+ " but got ", default_value.dtype());
+ }
+ if (!TensorShapeUtils::IsScalar(default_value.shape())) {
+ return errors::InvalidArgument("Default values must be scalar.");
+ }
+ return Status::OK();
+}
+
+} // namespace lookup
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h
new file mode 100644
index 0000000000..d4036d2019
--- /dev/null
+++ b/tensorflow/core/framework/lookup_interface.h
@@ -0,0 +1,65 @@
+#ifndef TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
+#define TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
+
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+namespace lookup {
+
+// Lookup interface for batch lookups used by table lookup ops.
+class LookupInterface : public ResourceBase {
+ public:
+ // Performs batch lookups, for every element in the key tensor, Find returns
+ // the corresponding value into the values tensor.
+ // If an element is not present in the table, the given default value is used.
+
+ // For tables that require initialization, Find is available once the table
+ // is marked as initialized.
+
+ // Returns the following statuses:
+ // - OK: when the find finishes successfully.
+ // - FailedPrecondition: if the table is not initialized.
+ // - InvalidArgument: if any of the preconditions on the lookup key or value
+ // fails.
+ // - In addition, other implementations may provide another non-OK status
+ // specific to their failure modes.
+ virtual Status Find(const Tensor& keys, Tensor* values,
+ const Tensor& default_value) = 0;
+
+ // Returns the number of elements in the table.
+ virtual size_t size() const = 0;
+
+ // Returns the data type of the key.
+ virtual DataType key_dtype() const = 0;
+
+ // Returns the data type of the value.
+ virtual DataType value_dtype() const = 0;
+
+ string DebugString() override { return "A lookup table"; }
+
+ protected:
+ virtual ~LookupInterface() = default;
+
+ // Check format of the key and value tensors.
+ // Returns OK if all the following requirements are satisfied, otherwise it
+ // returns InvalidArgument:
+ // - DataType of the tensor key equals to the table key_dtype
+ // - DataType of the test value equals to the table value_dtype
+ // - key and value have the same size and shape
+ Status CheckKeyAndValueTensors(const Tensor& keys, const Tensor& values);
+
+ // Check the arguments of a find operation. Returns OK if all the following
+ // requirements are satisfied, otherwise it returns InvalidArgument:
+ // - All requirements of CheckKeyAndValueTensors
+ // - default_value type equals to the table value_dtype
+ // - default_value is scalar
+ Status CheckFindArguments(const Tensor& keys, const Tensor& values,
+ const Tensor& default_value);
+};
+
+} // namespace lookup
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc
new file mode 100644
index 0000000000..12757f153a
--- /dev/null
+++ b/tensorflow/core/framework/node_def_builder.cc
@@ -0,0 +1,194 @@
+#include "tensorflow/core/framework/node_def_builder.h"
+
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+
+NodeDefBuilder::NodeDefBuilder(const string& name, const string& op_name,
+ const OpRegistryInterface* op_registry) {
+ node_def_.set_name(name);
+ Status status;
+ op_def_ = op_registry->LookUp(op_name, &status);
+ if (op_def_ == nullptr) {
+ errors_.push_back(status.error_message());
+ inputs_specified_ = 0;
+ } else {
+ Initialize();
+ }
+}
+
+NodeDefBuilder::NodeDefBuilder(const string& name, const OpDef* op_def)
+ : op_def_(op_def) {
+ node_def_.set_name(name);
+ Initialize();
+}
+
+void NodeDefBuilder::Initialize() {
+ inputs_specified_ = 0;
+ node_def_.set_op(op_def_->name());
+}
+
+const OpDef::ArgDef* NodeDefBuilder::NextArgDef() {
+ if (!NextArgAvailable()) return nullptr;
+ return &op_def_->input_arg(inputs_specified_++);
+}
+
+bool NodeDefBuilder::NextArgAvailable() {
+ if (op_def_ == nullptr) {
+ return false;
+ } else if (inputs_specified_ >= op_def_->input_arg_size()) {
+ errors_.push_back(strings::StrCat("More Input() calls than the ",
+ op_def_->input_arg_size(),
+ " input_args"));
+ return false;
+ }
+ return true;
+}
+
+NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) {
+ if (NextArgAvailable()) {
+ Status status =
+ fake_input(*op_def_, inputs_specified_, node_def_, this);
+ if (!status.ok()) errors_.push_back(status.error_message());
+ }
+ return *this;
+}
+
+void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg,
+ const string& src_node, int src_index,
+ DataType dt) {
+ AddInput(src_node, src_index);
+
+ if (!input_arg->number_attr().empty() ||
+ !input_arg->type_list_attr().empty()) {
+ errors_.push_back(strings::StrCat("Single tensor passed to '",
+ input_arg->name(), "', expected list"));
+ return;
+ }
+
+ if (input_arg->type() != DT_INVALID) {
+ const DataType expected = MaybeAddRef(input_arg, input_arg->type());
+ VerifyInputType(input_arg, expected, dt);
+ } else {
+ VerifyInputRef(input_arg, dt);
+ Attr(input_arg->type_attr(), BaseType(dt));
+ }
+}
+
+void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg,
+ gtl::ArraySlice<NodeOut> src_list) {
+ for (const auto& node_out : src_list) {
+ AddInput(node_out.node, node_out.index);
+ }
+
+ if (!input_arg->number_attr().empty()) {
+ Attr(input_arg->number_attr(), static_cast<int64>(src_list.size()));
+ if (input_arg->type() != DT_INVALID) {
+ const DataType expected = MaybeAddRef(input_arg, input_arg->type());
+ for (const auto& node_out : src_list) {
+ VerifyInputType(input_arg, expected, node_out.data_type);
+ }
+ } else if (!src_list.empty()) {
+ const DataType base = BaseType(src_list[0].data_type);
+ Attr(input_arg->type_attr(), base);
+ const DataType expected = MaybeAddRef(input_arg, base);
+ for (const auto& node_out : src_list) {
+ VerifyInputType(input_arg, expected, node_out.data_type);
+ }
+ }
+ } else if (!input_arg->type_list_attr().empty()) {
+ DataTypeVector type_vec;
+ type_vec.reserve(src_list.size());
+ for (const auto& node_out : src_list) {
+ const DataType dt = node_out.data_type;
+ VerifyInputRef(input_arg, dt);
+ type_vec.push_back(BaseType(dt));
+ }
+ Attr(input_arg->type_list_attr(), type_vec);
+ } else {
+ errors_.push_back(strings::StrCat("List provided to input '",
+ input_arg->name(),
+ "' when single Tensor expected"));
+ }
+}
+
+void NodeDefBuilder::AddInput(const string& src_node, int src_index) {
+ if (src_node.empty()) {
+ errors_.push_back("Empty input node name");
+ } else if (src_node[0] == '^') {
+ errors_.push_back(
+ strings::StrCat("Non-control input starting with ^: ", src_node));
+ } else if (src_index > 0) {
+ node_def_.add_input(strings::StrCat(src_node, ":", src_index));
+ } else {
+ node_def_.add_input(src_node);
+ }
+}
+
+void NodeDefBuilder::VerifyInputType(const OpDef::ArgDef* input_arg,
+ DataType expected, DataType dt) {
+ if (!TypesCompatible(expected, dt)) {
+ errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ",
+ DataTypeString(dt), " expected ",
+ DataTypeString(expected)));
+ }
+}
+
+void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg,
+ DataType dt) {
+ if (input_arg->is_ref() && !IsRefType(dt)) {
+ errors_.push_back(strings::StrCat("Input '", input_arg->name(), "' passed ",
+ DataTypeString(dt),
+ " expected ref type"));
+ }
+}
+
+Status NodeDefBuilder::Finalize(NodeDef* node_def) const {
+ const std::vector<string>* errors_ptr = &errors_;
+ std::vector<string> errors_storage;
+ if (op_def_ != nullptr && inputs_specified_ < op_def_->input_arg_size()) {
+ // Since this is a const method, to add an error, we have to make
+ // a copy of the existing errors.
+ errors_storage = errors_;
+ errors_storage.push_back(
+ strings::StrCat(inputs_specified_, " inputs specified of ",
+ op_def_->input_arg_size(), " inputs in Op"));
+ errors_ptr = &errors_storage;
+ }
+
+ if (!errors_ptr->empty()) {
+ if (errors_ptr->size() == 1) {
+ if (op_def_ == nullptr) {
+ return errors::InvalidArgument((*errors_ptr)[0],
+ " while building NodeDef '",
+ node_def_.name(), "'");
+ }
+ return errors::InvalidArgument(
+ (*errors_ptr)[0], " while building NodeDef '", node_def_.name(),
+ "' using ", SummarizeOpDef(*op_def_));
+ } else {
+ return errors::InvalidArgument(
+ errors_ptr->size(), " errors while building NodeDef '",
+ node_def_.name(), "' using ", SummarizeOpDef(*op_def_), ":\n",
+ str_util::Join(*errors_ptr, "\n"));
+ }
+ } else {
+ NodeDef node_def_backup;
+ if (node_def == nullptr) node_def = &node_def_backup;
+ *node_def = node_def_;
+
+ // Add control inputs after the regular inputs.
+ for (const auto& control_input : control_inputs_) {
+ node_def->add_input(strings::StrCat("^", control_input));
+ }
+
+ // Add default values for unspecified attrs.
+ AddDefaultsToNodeDef(*op_def_, node_def);
+
+ return Status::OK();
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h
new file mode 100644
index 0000000000..706f072608
--- /dev/null
+++ b/tensorflow/core/framework/node_def_builder.h
@@ -0,0 +1,176 @@
+#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_
+#define TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_
+
+#include <functional>
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+
+class NodeDefBuilder;
+typedef std::function<Status(const OpDef&, int, const NodeDef&,
+ NodeDefBuilder*)> FakeInputFunctor;
+
+// This is a helper for creating a NodeDef. Automatically sets attrs
+// that can be inferred from the inputs, and uses default values
+// (where they exist) for unspecified attrs. Example usage:
+//
+// NodeDef node_def;
+// Status status = NodeDefBuilder(node_name, op_name)
+// .Input(...)
+// .Attr(...)
+// .Finalize(&node_def);
+// if (!status.ok()) return status;
+// // Use node_def here.
+class NodeDefBuilder {
+ public:
+ // To specify an output to be consumed by one of the Input() methods below.
+ struct NodeOut {
+ NodeOut(const string& n, int i, DataType dt)
+ : node(n), index(i), data_type(dt) {}
+ NodeOut() {} // uninitialized, call Reset() before use.
+ void Reset(const string& n, int i, DataType dt) {
+ node = n;
+ index = i;
+ data_type = dt;
+ }
+ string node;
+ int index;
+ DataType data_type;
+ };
+
+ // Specify the name and the Op (either via an OpDef or the name of
+ // the Op plus a registry) for the NodeDef. Other fields are
+ // specified by calling the methods below.
+ // REQUIRES: The OpDef must satisfy ValidateOpDef().
+ NodeDefBuilder(const string& name, const string& op_name,
+ const OpRegistryInterface* op_registry = OpRegistry::Global());
+ // REQUIRES: in addition, *op_def must outlive *this.
+ NodeDefBuilder(const string& name, const OpDef* op_def);
+
+ // You must call one Input() function per input_arg in the Op,
+ // *and in the same order as the input_args appear in the OpDef.*
+
+ // For inputs that take a single tensor.
+ NodeDefBuilder& Input(const string& src_node, int src_index, DataType dt) {
+ const OpDef::ArgDef* arg = NextArgDef();
+ if (arg != nullptr) SingleInput(arg, src_node, src_index, dt);
+ return *this;
+ }
+ NodeDefBuilder& Input(const NodeOut& src) {
+ Input(src.node, src.index, src.data_type);
+ return *this;
+ }
+
+ // For inputs that take a list of tensors.
+ NodeDefBuilder& Input(gtl::ArraySlice<NodeOut> src_list) {
+ const OpDef::ArgDef* arg = NextArgDef();
+ if (arg != nullptr) ListInput(arg, src_list);
+ return *this;
+ }
+
+ // To create inputs in tests, see fake_input.h.
+ NodeDefBuilder& Input(FakeInputFunctor fake_input);
+
+ // Specify that this node must only run after src_node.
+ NodeDefBuilder& ControlInput(const string& src_node) {
+ control_inputs_.push_back(src_node);
+ return *this;
+ }
+
+ // Constrains what devices this node may be scheduled on.
+ NodeDefBuilder& Device(const string& device_spec) {
+ node_def_.set_device(device_spec);
+ return *this;
+ }
+
+ // Sets the attr, if not already set. If already set with a different
+ // value, an error will be returned from Finalize().
+ template <class T>
+ NodeDefBuilder& Attr(const string& attr_name, T&& value);
+ // Note: overload needed to allow {...} expressions for value.
+ template <class T>
+ NodeDefBuilder& Attr(const string& attr_name,
+ std::initializer_list<T> value) {
+ Attr<std::initializer_list<T>>(attr_name, std::move(value));
+ return *this;
+ }
+
+ // Finish building the NodeDef, returning any errors or setting
+ // *node_def if none.
+ // WARNING: Not all problems are detected! The resulting NodeDef may
+ // not be valid! Call ValidateNodeDef() from node_def_utils to be sure.
+ Status Finalize(NodeDef* node_def) const;
+
+ // Accessor for the OpDef set in the constructor.
+ const OpDef& op_def() const { return *op_def_; }
+
+ private:
+ // Called in the constructors.
+ void Initialize();
+
+ // Get the current ArgDef and advance to the next one. Returns nullptr
+ // if no more inputs are available.
+ const OpDef::ArgDef* NextArgDef();
+
+ // Returns true if there is still an input_arg available in *op_def_,
+ // otherwise adds to error_ and returns false.
+ bool NextArgAvailable();
+
+ // These do the main work of the Input() methods.
+ void SingleInput(const OpDef::ArgDef* input_arg, const string& src_node,
+ int src_index, DataType dt);
+ void ListInput(const OpDef::ArgDef* input_arg,
+ gtl::ArraySlice<NodeOut> src_list);
+
+ // Add "src_node:src_index" to the list of inputs in the node_def_.
+ void AddInput(const string& src_node, int src_index);
+
+ // Generate an error if you can't pass dt when expected is expected.
+ void VerifyInputType(const OpDef::ArgDef* input_arg, DataType expected,
+ DataType dt);
+
+ // If input_arg->is_ref() is true, generate an error if dt is not a ref.
+ void VerifyInputRef(const OpDef::ArgDef* input_arg, DataType dt);
+
+ // Makes dt a ref type if that is what the input_arg specifies.
+ DataType MaybeAddRef(const OpDef::ArgDef* input_arg, DataType dt) {
+ return input_arg->is_ref() ? MakeRefType(dt) : dt;
+ }
+
+ const OpDef* op_def_;
+ NodeDef node_def_;
+ int inputs_specified_;
+ std::vector<string> control_inputs_;
+ std::vector<string> errors_;
+};
+
+// IMPLEMENTATION -------------------------------------------------------------
+
+template <class T>
+NodeDefBuilder& NodeDefBuilder::Attr(const string& attr_name, T&& value) {
+ const AttrValue* found = AttrSlice(node_def_).Find(attr_name);
+ if (found == nullptr) {
+ AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
+ } else {
+ AttrValue attr_value;
+ SetAttrValue(std::forward<T>(value), &attr_value);
+ if (!AreAttrValuesEqual(*found, attr_value)) {
+ errors_.push_back(strings::StrCat(
+ "Inconsistent values for attr '", attr_name, "' ",
+ SummarizeAttrValue(*found), " vs. ", SummarizeAttrValue(attr_value)));
+ }
+ }
+ return *this;
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_
diff --git a/tensorflow/core/framework/node_def_builder_test.cc b/tensorflow/core/framework/node_def_builder_test.cc
new file mode 100644
index 0000000000..6fd4a8d1ed
--- /dev/null
+++ b/tensorflow/core/framework/node_def_builder_test.cc
@@ -0,0 +1,1036 @@
+#include "tensorflow/core/framework/node_def_builder.h"
+
+#include <memory>
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+class NodeDefBuilderTest : public ::testing::Test {
+ protected:
+ // Specify an OpDef via an OpDefBuilder.
+ void Op(const OpDefBuilder& op_def_builder) {
+ EXPECT_OK(op_def_builder.Finalize(&op_def_));
+ }
+
+ // Resets builder_ with a new NodeDefBuilder using the Op from the last call
+ // to Op() above.
+ NodeDefBuilder& Builder() {
+ EXPECT_FALSE(op_def_.name().empty()) << "Must call Op() before Builder()";
+ builder_.reset(new NodeDefBuilder("n", &op_def_));
+ return *builder_;
+ }
+
+ // Calls Finalize() and verifies it returns success and the result matches
+ // expectations.
+ void ExpectSuccess(const NodeDefBuilder& builder,
+ DataTypeSlice expected_in_types,
+ DataTypeSlice expected_out_types, StringPiece proto) {
+ NodeDef node_def;
+ Status status = builder.Finalize(&node_def);
+ EXPECT_OK(status);
+ if (!status.ok()) return;
+ NodeDef expected;
+ protobuf::TextFormat::ParseFromString(strings::StrCat("name: 'n' ", proto),
+ &expected);
+ EXPECT_EQ(node_def.DebugString(), expected.DebugString());
+
+ DataTypeVector in_types, out_types;
+ status =
+ InOutTypesForNode(node_def, builder.op_def(), &in_types, &out_types);
+ EXPECT_OK(status);
+ if (!status.ok()) return;
+ EXPECT_EQ(DataTypeSliceString(expected_in_types),
+ DataTypeVectorString(in_types));
+ EXPECT_EQ(DataTypeSliceString(expected_out_types),
+ DataTypeVectorString(out_types));
+
+ status = ValidateNodeDef(node_def, op_def_);
+ EXPECT_OK(status);
+ }
+
+ // Calls Finalize() and verifies it returns an error.
+ // Each message must appear as a substring of the error.
+ void ExpectFailures(const NodeDefBuilder& builder,
+ const std::vector<string>& messages) {
+ NodeDef node_def;
+ Status status = builder.Finalize(&node_def);
+ EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def);
+ if (status.ok()) return;
+ for (const string& message : messages) {
+ EXPECT_TRUE(StringPiece(status.error_message()).contains(message))
+ << status << ", " << message;
+ }
+ }
+
+ // Calls Finalize() and verifies it returns an error.
+ // Message must appear as a substring of the error.
+ void ExpectFailure(const NodeDefBuilder& builder, const string& message) {
+ ExpectFailures(builder, {message});
+ }
+
+ // Like ExpectFailure(), except that the error can come from
+ // ValidateNodeDef().
+ void ExpectInvalid(const NodeDefBuilder& builder, const string& message) {
+ NodeDef node_def;
+ Status status = builder.Finalize(&node_def);
+ if (status.ok()) {
+ status = ValidateNodeDef(node_def, op_def_);
+ }
+ EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def);
+ if (status.ok()) return;
+ EXPECT_TRUE(StringPiece(status.error_message()).contains(message))
+ << "Actual error: " << status.error_message()
+ << "\nDoes not contain: " << message;
+ }
+
+ OpDef op_def_;
+ std::unique_ptr<NodeDefBuilder> builder_;
+};
+
+TEST_F(NodeDefBuilderTest, Simple) {
+ Op(OpDefBuilder("Simple").Input("a: int32").Output("out: float"));
+
+ ExpectSuccess(Builder().Input("x", 0, DT_INT32), {DT_INT32}, {DT_FLOAT},
+ R"proto( op: "Simple" input: "x" )proto");
+
+ // Port != 0
+ ExpectSuccess(Builder().Input("y", 2, DT_INT32), {DT_INT32}, {DT_FLOAT},
+ R"proto( op: "Simple" input: "y:2" )proto");
+
+ // FakeInput
+ ExpectSuccess(Builder().Input(FakeInput()), {DT_INT32}, {DT_FLOAT}, R"proto(
+ op: "Simple" input: "a" )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_FLOAT},
+ R"proto( op: "Simple" input: "a" )proto");
+
+ // Ref input
+ ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32},
+ {DT_FLOAT}, R"proto( op: "Simple" input: "a" )proto");
+
+ // ControlInput
+ ExpectSuccess(
+ Builder().ControlInput("x").Input(FakeInput()).ControlInput("y"),
+ {DT_INT32}, {DT_FLOAT}, R"proto(
+ op: "Simple" input: ["a", "^x", "^y"] )proto");
+
+ // Device
+ ExpectSuccess(Builder().Input(FakeInput()).Device("ddd"), {DT_INT32},
+ {DT_FLOAT}, R"proto(
+ op: "Simple" input: "a" device: "ddd" )proto");
+
+ // Extra input
+ ExpectFailure(Builder().Input("x", 0, DT_INT32).Input("y", 0, DT_INT32),
+ "More Input() calls than the 1 input_args while building "
+ "NodeDef 'n' using Op<name=Simple; signature=a:int32 -> "
+ "out:float>");
+
+ // Missing input
+ ExpectFailure(Builder(), "0 inputs specified of 1 inputs in Op while");
+
+ { // Finalize() twice.
+ NodeDefBuilder& builder = Builder();
+ builder.Input(FakeInput()).Finalize(nullptr); // First call to Finalize()
+ // ExpectSuccess() also calls Finalize().
+ ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto(
+ op: "Simple" input: "a" )proto");
+ }
+
+ { // Input() after Finalize()
+ NodeDefBuilder& builder = Builder();
+ // Calling Finalize() before enough inputs -> error.
+ ExpectFailure(builder, "0 inputs specified of 1 inputs in Op while");
+ builder.Input(FakeInput());
+ // Calling Finalize() with enough inputs -> success
+ ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto(
+ op: "Simple" input: "a" )proto");
+ // Calling Finalize() with too many inputs -> error.
+ builder.Input(FakeInput(DT_INT32));
+ ExpectFailure(builder, "More Input() calls than the 1 input_args while");
+ }
+
+ // Wrong input type
+ ExpectFailure(Builder().Input("x", 0, DT_FLOAT),
+ "Input 'a' passed float expected int32 ");
+
+ ExpectFailure(Builder().Input("x", 0, DT_FLOAT_REF),
+ "Input 'a' passed float_ref expected int32 ");
+
+ // List input
+ ExpectFailure(Builder().Input(FakeInput(3, DT_FLOAT)),
+ "List provided to input 'a' when single Tensor expected while");
+
+ ExpectFailure(Builder().Input(FakeInput(3)),
+ "List provided to input 'a' when single Tensor expected while");
+
+ // Bad ControlInput
+ ExpectInvalid(Builder().Input(FakeInput()).ControlInput("z:2"),
+ "Control input '^z:2' must not have ':' in NodeDef:");
+
+ // Bad input name
+ ExpectFailure(Builder().Input("", 0, DT_INT32),
+ "Empty input node name while");
+
+ ExpectFailure(Builder().Input("^x", 0, DT_INT32),
+ "Non-control input starting with ^: ^x while");
+}
+
+TEST_F(NodeDefBuilderTest, OpDoesNotExist) {
+ NodeDefBuilder builder("n", "Op Does Not Exist");
+ builder.Input(FakeInput())
+ .Input(FakeInput(12))
+ .ControlInput("y")
+ .Attr("foo", 12)
+ .Device("device");
+ ExpectFailure(
+ builder,
+ "Op type not registered 'Op Does Not Exist' while building NodeDef 'n'");
+}
+
+TEST_F(NodeDefBuilderTest, Polymorphic) {
+ Op(OpDefBuilder("Polymorphic")
+ .Input("v: T")
+ .Output("out: T")
+ .Attr("T: type"));
+
+ ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_INT32},
+ R"proto(
+ op: "Polymorphic" input: "a"
+ attr { key: "T" value { type: DT_INT32 } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(DT_FLOAT)), {DT_FLOAT}, {DT_FLOAT},
+ R"proto(
+ op: "Polymorphic" input: "a"
+ attr { key: "T" value { type: DT_FLOAT } } )proto");
+
+ // Redundant Attr()
+ ExpectSuccess(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_BOOL),
+ {DT_BOOL}, {DT_BOOL}, R"proto(
+ op: "Polymorphic" input: "a"
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+
+ // Conficting Attr()
+ ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_STRING),
+ "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while");
+
+ ExpectFailure(Builder().Attr("T", DT_STRING).Input(FakeInput(DT_BOOL)),
+ "Inconsistent values for attr 'T' DT_STRING vs. DT_BOOL while");
+
+ ExpectFailure(Builder().Attr("T", 12).Input(FakeInput(DT_BOOL)),
+ "Inconsistent values for attr 'T' 12 vs. DT_BOOL while");
+}
+
+TEST_F(NodeDefBuilderTest, PolymorphicOut) {
+ Op(OpDefBuilder("PolymorphicOut").Output("out: T").Attr("T: type"));
+
+ ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32}, R"proto(
+ op: "PolymorphicOut"
+ attr { key: "T" value { type: DT_INT32 } } )proto");
+
+ ExpectSuccess(Builder().Attr("T", DT_FLOAT), {}, {DT_FLOAT}, R"proto(
+ op: "PolymorphicOut"
+ attr { key: "T" value { type: DT_FLOAT } } )proto");
+
+ // Redundant attr
+ ExpectSuccess(Builder().Attr("T", DT_FLOAT).Attr("T", DT_FLOAT), {},
+ {DT_FLOAT}, R"proto(
+ op: "PolymorphicOut"
+ attr { key: "T" value { type: DT_FLOAT } } )proto");
+
+ // Conflicting attr
+ ExpectFailure(Builder().Attr("T", DT_BOOL).Attr("T", DT_FLOAT),
+ "Inconsistent values for attr 'T' DT_BOOL vs. DT_FLOAT while");
+
+ // Missing attr
+ ExpectInvalid(Builder(), "NodeDef missing attr 'T' from");
+
+ // Attr has the wrong type
+ ExpectInvalid(Builder().Attr("T", {DT_INT32, DT_BOOL}),
+ "AttrValue had value with type list(type) when type expected");
+
+ ExpectInvalid(Builder().Attr("T", 12),
+ "AttrValue had value with type int when type expected");
+}
+
+TEST_F(NodeDefBuilderTest, PolymorphicDefaultOut) {
+ Op(OpDefBuilder("PolymorphicDefaultOut")
+ .Output("out: T")
+ .Attr("T: type = DT_STRING"));
+
+ ExpectSuccess(Builder(), {}, {DT_STRING}, R"proto(
+ op: "PolymorphicDefaultOut"
+ attr { key: "T" value { type: DT_STRING } } )proto");
+
+ ExpectSuccess(Builder().Attr("T", DT_BOOL), {}, {DT_BOOL}, R"proto(
+ op: "PolymorphicDefaultOut"
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, Binary) {
+ Op(OpDefBuilder("Binary").Input("a: T").Input("b: T").Output("out: T").Attr(
+ "T: type"));
+
+ ExpectSuccess(Builder().Input(FakeInput(DT_INT32)).Input(FakeInput(DT_INT32)),
+ {DT_INT32, DT_INT32}, {DT_INT32}, R"proto(
+ op: "Binary" input: "a" input: "b"
+ attr { key: "T" value { type: DT_INT32 } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(DT_STRING)).Input(FakeInput()),
+ {DT_STRING, DT_STRING}, {DT_STRING}, R"proto(
+ op: "Binary" input: "a" input: "b"
+ attr { key: "T" value { type: DT_STRING } } )proto");
+
+ // Type mismatch
+ ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Input(FakeInput(DT_STRING)),
+ "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while");
+}
+
+TEST_F(NodeDefBuilderTest, Restrict) {
+ Op(OpDefBuilder("Restrict")
+ .Input("a: T")
+ .Output("out: T")
+ .Attr("T: {string, bool}"));
+ ExpectSuccess(Builder().Input(FakeInput(DT_STRING)), {DT_STRING}, {DT_STRING},
+ R"proto(
+ op: "Restrict" input: "a"
+ attr { key: "T" value { type: DT_STRING } } )proto");
+
+ ExpectInvalid(Builder().Input(FakeInput(DT_INT32)),
+ "Value for attr 'T' of int32 is not in the list of allowed "
+ "values: string, bool");
+}
+
+TEST_F(NodeDefBuilderTest, TypeList) {
+ Op(OpDefBuilder("TypeList").Input("a: T").Attr("T: list(type)"));
+
+ ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_INT32})),
+ {DT_STRING, DT_INT32}, {}, R"proto(
+ op: "TypeList" input: ["a", "a:1"]
+ attr { key: "T" value { list { type: [DT_STRING, DT_INT32] } } }
+ )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(3, DT_BOOL)),
+ {DT_BOOL, DT_BOOL, DT_BOOL}, {}, R"proto(
+ op: "TypeList" input: ["a", "a:1", "a:2"]
+ attr { key: "T" value { list { type: [DT_BOOL, DT_BOOL, DT_BOOL] } } }
+ )proto");
+
+ ExpectInvalid(Builder().Input(FakeInput(0)),
+ "Length for attr 'T' of 0 must be at least minimum 1");
+
+ ExpectInvalid(Builder().Input(FakeInput({})),
+ "Length for attr 'T' of 0 must be at least minimum 1");
+
+ ExpectInvalid(Builder().Input(FakeInput(DT_BOOL)),
+ "Single tensor passed to 'a', expected list while");
+
+ ExpectFailures(Builder().Input(FakeInput()),
+ {"2 errors while building NodeDef",
+ "Could not infer list of types for input 'a': "
+ "No attr named 'T' in NodeDef:",
+ "0 inputs specified of 1 inputs in Op"});
+}
+
+TEST_F(NodeDefBuilderTest, TypeListNoMin) {
+ Op(OpDefBuilder("TypeListNoMin").Input("a: T").Attr("T: list(type) >= 0"));
+
+ ExpectSuccess(Builder().Input(FakeInput(0)), {}, {}, R"proto(
+ op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(DataTypeVector())), {}, {}, R"proto(
+ op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput({})), {}, {}, R"proto(
+ op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput({DT_BOOL})), {DT_BOOL}, {}, R"proto(
+ op: "TypeListNoMin" input: "a"
+ attr { key: "T" value { list { type: DT_BOOL } } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, TypeListTwice) {
+ Op(OpDefBuilder("TypeListTwice")
+ .Input("a: T")
+ .Input("b: T")
+ .Attr("T: list(type) >= 0"));
+
+ ExpectSuccess(Builder()
+ .Input(FakeInput({DT_INT32, DT_BOOL}))
+ .Input(FakeInput({DT_INT32, DT_BOOL})),
+ {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto(
+ op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"]
+ attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto");
+
+ ExpectSuccess(
+ Builder().Input(FakeInput({DT_INT32, DT_BOOL})).Input(FakeInput()),
+ {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto(
+ op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"]
+ attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(0)), {}, {},
+ R"proto(
+ op: "TypeListTwice" attr { key: "T" value { list { } } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {},
+ R"proto(
+ op: "TypeListTwice" attr { key: "T" value { list { } } } )proto");
+
+ ExpectFailure(Builder()
+ .Input(FakeInput({DT_INT32, DT_BOOL}))
+ .Input(FakeInput({DT_INT32, DT_STRING})),
+ "Inconsistent values for attr 'T' [DT_INT32, DT_BOOL] vs. "
+ "[DT_INT32, DT_STRING] while");
+}
+
+TEST_F(NodeDefBuilderTest, OutTypeList) {
+ Op(OpDefBuilder("OutTypeList").Output("out: T").Attr("T: list(type) >= 0"));
+
+ ExpectSuccess(Builder().Attr("T", {DT_FLOAT}), {}, {DT_FLOAT}, R"proto(
+ op: "OutTypeList"
+ attr { key: "T" value { list { type: DT_FLOAT } } } )proto");
+
+ ExpectSuccess(Builder().Attr("T", {DT_STRING, DT_BOOL}), {},
+ {DT_STRING, DT_BOOL}, R"proto(
+ op: "OutTypeList"
+ attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto");
+
+ ExpectSuccess(Builder().Attr("T", DataTypeVector()), {}, {}, R"proto(
+ op: "OutTypeList"
+ attr { key: "T" value { list { } } } )proto");
+
+ ExpectInvalid(Builder().Attr("T", DT_FLOAT),
+ "AttrValue had value with type type when list(type) expected");
+}
+
+TEST_F(NodeDefBuilderTest, TypeListRestrict) {
+ Op(OpDefBuilder("TypeListRestrict")
+ .Input("a: T")
+ .Attr("T: list({string, bool}) >= 0"));
+
+ ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_BOOL})),
+ {DT_STRING, DT_BOOL}, {}, R"proto(
+ op: "TypeListRestrict" input: ["a", "a:1"]
+ attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto");
+
+ ExpectInvalid(Builder().Input(FakeInput({DT_STRING, DT_INT32})),
+ "Value for attr 'T' of int32 is not in the list of allowed "
+ "values: string, bool");
+}
+
+TEST_F(NodeDefBuilderTest, OutTypeListRestrict) {
+ Op(OpDefBuilder("OutTypeListRestrict")
+ .Output("out: t")
+ .Attr("t: list({string, bool}) >= 0"));
+
+ ExpectSuccess(Builder().Attr("t", {DT_BOOL, DT_STRING}), {},
+ {DT_BOOL, DT_STRING}, R"proto(
+ op: "OutTypeListRestrict"
+ attr { key: "t" value { list { type: [DT_BOOL, DT_STRING] } } } )proto");
+
+ ExpectInvalid(Builder().Attr("t", {DT_STRING, DT_INT32}),
+ "Value for attr 't' of int32 is not in the list of allowed "
+ "values: string, bool");
+}
+
+TEST_F(NodeDefBuilderTest, Attr) {
+ Op(OpDefBuilder("Attr").Attr("a: int"));
+
+ ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto(
+ op: "Attr" attr { key: "a" value { i: 12 } } )proto");
+
+ // Attr has wrong type
+ ExpectInvalid(Builder().Attr("a", "bad"),
+ "AttrValue had value with type string when int expected");
+
+ ExpectInvalid(Builder().Attr("a", {12}),
+ "AttrValue had value with type list(int) when int expected");
+
+ // Missing attr
+ ExpectInvalid(Builder(), "NodeDef missing attr 'a' from Op<");
+
+ // Wrong attr
+ ExpectInvalid(Builder().Attr("b", 12),
+ "NodeDef mentions attr 'b' not in Op<");
+
+ // Extra attr
+ ExpectInvalid(Builder().Attr("a", 12).Attr("extra", 12),
+ "NodeDef mentions attr 'extra' not in Op<");
+}
+
+TEST_F(NodeDefBuilderTest, AttrFloat) {
+ Op(OpDefBuilder("AttrFloat").Attr("a: float"));
+
+ ExpectSuccess(Builder().Attr("a", 1.2f /* float */), {}, {}, R"proto(
+ op: "AttrFloat" attr { key: "a" value { f: 1.2 } }
+ )proto");
+
+ ExpectSuccess(Builder().Attr("a", 1.2 /* double */), {}, {}, R"proto(
+ op: "AttrFloat" attr { key: "a" value { f: 1.2 } }
+ )proto");
+
+ // Won't automatically cast int to float
+ ExpectInvalid(Builder().Attr("a", 12),
+ "AttrValue had value with type int when float expected");
+}
+
+TEST_F(NodeDefBuilderTest, AttrBoolList) {
+ Op(OpDefBuilder("AttrBoolList").Attr("a: list(bool)"));
+
+ ExpectSuccess(Builder().Attr("a", {true, false, true}), {}, {}, R"proto(
+ op: "AttrBoolList"
+ attr { key: "a" value { list { b: [true, false, true] } } }
+ )proto");
+
+ ExpectSuccess(Builder().Attr("a", std::vector<bool>()), {}, {}, R"proto(
+ op: "AttrBoolList" attr { key: "a" value { list { } } }
+ )proto");
+
+ // Won't cast int -> bool.
+ ExpectInvalid(Builder().Attr("a", {0}),
+ "AttrValue had value with type list(int) when list(bool) "
+ "expected");
+}
+
+TEST_F(NodeDefBuilderTest, AttrMin) {
+ Op(OpDefBuilder("AttrMin").Attr("a: int >= 5"));
+
+ ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto(
+ op: "AttrMin" attr { key: "a" value { i: 12 } } )proto");
+
+ ExpectInvalid(Builder().Attr("a", 2),
+ "Value for attr 'a' of 2 must be at least minimum 5");
+}
+
+TEST_F(NodeDefBuilderTest, AttrListMin) {
+ Op(OpDefBuilder("AttrListMin").Attr("a: list(int) >= 2"));
+
+ ExpectSuccess(Builder().Attr("a", {1, 2}), {}, {}, R"proto(
+ op: "AttrListMin"
+ attr { key: "a" value { list { i: [1, 2] } } } )proto");
+
+ ExpectInvalid(Builder().Attr("a", {17}),
+ "Length for attr 'a' of 1 must be at least minimum 2");
+}
+
+TEST_F(NodeDefBuilderTest, AttrEnum) {
+ Op(OpDefBuilder("AttrEnum").Attr("a: {'apples', 'oranges'}"));
+
+ ExpectSuccess(Builder().Attr("a", "oranges"), {}, {}, R"proto(
+ op: "AttrEnum"
+ attr { key: "a" value { s: "oranges" } } )proto");
+
+ ExpectInvalid(
+ Builder().Attr("a", "invalid"),
+ "Value for attr 'a' of \"invalid\" is not in the list of allowed values: "
+ "\"apples\", \"oranges\"");
+}
+
+TEST_F(NodeDefBuilderTest, AttrEnumList) {
+ Op(OpDefBuilder("AttrEnumList").Attr("a: list({'apples', 'oranges'})"));
+
+ ExpectSuccess(Builder().Attr("a", {"oranges", "apples"}), {}, {}, R"proto(
+ op: "AttrEnumList"
+ attr { key: "a" value { list { s: ["oranges", "apples"] } } } )proto");
+
+ ExpectInvalid(
+ Builder().Attr("a", {"apples", "invalid", "oranges"}),
+ "Value for attr 'a' of \"invalid\" is not in the list of allowed values: "
+ "\"apples\", \"oranges\"");
+}
+
+TEST_F(NodeDefBuilderTest, AttrShape) {
+ Op(OpDefBuilder("AttrShape").Attr("a: shape"));
+
+ ExpectSuccess(Builder().Attr("a", TensorShape({5})), {}, {}, R"proto(
+ op: "AttrShape"
+ attr { key: "a" value { shape { dim { size: 5 } } } } )proto");
+
+ ExpectSuccess(Builder().Attr("a", TensorShape({4, 3, 2})), {}, {}, R"proto(
+ op: "AttrShape"
+ attr { key: "a" value { shape {
+ dim { size: 4 } dim { size: 3 } dim { size: 2 } } } } )proto");
+
+ ExpectSuccess(Builder().Attr("a", TensorShape({3, 2})), {}, {},
+ R"proto(
+ op: "AttrShape"
+ attr { key: "a" value { shape {
+ dim { size: 3 } dim { size: 2 } } } } )proto");
+
+ ExpectSuccess(Builder().Attr("a", TensorShape()), {}, {}, R"proto(
+ op: "AttrShape"
+ attr { key: "a" value { shape { } } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, AttrDefault) {
+ Op(OpDefBuilder("AttrDefault").Attr("a: string = 'banana'"));
+
+ ExpectSuccess(Builder(), {}, {}, R"proto(
+ op: "AttrDefault"
+ attr { key: "a" value { s: "banana" } } )proto");
+
+ ExpectSuccess(Builder().Attr("a", "kiwi"), {}, {}, R"proto(
+ op: "AttrDefault"
+ attr { key: "a" value { s: "kiwi" } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, AttrManyDefault) {
+ Op(OpDefBuilder("AttrManyDefault")
+ .Attr("a: string = 'banana'")
+ .Attr("b: string = 'kiwi'"));
+
+ ExpectSuccess(Builder(), {}, {}, R"proto(
+ op: "AttrManyDefault"
+ attr { key: "a" value { s: "banana" } }
+ attr { key: "b" value { s: "kiwi" } } )proto");
+
+ Op(OpDefBuilder("AttrManyDefaultWithMandatory")
+ .Attr("a: string = 'banana'")
+ .Attr("b: string = 'kiwi'")
+ .Attr("c: string"));
+
+ ExpectSuccess(Builder().Attr("c", "strawberry"), {}, {}, R"proto(
+ op: "AttrManyDefaultWithMandatory"
+ attr { key: "c" value { s: "strawberry" } }
+ attr { key: "a" value { s: "banana" } }
+ attr { key: "b" value { s: "kiwi" } } )proto");
+
+ Op(OpDefBuilder("AttrManyDefaultAndInferred")
+ .Input("input: T")
+ .Attr("T: {float, double}")
+ .Attr("a: string")
+ .Attr("b: list(string) >= 1")
+ .Attr("c: bool = true")
+ .Attr("d: float = 0.3")
+ .Attr("e: string")
+ .Attr("f: float = 0.25"));
+
+ ExpectSuccess(Builder()
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("a", "foo")
+ .Attr("e", "foo")
+ .Attr("b", std::vector<string>({"bar", "baz"}))
+ .Attr("f", 1.0f),
+ {DT_FLOAT}, {}, R"proto(
+ op: "AttrManyDefaultAndInferred"
+ input: "a"
+ attr { key: "T" value { type: DT_FLOAT } }
+ attr { key: "a" value { s: "foo" } }
+ attr { key: "e" value { s: "foo" } }
+ attr { key: "b" value { list { s: "bar" s: "baz" } } }
+ attr { key: "f" value { f: 1.0 } }
+ attr { key: "c" value { b: true } }
+ attr { key: "d" value { f: 0.3 } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, AttrListDefault) {
+ Op(OpDefBuilder("AttrListDefault").Attr("a: list(int) = [5, 15]"));
+
+ ExpectSuccess(Builder(), {}, {}, R"proto(
+ op: "AttrListDefault"
+ attr { key: "a" value { list { i: [5, 15] } } } )proto");
+
+ ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto(
+ op: "AttrListDefault"
+ attr { key: "a" value { list { i: 3 } } } )proto");
+
+ ExpectSuccess(Builder().Attr("a", std::vector<int>()), {}, {}, R"proto(
+ op: "AttrListDefault"
+ attr { key: "a" value { list { } } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, AttrEmptyListDefault) {
+ Op(OpDefBuilder("AttrEmptyListDefault").Attr("a: list(int) = []"));
+
+ ExpectSuccess(Builder(), {}, {}, R"proto(
+ op: "AttrEmptyListDefault"
+ attr { key: "a" value { list { } } } )proto");
+
+ ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto(
+ op: "AttrEmptyListDefault"
+ attr { key: "a" value { list { i: 3 } } } )proto");
+
+ ExpectSuccess(Builder().Attr("a", std::vector<int>()), {}, {}, R"proto(
+ op: "AttrEmptyListDefault"
+ attr { key: "a" value { list { } } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, NIntsIn) {
+ Op(OpDefBuilder("NIntsIn").Input("a: N*int32").Attr("N: int >= 2"));
+
+ ExpectSuccess(Builder().Input(FakeInput(2)), {DT_INT32, DT_INT32}, {},
+ R"proto(
+ op: "NIntsIn" input: ["a", "a:1"]
+ attr { key: "N" value { i: 2 } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(5, DT_INT32)),
+ {DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto(
+ op: "NIntsIn"
+ input: ["a", "a:1", "a:2", "a:3", "a:4"]
+ attr { key: "N" value { i: 5 } } )proto");
+
+ ExpectFailures(Builder().Input(FakeInput(2, DT_STRING)),
+ {"2 errors while building NodeDef",
+ "Input 'a' passed string expected int32"});
+
+ ExpectInvalid(Builder().Input(FakeInput(1)),
+ "Value for attr 'N' of 1 must be at least minimum 2");
+
+ ExpectFailures(
+ Builder().Input(FakeInput(DT_INT32)),
+ {"2 errors while building NodeDef",
+ "Could not infer length of input 'a': No attr named 'N' in NodeDef:",
+ "0 inputs specified of 1 inputs in Op"});
+
+ ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}),
+ "Input 'a' passed string expected int32 while");
+
+ ExpectFailures(
+ Builder().Input(FakeInput()),
+ {"2 errors while building NodeDef",
+ "Could not infer length of input 'a': No attr named 'N' in NodeDef:",
+ "0 inputs specified of 1 inputs in Op"});
+}
+
+TEST_F(NodeDefBuilderTest, NPolymorphicIn) {
+ Op(OpDefBuilder("NPolymorphicIn")
+ .Input("a: N*T")
+ .Attr("T: type")
+ .Attr("N: int >= 2"));
+
+ ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)), {DT_INT32, DT_INT32},
+ {}, R"proto(
+ op: "NPolymorphicIn" input: ["a", "a:1"]
+ attr { key: "N" value { i: 2 } }
+ attr { key: "T" value { type: DT_INT32 } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)),
+ {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto(
+ op: "NPolymorphicIn"
+ input: ["a", "a:1", "a:2"]
+ attr { key: "N" value { i: 3 } }
+ attr { key: "T" value { type: DT_STRING } } )proto");
+
+ ExpectFailures(
+ Builder().Input(FakeInput(2)),
+ {"2 errors while building NodeDef",
+ "Could not infer type for input 'a': No attr named 'T' in NodeDef:",
+ "0 inputs specified of 1 inputs in Op"});
+
+ ExpectFailure(Builder().Input(FakeInput({DT_INT32, DT_STRING})),
+ "Input 'a' passed string expected int32 while");
+
+ ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}),
+ "Input 'a' passed string expected int32 while");
+
+ ExpectInvalid(Builder().Input(FakeInput(1, DT_INT32)),
+ "Value for attr 'N' of 1 must be at least minimum 2");
+
+ ExpectFailure(Builder().Input("in", 0, DT_INT32),
+ "Single tensor passed to 'a', expected list while");
+}
+
+TEST_F(NodeDefBuilderTest, NPolymorphicRestrictIn) {
+ Op(OpDefBuilder("NPolymorphicRestrictIn")
+ .Input("a: N*T")
+ .Attr("T: {string, bool}")
+ .Attr("N: int >= 2"));
+
+ ExpectSuccess(Builder().Input(FakeInput(2, DT_BOOL)), {DT_BOOL, DT_BOOL}, {},
+ R"proto(
+ op: "NPolymorphicRestrictIn" input: ["a", "a:1"]
+ attr { key: "N" value { i: 2 } }
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)),
+ {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto(
+ op: "NPolymorphicRestrictIn"
+ input: ["a", "a:1", "a:2"]
+ attr { key: "N" value { i: 3 } }
+ attr { key: "T" value { type: DT_STRING } } )proto");
+
+ ExpectInvalid(Builder().Input(FakeInput(2, DT_INT32)),
+ "Value for attr 'T' of int32 is not in the list of allowed "
+ "values: string, bool");
+}
+
+TEST_F(NodeDefBuilderTest, NInTwice) {
+ Op(OpDefBuilder("NInTwice")
+ .Input("a: N*int32")
+ .Input("b: N*string")
+ .Attr("N: int >= 0"));
+
+ ExpectSuccess(Builder().Input(FakeInput(2)).Input(FakeInput(2)),
+ {DT_INT32, DT_INT32, DT_STRING, DT_STRING}, {}, R"proto(
+ op: "NInTwice"
+ input: ["a", "a:1", "b", "b:1"]
+ attr { key: "N" value { i: 2 } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {},
+ R"proto(
+ op: "NInTwice" attr { key: "N" value { i: 0 } } )proto");
+
+ ExpectFailure(Builder().Input(FakeInput(3)).Input(FakeInput(1)),
+ "Inconsistent values for attr 'N' 3 vs. 1 while");
+}
+
+TEST_F(NodeDefBuilderTest, NInPolymorphicTwice) {
+ Op(OpDefBuilder("NInPolymorphicTwice")
+ .Input("a: N*T")
+ .Input("b: N*T")
+ .Attr("T: type")
+ .Attr("N: int >= 0"));
+
+ ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput()),
+ {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto(
+ op: "NInPolymorphicTwice"
+ input: ["a", "a:1", "b", "b:1"]
+ attr { key: "N" value { i: 2 } }
+ attr { key: "T" value { type: DT_INT32 } } )proto");
+
+ ExpectFailure(
+ Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_INT32)),
+ "Inconsistent values for attr 'N' 3 vs. 1 while");
+
+ ExpectFailure(Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1)),
+ "Inconsistent values for attr 'N' 3 vs. 1 while");
+
+ ExpectFailure(
+ Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)),
+ "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while");
+
+ ExpectFailure(
+ Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_STRING)),
+ "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while");
+}
+
+TEST_F(NodeDefBuilderTest, NInTwoTypeVariables) {
+ Op(OpDefBuilder("NInTwoTypeVariables")
+ .Input("a: N*S")
+ .Input("b: N*T")
+ .Attr("S: type")
+ .Attr("T: type")
+ .Attr("N: int >= 0"));
+
+ ExpectSuccess(
+ Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_BOOL)),
+ {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto(
+ op: "NInTwoTypeVariables"
+ input: ["a", "a:1", "b", "b:1"]
+ attr { key: "N" value { i: 2 } }
+ attr { key: "S" value { type: DT_INT32 } }
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+
+ ExpectSuccess(
+ Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_BOOL)),
+ {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto(
+ op: "NInTwoTypeVariables"
+ input: ["a", "a:1", "b", "b:1"]
+ attr { key: "N" value { i: 2 } }
+ attr { key: "S" value { type: DT_INT32 } }
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+
+ ExpectFailure(
+ Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_STRING)),
+ "Inconsistent values for attr 'N' 3 vs. 1 while");
+}
+
+TEST_F(NodeDefBuilderTest, InPolymorphicTwice) {
+ Op(OpDefBuilder("InPolymorphicTwice")
+ .Input("a: N*T")
+ .Input("b: M*T")
+ .Attr("T: type")
+ .Attr("N: int >= 0")
+ .Attr("M: int >= 0"));
+
+ ExpectSuccess(
+ Builder().Input(FakeInput(1, DT_INT32)).Input(FakeInput(3, DT_INT32)),
+ {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto(
+ op: "InPolymorphicTwice"
+ input: ["a", "b", "b:1", "b:2"]
+ attr { key: "N" value { i: 1 } }
+ attr { key: "T" value { type: DT_INT32 } }
+ attr { key: "M" value { i: 3 } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(1, DT_BOOL)).Input(FakeInput(0)),
+ {DT_BOOL}, {}, R"proto(
+ op: "InPolymorphicTwice" input: "a"
+ attr { key: "N" value { i: 1 } }
+ attr { key: "T" value { type: DT_BOOL } }
+ attr { key: "M" value { i: 0 } } )proto");
+
+ ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(1, DT_BOOL)),
+ {DT_BOOL}, {}, R"proto(
+ op: "InPolymorphicTwice" input: "b"
+ attr { key: "N" value { i: 0 } }
+ attr { key: "M" value { i: 1 } }
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+
+ ExpectFailure(
+ Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)),
+ "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while");
+}
+
+TEST_F(NodeDefBuilderTest, NIntsOut) {
+ Op(OpDefBuilder("NIntsOut").Output("a: N*int32").Attr("N: int >= 2"));
+
+ ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto(
+ op: "NIntsOut"
+ attr { key: "N" value { i: 2 } } )proto");
+
+ ExpectSuccess(Builder().Attr("N", 3), {}, {DT_INT32, DT_INT32, DT_INT32},
+ R"proto(
+ op: "NIntsOut"
+ attr { key: "N" value { i: 3 } } )proto");
+
+ ExpectInvalid(Builder().Attr("N", 1),
+ "Value for attr 'N' of 1 must be at least minimum 2");
+
+ ExpectInvalid(Builder().Attr("N", {3}),
+ "AttrValue had value with type list(int) when int expected");
+
+ ExpectInvalid(Builder(), "NodeDef missing attr 'N' from");
+}
+
+TEST_F(NodeDefBuilderTest, NIntsOutDefault) {
+ Op(OpDefBuilder("NIntsOutDefault")
+ .Output("a: N*int32")
+ .Attr("N: int >= 2 = 3"));
+
+ ExpectSuccess(Builder(), {}, {DT_INT32, DT_INT32, DT_INT32}, R"proto(
+ op: "NIntsOutDefault"
+ attr { key: "N" value { i: 3 } } )proto");
+
+ ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto(
+ op: "NIntsOutDefault"
+ attr { key: "N" value { i: 2 } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, NPolymorphicOut) {
+ Op(OpDefBuilder("NPolymorphicOut")
+ .Output("a: N*T")
+ .Attr("T: type")
+ .Attr("N: int >= 2"));
+
+ ExpectSuccess(Builder().Attr("T", DT_INT32).Attr("N", 2), {},
+ {DT_INT32, DT_INT32}, R"proto(
+ op: "NPolymorphicOut"
+ attr { key: "T" value { type: DT_INT32 } }
+ attr { key: "N" value { i: 2 } } )proto");
+
+ ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_STRING), {},
+ {DT_STRING, DT_STRING, DT_STRING}, R"proto(
+ op: "NPolymorphicOut"
+ attr { key: "N" value { i: 3 } }
+ attr { key: "T" value { type: DT_STRING } } )proto");
+
+ ExpectInvalid(Builder().Attr("N", 1).Attr("T", DT_STRING),
+ "Value for attr 'N' of 1 must be at least minimum 2");
+
+ ExpectInvalid(Builder().Attr("N", 3).Attr("T", {DT_STRING}),
+ "AttrValue had value with type list(type) when type expected");
+}
+
+TEST_F(NodeDefBuilderTest, NPolymorphicOutDefault) {
+ Op(OpDefBuilder("NPolymorphicOutDefault")
+ .Output("a: N*T")
+ .Attr("T: type = DT_BOOL")
+ .Attr("N: int >= 2 = 2"));
+
+ ExpectSuccess(Builder(), {}, {DT_BOOL, DT_BOOL}, R"proto(
+ op: "NPolymorphicOutDefault"
+ attr { key: "T" value { type: DT_BOOL } }
+ attr { key: "N" value { i: 2 } } )proto");
+
+ ExpectSuccess(Builder().Attr("N", 3), {}, {DT_BOOL, DT_BOOL, DT_BOOL},
+ R"proto(
+ op: "NPolymorphicOutDefault"
+ attr { key: "N" value { i: 3 } }
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+
+ ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32, DT_INT32},
+ R"proto(
+ op: "NPolymorphicOutDefault"
+ attr { key: "T" value { type: DT_INT32 } }
+ attr { key: "N" value { i: 2 } } )proto");
+
+ ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_INT32), {},
+ {DT_INT32, DT_INT32, DT_INT32}, R"proto(
+ op: "NPolymorphicOutDefault"
+ attr { key: "N" value { i: 3 } }
+ attr { key: "T" value { type: DT_INT32 } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, NPolymorphicRestrictOut) {
+ Op(OpDefBuilder("NPolymorphicRestrictOut")
+ .Output("a: N*T")
+ .Attr("T: {string, bool}")
+ .Attr("N: int >= 2"));
+
+ ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_BOOL), {},
+ {DT_BOOL, DT_BOOL, DT_BOOL}, R"proto(
+ op: "NPolymorphicRestrictOut"
+ attr { key: "N" value { i: 3 } }
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+
+ ExpectInvalid(Builder().Attr("N", 3).Attr("T", DT_INT32),
+ "Value for attr 'T' of int32 is not in the list of allowed "
+ "values: string, bool");
+}
+
+TEST_F(NodeDefBuilderTest, RefIn) {
+ Op(OpDefBuilder("RefIn").Input("a: Ref(int32)"));
+
+ ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32_REF}, {},
+ R"proto(
+ op: "RefIn" input: "a" )proto");
+
+ ExpectFailure(Builder().Input(FakeInput(DT_BOOL_REF)),
+ "Input 'a' passed bool_ref expected int32_ref while");
+
+ ExpectFailure(Builder().Input(FakeInput(DT_INT32)),
+ "Input 'a' passed int32 expected int32_ref while");
+}
+
+TEST_F(NodeDefBuilderTest, PolymorphicRefIn) {
+ Op(OpDefBuilder("PolymorphicRefIn").Input("a: Ref(T)").Attr("T: type"));
+
+ ExpectSuccess(Builder().Input(FakeInput(DT_BOOL_REF)), {DT_BOOL_REF}, {},
+ R"proto(
+ op: "PolymorphicRefIn" input: "a"
+ attr { key: "T" value { type: DT_BOOL } } )proto");
+
+ ExpectFailure(Builder().Input(FakeInput(DT_BOOL)),
+ "Input 'a' passed bool expected ref type while");
+}
+
+TEST_F(NodeDefBuilderTest, RefOut) {
+ Op(OpDefBuilder("RefOut").Output("a: Ref(string)"));
+
+ ExpectSuccess(Builder(), {}, {DT_STRING_REF}, R"proto(
+ op: "RefOut" )proto");
+}
+
+TEST_F(NodeDefBuilderTest, PolymorphicRefOut) {
+ Op(OpDefBuilder("PolymorphicRefOut").Output("a: Ref(t)").Attr("t: type"));
+
+ ExpectSuccess(Builder().Attr("t", DT_BOOL), {}, {DT_BOOL_REF}, R"proto(
+ op: "PolymorphicRefOut"
+ attr { key: "t" value { type: DT_BOOL } } )proto");
+}
+
+TEST_F(NodeDefBuilderTest, SpecifyDevice) {
+ Op(OpDefBuilder("SpecifyDevice"));
+
+ ExpectSuccess(Builder().Device("ADevice"), {}, {}, R"proto(
+ op: "SpecifyDevice" device: "ADevice" )proto");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
new file mode 100644
index 0000000000..aefd416187
--- /dev/null
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -0,0 +1,414 @@
+#include "tensorflow/core/framework/node_def_util.h"
+
+#include <algorithm>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace tensorflow {
+
+string SummarizeNodeDef(const NodeDef& node_def) {
+ string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "[");
+
+ // We sort the attrs so the output is deterministic.
+ std::vector<string> attr_names;
+ attr_names.reserve(node_def.attr().size());
+ for (const auto& attr : node_def.attr()) {
+ attr_names.push_back(attr.first);
+ }
+ std::sort(attr_names.begin(), attr_names.end());
+ bool first = true;
+ for (const string& attr_name : attr_names) {
+ if (!first) strings::StrAppend(&ret, ", ");
+ first = false;
+ auto iter = node_def.attr().find(attr_name);
+ strings::StrAppend(&ret, attr_name, "=", SummarizeAttrValue(iter->second));
+ }
+
+ // Consider the device to be a final attr with name "_device".
+ if (!node_def.device().empty()) {
+ if (!first) strings::StrAppend(&ret, ", ");
+ first = false;
+ strings::StrAppend(&ret, "_device=\"", node_def.device(), "\"");
+ }
+ strings::StrAppend(&ret, "](");
+
+ // Output inputs, including control inputs, verbatim.
+ first = true;
+ for (const string& input : node_def.input()) {
+ if (!first) strings::StrAppend(&ret, ", ");
+ first = false;
+ strings::StrAppend(&ret, input);
+ }
+ strings::StrAppend(&ret, ")");
+ return ret;
+}
+
+const AttrValue* AttrSlice::Find(const string& attr_name) const {
+ auto iter = attrs_->find(attr_name);
+ if (iter == attrs_->end()) return nullptr;
+ return &iter->second;
+}
+
+Status AttrSlice::Find(const string& attr_name,
+ const AttrValue** attr_value) const {
+ *attr_value = Find(attr_name);
+ if (*attr_value != nullptr) {
+ return Status::OK();
+ }
+ Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
+ if (ndef_) {
+ s = AttachDef(s, *ndef_);
+ }
+ return s;
+}
+
+// The ... is to allow the caller to inject some value validation code. Use
+// just ; if no additional validation code is needed.
+#define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
+ Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, \
+ TYPE* value) { \
+ const AttrValue* attr_value; \
+ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
+ TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE)); \
+ const auto& v = attr_value->FIELD(); \
+ __VA_ARGS__; \
+ *value = CAST; \
+ return Status::OK(); \
+ } \
+ Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, \
+ std::vector<TYPE>* value) { \
+ const AttrValue* attr_value; \
+ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
+ TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
+ for (const auto& v : attr_value->list().FIELD()) { \
+ __VA_ARGS__; \
+ value->APPEND_OP(CAST); \
+ } \
+ return Status::OK(); \
+ }
+
+DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
+DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
+DEFINE_GET_ATTR(int32, i, "int", emplace_back, static_cast<int32>(v),
+ if (static_cast<int64>(static_cast<int32>(v)) != v) {
+ return errors::InvalidArgument("Attr ", attr_name,
+ " has value ", v,
+ " out of range for an int32");
+ })
+DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
+// std::vector<bool> specialization does not have emplace_back until
+// c++14, so we have to use push_back (see
+// http://en.cppreference.com/w/cpp/container/vector/emplace_back)
+DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;)
+DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
+ ;)
+DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
+DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v), ;)
+DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t;
+ if (!t.FromProto(v)) {
+ return errors::InvalidArgument(
+ "Attr ", attr_name, " has value ", v.ShortDebugString(),
+ " that can't be converted to a Tensor");
+ })
+
+#undef DEFINE_GET_ATTR
+
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ DataTypeVector* value) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
+ TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)"));
+ for (const auto& v : attr_value->list().type()) {
+ value->push_back(static_cast<DataType>(v));
+ }
+ return Status::OK();
+}
+
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ const TensorProto** value) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
+ TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor"));
+ *value = &attr_value->tensor();
+ return Status::OK();
+}
+
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ const NameAttrList** value) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
+ TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
+ *value = &attr_value->func();
+ return Status::OK();
+}
+
+namespace { // Helper for InOutTypesForNode().
+
+Status AddArgToSig(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
+ DataTypeVector* sig) {
+ const int original_size = sig->size();
+ if (!arg_def.number_attr().empty()) {
+ // Same type repeated "repeats" times.
+ int32 repeats = -1;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.number_attr(), &repeats));
+ if (repeats < 0) {
+ return errors::InvalidArgument("Value for number_attr() ", repeats,
+ " < 0");
+ }
+
+ if (!arg_def.type_attr().empty()) {
+ DataType dtype;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.type_attr(), &dtype));
+ for (int i = 0; i < repeats; ++i) {
+ sig->push_back(dtype);
+ }
+ } else if (arg_def.type() != DT_INVALID) {
+ for (int i = 0; i < repeats; ++i) {
+ sig->push_back(arg_def.type());
+ }
+ } else {
+ return errors::InvalidArgument("Missing type or type_attr field in ",
+ arg_def.ShortDebugString());
+ }
+ } else if (!arg_def.type_attr().empty()) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(
+ AttrSlice(node_def).Find(arg_def.type_attr(), &attr_value));
+ sig->push_back(attr_value->type());
+ } else if (!arg_def.type_list_attr().empty()) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(
+ AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value));
+ for (int dtype : attr_value->list().type()) {
+ sig->push_back(static_cast<DataType>(dtype));
+ }
+ } else if (arg_def.type() != DT_INVALID) {
+ sig->push_back(arg_def.type());
+ } else {
+ return errors::InvalidArgument("No type fields in ",
+ arg_def.ShortDebugString());
+ }
+ if (arg_def.is_ref()) {
+ // For all types that were added by this function call, make them refs.
+ for (size_t i = original_size; i < sig->size(); ++i) {
+ (*sig)[i] = MakeRefType((*sig)[i]);
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs, DataTypeVector* outputs) {
+ for (const auto& arg : op_def.input_arg()) {
+ TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
+ }
+ for (const auto& arg : op_def.output_arg()) {
+ TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs));
+ }
+ return Status::OK();
+}
+
+Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
+ if (node_def.op() != op_def.name()) {
+ return errors::InvalidArgument("NodeDef op '", node_def.op(),
+ "' does not match ", SummarizeOpDef(op_def),
+ "; NodeDef: ", SummarizeNodeDef(node_def));
+ }
+
+ bool seen_control = false;
+ size_t num_inputs = 0;
+ // TODO(josh11b): Unify the input field validation.
+ for (const string& input : node_def.input()) {
+ if (StringPiece(input).starts_with("^")) {
+ seen_control = true;
+ if (input.find(':') != string::npos) {
+ return errors::InvalidArgument("Control input '", input,
+ "' must not have ':' in NodeDef: ",
+ SummarizeNodeDef(node_def));
+ }
+ } else if (seen_control) {
+ return errors::InvalidArgument("Non-control input '", input,
+ "' after control input in NodeDef: ",
+ SummarizeNodeDef(node_def));
+ } else {
+ ++num_inputs;
+ }
+ }
+
+ std::unordered_map<string, const OpDef::AttrDef*> op_attrs;
+ for (const auto& attr : op_def.attr()) {
+ if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
+ return errors::InvalidArgument("OpDef has duplicate attr name '",
+ attr.name(), "': ",
+ SummarizeOpDef(op_def));
+ }
+ }
+ for (const auto& attr : node_def.attr()) {
+ // Allow internal optional attributes with names starting with "_".
+ if (StringPiece(attr.first).starts_with("_")) {
+ continue;
+ }
+ auto iter = op_attrs.find(attr.first);
+ if (iter == op_attrs.end()) {
+ return errors::InvalidArgument("NodeDef mentions attr '", attr.first,
+ "' not in ", SummarizeOpDef(op_def),
+ "; NodeDef: ", SummarizeNodeDef(node_def));
+ }
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ ValidateAttrValue(attr.second, *iter->second), "; NodeDef: ",
+ SummarizeNodeDef(node_def), "; ", SummarizeOpDef(op_def));
+ // Keep track of which attr names have (not) been found in the NodeDef.
+ op_attrs.erase(iter);
+ }
+
+ // Were all attrs in the OpDef found in the NodeDef?
+ if (!op_attrs.empty()) {
+ string attrs;
+ for (const auto& attr_pair : op_attrs) {
+ if (!attrs.empty()) strings::StrAppend(&attrs, "', '");
+ strings::StrAppend(&attrs, attr_pair.first);
+ }
+ return errors::InvalidArgument("NodeDef missing attr",
+ op_attrs.size() == 1 ? " '" : "s '", attrs,
+ "' from ", SummarizeOpDef(op_def),
+ "; NodeDef: ", SummarizeNodeDef(node_def));
+ }
+
+ // Validate the number of inputs.
+ DataTypeVector inputs, outputs;
+ TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs));
+
+ if (num_inputs != inputs.size()) {
+ return errors::InvalidArgument(
+ "NodeDef expected inputs '", DataTypeVectorString(inputs),
+ "' do not match ", num_inputs, " inputs specified; ",
+ SummarizeOpDef(op_def), "; NodeDef: ", SummarizeNodeDef(node_def));
+ }
+
+ return Status::OK();
+}
+
+namespace { // Helpers for NameRangesForNode()
+
+Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
+ const OpDef& op_def, int* num) {
+ if (!arg_def.number_attr().empty()) {
+ // Same type repeated "num" times.
+ return GetNodeAttr(node_def, arg_def.number_attr(), num);
+ } else if (!arg_def.type_list_attr().empty()) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(
+ AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value));
+ *num = attr_value->list().type_size();
+ } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
+ *num = 1;
+ } else {
+ return errors::InvalidArgument("Argument '", arg_def.name(),
+ "' incorrectly specified in op definition: ",
+ SummarizeOpDef(op_def));
+ }
+ return Status::OK();
+}
+
+Status NameRangesHelper(const NodeDef& node_def,
+ const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
+ const OpDef& op_def, NameRangeMap* result) {
+ int start = 0;
+ int num;
+ for (const auto& arg : args) {
+ TF_RETURN_IF_ERROR(ComputeArgRange(node_def, arg, op_def, &num));
+ (*result)[arg.name()] = std::make_pair(start, start + num);
+ start += num;
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
+ NameRangeMap* inputs, NameRangeMap* outputs) {
+ TF_RETURN_IF_ERROR(
+ NameRangesHelper(node_def, op_def.input_arg(), op_def, inputs));
+ return NameRangesHelper(node_def, op_def.output_arg(), op_def, outputs);
+}
+
+void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
+ for (const auto& attr_def : op_def.attr()) {
+ AttrSlice attrs(*node_def);
+ if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) {
+ AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def);
+ }
+ }
+}
+
+namespace {
+
+static RE2* valid_op_name_pattern = new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*");
+static RE2* valid_data_input_pattern =
+ new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*(\\:(0|([1-9][0-9]*)))?");
+static RE2* valid_control_input_pattern =
+ new RE2("\\^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*");
+
+} // namespace
+
+Status ValidateOpInput(const string& input_name, bool* is_control_input) {
+ *is_control_input = false;
+ if (RE2::FullMatch(input_name, *valid_data_input_pattern)) {
+ return Status::OK();
+ } else if (RE2::FullMatch(input_name, *valid_control_input_pattern)) {
+ *is_control_input = true;
+ return Status::OK();
+ } else {
+ return errors::InvalidArgument("Illegal op input name '", input_name, "'");
+ }
+}
+
+Status ValidateOpName(const string& op_name) {
+ if (RE2::FullMatch(op_name, *valid_op_name_pattern)) {
+ return Status::OK();
+ } else {
+ return errors::InvalidArgument("Illegal op name '", op_name, "'");
+ }
+}
+
+Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
+ Status s = ValidateOpName(node_def.name());
+ if (!s.ok()) {
+ return AttachDef(s, node_def);
+ }
+ bool in_control_inputs = false;
+ for (const string& input_name : node_def.input()) {
+ bool is_control_input;
+ s = ValidateOpInput(input_name, &is_control_input);
+ if (!s.ok()) {
+ return AttachDef(s, node_def);
+ }
+
+ if (in_control_inputs && !is_control_input) {
+ return AttachDef(errors::InvalidArgument(
+ "All control inputs must follow all data inputs"),
+ node_def);
+ }
+ in_control_inputs = is_control_input;
+ }
+ return Status::OK();
+}
+
+Status AttachDef(const Status& status, const NodeDef& node_def) {
+ Status ret = status;
+ errors::AppendToMessage(
+ &ret, strings::StrCat(" [[Node: ", SummarizeNodeDef(node_def), "]]"));
+ return ret;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
new file mode 100644
index 0000000000..fce6fd2433
--- /dev/null
+++ b/tensorflow/core/framework/node_def_util.h
@@ -0,0 +1,157 @@
+#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
+#define TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+// Produce a human-readable version of a NodeDef that is more concise
+// than a text-format proto.
+string SummarizeNodeDef(const NodeDef& node_def);
+
+typedef protobuf::Map<string, AttrValue> AttrValueMap;
+
+// Adds an attr with name <name> and value <value> to *node_def.
+// The type of the attr is based on the type of value.
+template <class T>
+void AddNodeAttr(const string& name, T&& value, NodeDef* node_def) {
+ AttrValue attr_value;
+ SetAttrValue(std::forward<T>(value), &attr_value);
+ node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value));
+}
+
+// Version to workaround C++'s "perfect" forwarding not being able to
+// forward {...} initialization.
+template <class T>
+void AddNodeAttr(const string& name, std::initializer_list<T> value,
+ NodeDef* node_def) {
+ AttrValue attr_value;
+ SetAttrValue(value, &attr_value);
+ node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value));
+}
+
+class AttrSlice {
+ public:
+ AttrSlice(const NodeDef& node_def) // NOLINT(runtime/explicit)
+ : ndef_(&node_def),
+ attrs_(&ndef_->attr()) {}
+
+ explicit AttrSlice(const AttrValueMap* a) : attrs_(a) {}
+
+ // Returns the attr with attr_name if found. Otherwise, returns
+ // nullptr.
+ const AttrValue* Find(const string& attr_name) const;
+
+ // Returns the attr_value for attr_name if found. Otherwise, returns a
+ // NotFound status.
+ Status Find(const string& attr_name, const AttrValue** attr_value) const;
+
+ private:
+ const NodeDef* ndef_ = nullptr;
+ const AttrValueMap* attrs_;
+};
+
+// Look up the attr with name attr_name and set *value to its value. If no
+// attr with attr_name is found in node_def, or the attr does not have
+// a matching type, a non-ok status will be returned.
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ string* value); // type: "string"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ int64* value); // type: "int"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ int32* value); // type: "int"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ float* value); // type: "float"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ bool* value); // type: "bool"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ DataType* value); // type: "type"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ TensorShapeProto* value); // type: "shape"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ TensorShape* value); // type: "shape"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ Tensor* value); // type: "tensor"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<string>* value); // type "list(string)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<int64>* value); // type "list(int)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<int32>* value); // type "list(int)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<float>* value); // type "list(float)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<bool>* value); // type "list(bool)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<DataType>* value); // type "list(type)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ DataTypeVector* value); // type "list(type)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<TensorShapeProto>* value); // type "list(shape)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<TensorShape>* value); // type "list(shape)"
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ std::vector<Tensor>* value); // type: "list(tensor)"
+
+// This version avoids copying the TensorProto.
+// REQUIRES: Must not use *value beyond the lifetime of node_def.
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ const TensorProto** value); // type: "tensor"
+
+// This version avoids copying the NameAttrList.
+// REQUIRES: Must not use *value beyond the lifetime of node_def.
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ const NameAttrList** value); // type: "func"
+
+// Computes the input and output types for a specific node, for
+// attr-style ops.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs, DataTypeVector* outputs);
+
+// Validates that the NodeDef:
+// * Defines all expected attrs from the OpDef.
+// * All attrs satisfies constraints from the OpDef.
+// * Has a signature matching SignatureForNode().
+// etc.
+Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);
+
+// Computes the mapping from input/output argument name to the
+// corresponding input/output index range. For example,
+// input "foo" coresponds to input indices
+// [ (*inputs)["foo"].first, (*inputs)["foo"].second ).
+typedef std::unordered_map<string, std::pair<int, int>> NameRangeMap;
+Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
+ NameRangeMap* inputs, NameRangeMap* outputs);
+
+// Adds default values to *node_def for unspecified attrs from op_def.
+void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def);
+
+// Validates the syntax of a NodeDef provided externally.
+//
+// The following is an EBNF-style syntax for NodeDef objects. Note that
+// Node objects are actually specified as tensorflow::NodeDef protocol buffers,
+// which contain many other fields that are not (currently) validated.
+//
+// Node = NodeName, Inputs
+// Inputs = ( DataInput * ), ( ControlInput * )
+// DataInput = NodeName, ( ":", [1-9], [0-9] * ) ?
+// ControlInput = "^", NodeName
+// NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] *
+Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);
+
+// Returns "status" with kernel's NodeDef attached as additional text
+// in the error message.
+Status AttachDef(const Status& status, const NodeDef& node_def);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
new file mode 100644
index 0000000000..71f1760a09
--- /dev/null
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -0,0 +1,442 @@
+#include "tensorflow/core/framework/node_def_util.h"
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+OpDef ToOpDef(const OpDefBuilder& builder) {
+ OpDef op_def;
+ EXPECT_OK(builder.Finalize(&op_def));
+ return op_def;
+}
+
+NodeDef ToNodeDef(const string& text) {
+ NodeDef node_def;
+ EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
+ return node_def;
+}
+
+NodeDef ToNodeDef(const NodeDefBuilder& builder) {
+ NodeDef node_def;
+ EXPECT_OK(builder.Finalize(&node_def));
+ return node_def;
+}
+
+void ExpectSuccess(const NodeDef& good, const OpDef& op_def) {
+ EXPECT_EQ(Status::OK(), ValidateNodeDef(good, op_def))
+ << "NodeDef: " << SummarizeNodeDef(good)
+ << "; OpDef: " << SummarizeOpDef(op_def);
+}
+
+void ExpectFailure(const NodeDef& bad, const OpDef& op_def,
+ const string& message) {
+ Status status = ValidateNodeDef(bad, op_def);
+
+ EXPECT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad)
+ << "; OpDef: " << SummarizeOpDef(op_def);
+ if (status.ok()) return;
+
+ EXPECT_TRUE(errors::IsInvalidArgument(status))
+ << status << "; NodeDef: " << SummarizeNodeDef(bad)
+ << "; OpDef: " << SummarizeOpDef(op_def);
+
+ LOG(INFO) << "Message: " << status.error_message();
+ EXPECT_TRUE(StringPiece(status.ToString()).contains(message))
+ << "NodeDef: " << SummarizeNodeDef(bad)
+ << "; OpDef: " << SummarizeOpDef(op_def) << "\nActual error: " << status
+ << "\nDoes not contain: " << message;
+}
+
+TEST(NodeDefUtilTest, In) {
+ const OpDef op = ToOpDef(OpDefBuilder("In").Input("i: T").Attr("T: type"));
+ const NodeDef node_def = ToNodeDef(R"proto(
+ name:'n' op:'In' input:'a' attr { key:'T' value { type:DT_FLOAT } }
+ )proto");
+ ExpectSuccess(node_def, op);
+
+ EXPECT_EQ("n = In[T=DT_FLOAT](a)", SummarizeNodeDef(node_def));
+
+ // Mismatching Op names.
+ NodeDef bad = node_def;
+ bad.set_op("Wrong");
+ ExpectFailure(bad, op, "NodeDef op 'Wrong' does not match Op<name=In;");
+
+ // Missing attr
+ bad = node_def;
+ bad.clear_attr();
+ ExpectFailure(bad, op, "NodeDef missing attr 'T' from Op<name=In;");
+
+ // Extra attr
+ bad = node_def;
+ AddNodeAttr("EXTRA", 17, &bad);
+ ExpectFailure(bad, op, "NodeDef mentions attr 'EXTRA' not in Op<name=In;");
+
+ // Attr has wrong type
+ bad = node_def;
+ bad.clear_attr();
+ AddNodeAttr("T", 17, &bad);
+ ExpectFailure(
+ bad, op,
+ "AttrValue had value with type int when type expected\n\t for attr "
+ "'T'\n\t; NodeDef: ");
+
+ // Wrong number of inputs
+ bad = node_def;
+ bad.add_input("b");
+ ExpectFailure(
+ bad, op,
+ "NodeDef expected inputs 'float' do not match 2 inputs specified;");
+
+ bad = node_def;
+ bad.clear_input();
+ ExpectFailure(
+ bad, op,
+ "NodeDef expected inputs 'float' do not match 0 inputs specified;");
+
+ // Control inputs must appear after data inputs
+ NodeDef good = node_def;
+ good.add_input("^b");
+ ExpectSuccess(node_def, op);
+
+ bad = node_def;
+ bad.clear_input();
+ bad.add_input("^b");
+ bad.add_input("a");
+ ExpectFailure(bad, op,
+ "Invalid argument: Non-control input 'a' after control input "
+ "in NodeDef:");
+
+ bad = node_def;
+ bad.add_input("^b:0");
+ ExpectFailure(bad, op, "Control input '^b:0' must not have ':' in NodeDef:");
+}
+
+TEST(NodeDefUtilTest, Out) {
+ const OpDef op =
+ ToOpDef(OpDefBuilder("Out").Output("o: T").Attr("T: numbertype"));
+ const NodeDef node_def = ToNodeDef(R"proto(
+ name:'n' op:'Out' attr { key:'T' value { type:DT_INT32 } }
+ )proto");
+ ExpectSuccess(node_def, op);
+
+ EXPECT_EQ("n = Out[T=DT_INT32]()", SummarizeNodeDef(node_def));
+
+ // Non-number type.
+ NodeDef bad = node_def;
+ bad.clear_attr();
+ AddNodeAttr("T", DT_STRING, &bad);
+ ExpectFailure(bad, op,
+ "Value for attr 'T' of string is not in the list of allowed "
+ "values: float, double, int64, int32, uint8, int16, int8, "
+ "complex64, qint8, quint8, qint32");
+}
+
+TEST(NodeDefUtilTest, Enum) {
+ const OpDef op = ToOpDef(OpDefBuilder("Enum").Attr("e: {'apple','orange'}"));
+ const NodeDef node_def = ToNodeDef(R"proto(
+ name:'n' op:'Enum' attr { key:'e' value { s:'apple' } }
+ )proto");
+ ExpectSuccess(node_def, op);
+
+ EXPECT_EQ("n = Enum[e=\"apple\"]()", SummarizeNodeDef(node_def));
+
+ NodeDef good = node_def;
+ good.clear_attr();
+ AddNodeAttr("e", "orange", &good);
+ ExpectSuccess(good, op);
+
+ // Non-allowed value.
+ NodeDef bad = node_def;
+ bad.clear_attr();
+ AddNodeAttr("e", "foo", &bad);
+ ExpectFailure(bad, op,
+ "Value for attr 'e' of \"foo\" is not in the list of allowed "
+ "values: \"apple\", \"orange\"");
+}
+
+TEST(NodeDefUtilTest, SameIn) {
+ const OpDef op = ToOpDef(OpDefBuilder("SameIn")
+ .Input("i: N * T")
+ .Attr("N: int >= 2")
+ .Attr("T: {float,double}"));
+ const NodeDef node_def = ToNodeDef(R"proto(
+ name:'n' op:'SameIn' input:'a' input:'b'
+ attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_DOUBLE } }
+ )proto");
+ ExpectSuccess(node_def, op);
+
+ EXPECT_EQ("n = SameIn[N=2, T=DT_DOUBLE](a, b)", SummarizeNodeDef(node_def));
+
+ // Illegal type
+ NodeDef bad = ToNodeDef(R"proto(
+ name:'n' op:'SameIn' input:'a' input:'b'
+ attr { key:'N' value { i:2 } } attr { key:'T' value { type:DT_STRING } }
+ )proto");
+ ExpectFailure(bad, op,
+ "Value for attr 'T' of string is not in the list of allowed "
+ "values: float, double");
+
+ // Too few inputs
+ bad = ToNodeDef(R"proto(
+ name:'n' op:'SameIn' input:'a' input:'b'
+ attr { key:'N' value { i:1 } } attr { key:'T' value { type:DT_FLOAT } }
+ )proto");
+ ExpectFailure(bad, op, "Value for attr 'N' of 1 must be at least minimum 2");
+}
+
+TEST(NodeDefUtilTest, AnyIn) {
+ const OpDef op =
+ ToOpDef(OpDefBuilder("AnyIn").Input("i: T").Attr("T: list(type) >= 1"));
+
+ const NodeDef node_def = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectSuccess(node_def, op);
+
+ EXPECT_EQ("n = AnyIn[T=[DT_INT32, DT_STRING]](a, b)",
+ SummarizeNodeDef(node_def));
+
+ const NodeDef bad = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a' attr { key:'T' value { list { } } }
+ )proto");
+ ExpectFailure(bad, op, "Length for attr 'T' of 0 must be at least minimum 1");
+
+ // With proto3 semantics, an empty value {} is indistinguishable from a value
+ // with an empty list in it. So we simply expect to get a message complaining
+ // about empty list for value {}.
+ const NodeDef bad2 = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a' attr { key:'T' value { } }
+ )proto");
+ ExpectFailure(bad2, op,
+ "Length for attr 'T' of 0 must be at least minimum 1");
+}
+
+TEST(NodeDefUtilTest, Device) {
+ const OpDef op_def1 = ToOpDef(OpDefBuilder("None"));
+ const NodeDef node_def1 =
+ ToNodeDef(NodeDefBuilder("d", &op_def1).Device("/cpu:17"));
+ ExpectSuccess(node_def1, op_def1);
+ EXPECT_EQ("d = None[_device=\"/cpu:17\"]()", SummarizeNodeDef(node_def1));
+
+ const OpDef op_def2 = ToOpDef(OpDefBuilder("WithAttr").Attr("v: int"));
+ const NodeDef node_def2 =
+ ToNodeDef(NodeDefBuilder("d", &op_def2).Attr("v", 7).Device("/cpu:5"));
+ ExpectSuccess(node_def2, op_def2);
+ EXPECT_EQ("d = WithAttr[v=7, _device=\"/cpu:5\"]()",
+ SummarizeNodeDef(node_def2));
+}
+
+void ExpectValidSyntax(const NodeDef& good) {
+ EXPECT_EQ(Status::OK(), ValidateExternalNodeDefSyntax(good))
+ << "NodeDef: " << SummarizeNodeDef(good);
+}
+
+void ExpectInvalidSyntax(const NodeDef& bad, const string& message) {
+ Status status = ValidateExternalNodeDefSyntax(bad);
+
+ ASSERT_FALSE(status.ok()) << "NodeDef: " << SummarizeNodeDef(bad);
+
+ EXPECT_TRUE(errors::IsInvalidArgument(status))
+ << status << "; NodeDef: " << SummarizeNodeDef(bad);
+
+ EXPECT_TRUE(StringPiece(status.ToString()).contains(message))
+ << "NodeDef: " << SummarizeNodeDef(bad) << ", " << status << ", "
+ << message;
+}
+
+TEST(NodeDefUtilTest, ValidSyntax) {
+ const NodeDef node_def = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectValidSyntax(node_def);
+
+ const NodeDef node_def_explicit_inputs = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a:0' input:'b:123'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectValidSyntax(node_def_explicit_inputs);
+
+ EXPECT_EQ("n = AnyIn[T=[DT_INT32, DT_STRING]](a:0, b:123)",
+ SummarizeNodeDef(node_def_explicit_inputs));
+
+ const NodeDef node_def_control_input = ToNodeDef(R"proto(
+ name:'n-' op:'AnyIn' input:'a' input:'^b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectValidSyntax(node_def_control_input);
+
+ const NodeDef node_def_invalid_name = ToNodeDef(R"proto(
+ name:'n:0' op:'AnyIn' input:'a' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_invalid_name, "Illegal op name 'n:0'");
+
+ const NodeDef node_def_internal_name = ToNodeDef(R"proto(
+ name:'_n' op:'AnyIn' input:'a' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_internal_name, "Illegal op name '_n'");
+
+ const NodeDef node_def_internal_input_name = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'_a' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_internal_input_name,
+ "Illegal op input name '_a'");
+
+ const NodeDef node_def_invalid_control_input_name = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'a' input:'^b:0'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_invalid_control_input_name,
+ "Illegal op input name '^b:0'");
+
+ const NodeDef node_def_data_input_after_control = ToNodeDef(R"proto(
+ name:'n' op:'AnyIn' input:'^a' input:'b'
+ attr { key:'T' value { list { type: [DT_INT32, DT_STRING] } } }
+ )proto");
+ ExpectInvalidSyntax(node_def_data_input_after_control,
+ "All control inputs must follow all data inputs");
+}
+
+TEST(NameRangesForNodeTest, Simple) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+ .Input("a: float")
+ .Input("b: int32")
+ .Output("c: string")
+ .Output("d: bool"));
+ NameRangeMap inputs, outputs;
+ const NodeDef node_def = ToNodeDef(
+ NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+ EXPECT_OK(NameRangesForNode(node_def, op_def, &inputs, &outputs));
+ EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
+ EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 2}}}), outputs);
+
+ EXPECT_EQ("simple = Simple[](a, b)", SummarizeNodeDef(node_def));
+
+ OpDef bad_op_def = op_def;
+ bad_op_def.mutable_input_arg(0)->clear_type();
+ EXPECT_FALSE(NameRangesForNode(node_def, bad_op_def, &inputs, &outputs).ok());
+}
+
+TEST(NameRangesForNodeTest, Polymorphic) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Polymorphic")
+ .Input("a: T")
+ .Input("b: T")
+ .Output("c: T")
+ .Attr("T: type"));
+ NameRangeMap inputs, outputs;
+ const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("poly", &op_def)
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_INT32)));
+ EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
+ EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
+ EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
+ EXPECT_EQ("poly = Polymorphic[T=DT_INT32](a, b)",
+ SummarizeNodeDef(node_def1));
+
+ const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("poly", &op_def)
+ .Input(FakeInput(DT_BOOL))
+ .Input(FakeInput(DT_BOOL)));
+ EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
+ EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
+ EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
+ EXPECT_EQ("poly = Polymorphic[T=DT_BOOL](a, b)", SummarizeNodeDef(node_def2));
+}
+
+TEST(NameRangesForNodeTest, NRepeats) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("NRepeats")
+ .Input("a: N * int32")
+ .Input("b: N * T")
+ .Output("c: T")
+ .Output("d: N * string")
+ .Output("e: M * bool")
+ .Attr("N: int")
+ .Attr("M: int")
+ .Attr("T: type"));
+ NameRangeMap inputs, outputs;
+ const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("nr", &op_def)
+ .Input(FakeInput(4, DT_INT32))
+ .Input(FakeInput(4, DT_FLOAT))
+ .Attr("M", 3));
+ EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
+ EXPECT_EQ(NameRangeMap({{"a", {0, 4}}, {"b", {4, 8}}}), inputs);
+ EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 5}}, {"e", {5, 8}}}),
+ outputs);
+ EXPECT_EQ(
+ "nr = NRepeats[M=3, N=4, T=DT_FLOAT](a, a:1, a:2, a:3, b, b:1, b:2, b:3)",
+ SummarizeNodeDef(node_def1));
+
+ const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("nr", &op_def)
+ .Input(FakeInput(2, DT_INT32))
+ .Input(FakeInput(2, DT_DOUBLE))
+ .Attr("M", 7));
+ EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
+ EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 4}}}), inputs);
+ EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
+ outputs);
+ EXPECT_EQ("nr = NRepeats[M=7, N=2, T=DT_DOUBLE](a, a:1, b, b:1)",
+ SummarizeNodeDef(node_def2));
+
+ NodeDef bad_node_def = node_def2;
+ bad_node_def.clear_attr();
+ EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
+}
+
+TEST(NameRangesForNodeTest, TypeList) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("TypeList")
+ .Input("a: T1")
+ .Input("b: T2")
+ .Output("c: T2")
+ .Output("d: T3")
+ .Output("e: T1")
+ .Attr("T1: list(type)")
+ .Attr("T2: list(type)")
+ .Attr("T3: list(type)"));
+ NameRangeMap inputs, outputs;
+ const NodeDef node_def1 =
+ ToNodeDef(NodeDefBuilder("tl", &op_def)
+ .Input(FakeInput({DT_BOOL, DT_FLOAT}))
+ .Input(FakeInput(4, DT_FLOAT))
+ .Attr("T3", {DT_INT32, DT_DOUBLE, DT_STRING}));
+ EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
+ EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 6}}}), inputs);
+ EXPECT_EQ(NameRangeMap({{"c", {0, 4}}, {"d", {4, 7}}, {"e", {7, 9}}}),
+ outputs);
+ EXPECT_EQ(
+ "tl = TypeList[T1=[DT_BOOL, DT_FLOAT],"
+ " T2=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT],"
+ " T3=[DT_INT32, DT_DOUBLE, DT_STRING]](a, a:1, b, b:1, b:2, b:3)",
+ SummarizeNodeDef(node_def1));
+
+ const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("tl", &op_def)
+ .Input(FakeInput(7, DT_INT32))
+ .Input(FakeInput({DT_DOUBLE}))
+ .Attr("T3", {DT_DOUBLE, DT_STRING}));
+ EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
+ EXPECT_EQ(NameRangeMap({{"a", {0, 7}}, {"b", {7, 8}}}), inputs);
+ EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
+ outputs);
+ EXPECT_EQ(
+ "tl = TypeList[T1=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32,"
+ " DT_INT32, DT_INT32], T2=[DT_DOUBLE], T3=[DT_DOUBLE, DT_STRING]]"
+ "(a, a:1, a:2, a:3, a:4, a:5, a:6, b)",
+ SummarizeNodeDef(node_def2));
+
+ NodeDef bad_node_def = node_def2;
+ bad_node_def.clear_attr();
+ EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/numeric_op.h b/tensorflow/core/framework/numeric_op.h
new file mode 100644
index 0000000000..8413d18f33
--- /dev/null
+++ b/tensorflow/core/framework/numeric_op.h
@@ -0,0 +1,96 @@
+#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_
+#define TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// One input and one output, both the same type.
+template <class T>
+class UnaryOp : public OpKernel {
+ public:
+ explicit UnaryOp(OpKernelConstruction* context) : OpKernel(context) {
+ const DataType dt = DataTypeToEnum<T>::v();
+ OP_REQUIRES_OK(context, context->MatchSignature({dt}, {dt}));
+ }
+};
+
+// Two inputs and one output, all the same type.
+template <class T>
+class BinaryOp : public OpKernel {
+ public:
+ explicit BinaryOp(OpKernelConstruction* context) : OpKernel(context) {
+ const DataType dt = DataTypeToEnum<T>::v();
+ OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt}));
+ }
+};
+
+// For operations where the input and output are the same shape.
+//
+// For usage, see ../framework/elementwise_ops.cc.
+template <class T, class CHILD>
+class UnaryElementWiseOp : public UnaryOp<T> {
+ public:
+ using UnaryOp<T>::UnaryOp;
+
+ void Compute(OpKernelContext* context) override {
+ // Output shape is the same as input shape.
+ const Tensor& input = context->input(0);
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+ static_cast<CHILD*>(this)->Operate(context, input, output);
+ }
+};
+
+// For binary elementwise operations.
+template <class T, class CHILD>
+class BinaryElementWiseOp : public BinaryOp<T> {
+ public:
+ using BinaryOp<T>::BinaryOp;
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& a = context->input(0);
+ const Tensor& b = context->input(1);
+
+ if (!context->ValidateInputsAreSameShape(this)) {
+ return;
+ }
+
+ Tensor* output;
+ OP_REQUIRES_OK(context, context->allocate_output(0, a.shape(), &output));
+
+ // Dispatch to the descendant's Operate() function.
+ switch (a.dims()) {
+#define NDIM_CASE(NDIMS) \
+ case NDIMS: { \
+ static_cast<CHILD*>(this)->template Operate<NDIMS>(context, a, b, output); \
+ break; \
+ }
+
+ NDIM_CASE(1);
+ NDIM_CASE(2);
+ NDIM_CASE(3);
+ NDIM_CASE(4);
+ NDIM_CASE(5);
+ NDIM_CASE(6);
+ NDIM_CASE(7);
+ NDIM_CASE(8);
+#undef NDIM_CASE
+
+ default:
+ context->SetStatus(errors::OutOfRange(
+ "We only handle up to Tensor::dims() up to 8, not ", a.dims()));
+ break;
+ }
+ }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h
new file mode 100644
index 0000000000..366f00ae03
--- /dev/null
+++ b/tensorflow/core/framework/numeric_types.h
@@ -0,0 +1,15 @@
+#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_
+#define TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_
+
+#include <complex>
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Single precision complex.
+typedef std::complex<float> complex64;
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_
diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc
new file mode 100644
index 0000000000..15b7eab4da
--- /dev/null
+++ b/tensorflow/core/framework/op.cc
@@ -0,0 +1,135 @@
+#include "tensorflow/core/framework/op.h"
+
+#include <algorithm>
+#include <memory>
+#include <vector>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+// OpRegistry -----------------------------------------------------------------
+
+OpRegistryInterface::~OpRegistryInterface() {}
+
+OpRegistry::OpRegistry() : initialized_(false) {}
+
+void OpRegistry::Register(std::function<OpDef(void)> func) {
+ mutex_lock lock(mu_);
+ if (initialized_) {
+ OpDef def = func();
+ TF_QCHECK_OK(RegisterAlreadyLocked(def)) << "Attempting to register: "
+ << SummarizeOpDef(def);
+ } else {
+ deferred_.push_back(func);
+ }
+}
+
+const OpDef* OpRegistry::LookUp(const string& op_type_name,
+ Status* status) const {
+ const OpDef* op_def = nullptr;
+ bool first_call = false;
+ { // Scope for lock.
+ mutex_lock lock(mu_);
+ first_call = CallDeferred();
+ op_def = gtl::FindWithDefault(registry_, op_type_name, nullptr);
+ // Note: Can't hold mu_ while calling Export() below.
+ }
+ if (first_call) {
+ TF_QCHECK_OK(ValidateKernelRegistrations(this));
+ }
+ if (op_def == nullptr) {
+ status->Update(
+ errors::NotFound("Op type not registered '", op_type_name, "'"));
+ static bool first = true;
+ if (first) {
+ OpList op_list;
+ Export(true, &op_list);
+ LOG(INFO) << "All registered Ops:";
+ for (const auto& op : op_list.op()) {
+ LOG(INFO) << SummarizeOpDef(op);
+ }
+ first = false;
+ }
+ }
+ return op_def;
+}
+
+void OpRegistry::Export(bool include_internal, OpList* ops) const {
+ mutex_lock lock(mu_);
+ CallDeferred();
+
+ std::vector<std::pair<string, const OpDef*>> sorted(registry_.begin(),
+ registry_.end());
+ std::sort(sorted.begin(), sorted.end());
+
+ auto out = ops->mutable_op();
+ out->Clear();
+ out->Reserve(sorted.size());
+
+ for (const auto& item : sorted) {
+ if (include_internal || !StringPiece(item.first).starts_with("_")) {
+ *out->Add() = *item.second;
+ }
+ }
+}
+
+string OpRegistry::DebugString(bool include_internal) const {
+ OpList op_list;
+ Export(include_internal, &op_list);
+ string ret;
+ for (const auto& op : op_list.op()) {
+ strings::StrAppend(&ret, SummarizeOpDef(op), "\n");
+ }
+ return ret;
+}
+
+bool OpRegistry::CallDeferred() const {
+ if (initialized_) return false;
+ initialized_ = true;
+ for (const auto& fn : deferred_) {
+ OpDef def = fn();
+ TF_QCHECK_OK(RegisterAlreadyLocked(def)) << "Attempting to register: "
+ << SummarizeOpDef(def);
+ }
+ deferred_.clear();
+ return true;
+}
+
+Status OpRegistry::RegisterAlreadyLocked(const OpDef& def) const {
+ TF_RETURN_IF_ERROR(ValidateOpDef(def));
+
+ std::unique_ptr<OpDef> copy(new OpDef(def));
+ if (gtl::InsertIfNotPresent(&registry_, def.name(), copy.get())) {
+ copy.release(); // Ownership transferred to op_registry
+ return Status::OK();
+ } else {
+ return errors::AlreadyExists("Op with name ", def.name());
+ }
+}
+
+// static
+OpRegistry* OpRegistry::Global() {
+ static OpRegistry* global_op_registry = new OpRegistry;
+ return global_op_registry;
+}
+
+namespace register_op {
+OpDefBuilder& RegisterOp(StringPiece name) {
+ VLOG(1) << "RegisterOp: " << name;
+ OpDefBuilder* b = new OpDefBuilder(name);
+ OpRegistry::Global()->Register([b]() -> ::tensorflow::OpDef {
+ OpDef op_def;
+ TF_QCHECK_OK(b->Finalize(&op_def));
+ delete b;
+ return op_def;
+ });
+ return *b;
+}
+} // namespace register_op
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h
new file mode 100644
index 0000000000..95ad32df35
--- /dev/null
+++ b/tensorflow/core/framework/op.h
@@ -0,0 +1,122 @@
+#ifndef TENSORFLOW_FRAMEWORK_OP_H_
+#define TENSORFLOW_FRAMEWORK_OP_H_
+
+#include <functional>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// Users that want to look up an OpDef by type name should take an
+// OpRegistryInterface. Functions accepting a
+// (const) OpRegistryInterface* may call LookUp() from multiple threads.
+class OpRegistryInterface {
+ public:
+ virtual ~OpRegistryInterface();
+
+ // Returns nullptr and sets *status if no OpDef is registered under that
+ // name, otherwise returns the registered OpDef.
+ // Caller must not delete the returned pointer.
+ virtual const OpDef* LookUp(const string& op_type_name,
+ Status* status) const = 0;
+};
+
+// The standard implementation of OpRegistryInterface, along with a
+// global singleton used for registering OpDefs via the REGISTER
+// macros below. Thread-safe.
+//
+// Example registration:
+// OpRegistry::Global()->Register([]()->OpDef{
+// OpDef def;
+// // Populate def here.
+// return def;
+// });
+class OpRegistry : public OpRegistryInterface {
+ public:
+ OpRegistry();
+ ~OpRegistry() override {}
+
+ // Calls func() and registers the returned OpDef. Since Register()
+ // is normally called during program initialization (before main()),
+ // we defer calling func() until the first call to LookUp() or
+ // Export() (if one of those has already been called, func() is
+ // called immediately).
+ void Register(std::function<OpDef(void)> func);
+
+ const OpDef* LookUp(const string& op_type_name,
+ Status* status) const override;
+
+ // Fills *ops with all registered OpDefss (except those with names
+ // starting with '_' if include_internal == false).
+ void Export(bool include_internal, OpList* ops) const;
+
+ // Returns ASCII-format OpList for all registered OpDefs (except
+ // those with names starting with '_' if include_internal == false).
+ string DebugString(bool include_internal) const;
+
+ // A singleton available at startup.
+ static OpRegistry* Global();
+
+ private:
+ // Ensures that all the functions in deferred_ get called, their OpDef's
+ // registered, and returns with deferred_ empty. Returns true the first
+ // time it is called.
+ bool CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Add 'def' to the registry. On failure, or if there is already an
+ // OpDef with that name registered, returns a non-okay status.
+ Status RegisterAlreadyLocked(const OpDef& def) const
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ mutable mutex mu_;
+ // Functions in deferred_ may only be called with mu_ held.
+ mutable std::vector<std::function<OpDef(void)>> deferred_ GUARDED_BY(mu_);
+ mutable std::unordered_map<string, OpDef*> registry_ GUARDED_BY(mu_);
+ mutable bool initialized_ GUARDED_BY(mu_);
+};
+
+// Support for defining the OpDef (specifying the semantics of the Op and how
+// it should be created) and registering it in the OpRegistry::Global()
+// registry. Usage:
+//
+// REGISTER_OP("my_op_name")
+// .Attr("<name>:<type>")
+// .Attr("<name>:<type>=<default>")
+// .Input("<name>:<type-expr>")
+// .Input("<name>:Ref(<type-expr>)")
+// .Output("<name>:<type-expr>")
+// .Doc(R"(
+// <1-line summary>
+// <rest of the description (potentially many lines)>
+// <name-of-attr-input-or-output>: <description of name>
+// <name-of-attr-input-or-output>: <description of name;
+// if long, indent the description on subsequent lines>
+// )");
+//
+// Note: .Doc() should be last.
+// For details, see the OpDefBuilder class in op_def_builder.h.
+
+namespace register_op {
+// To call OpRegistry::Global()->Register(...), used by the
+// REGISTER_OP macro below.
+OpDefBuilder& RegisterOp(StringPiece name);
+} // namespace register_op
+
+#define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
+#define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
+#define REGISTER_OP_UNIQ(ctr, name) \
+ static ::tensorflow::OpDefBuilder& register_op##ctr TF_ATTRIBUTE_UNUSED = \
+ ::tensorflow::register_op::RegisterOp(name)
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_OP_H_
diff --git a/tensorflow/core/framework/op_def.proto b/tensorflow/core/framework/op_def.proto
new file mode 100644
index 0000000000..4a2e90b1b9
--- /dev/null
+++ b/tensorflow/core/framework/op_def.proto
@@ -0,0 +1,142 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/attr_value.proto";
+import "tensorflow/core/framework/types.proto";
+
+// Defines an operation. A NodeDef in a GraphDef specifies an Op by
+// using the "op" field which should match the name of a OpDef.
+message OpDef {
+ // Op names starting with an underscore are reserved for internal use.
+ // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*".
+ string name = 1;
+
+ // For describing inputs and outputs.
+ message ArgDef {
+ // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*".
+ string name = 1;
+
+ // Human readable description.
+ string description = 2;
+
+ // Describes the type of one or more tensors that are accepted/produced
+ // by this input/output arg. The only legal combinations are:
+ // * For a single tensor: either the "type" field is set or the
+ // "type_attr" field is set to the name of an attr with type "type".
+ // * For a sequence of tensors with the same type: the "number_attr"
+ // field will be set to the name of an attr with type "int", and
+ // either the "type" or "type_attr" field will be set as for
+ // single tensors.
+ // * For a sequence of tensors, the "type_list_attr" field will be set
+ // to the name of an attr with type "list(type)".
+ DataType type = 3;
+ string type_attr = 4; // if specified, attr must have type "type"
+ string number_attr = 5; // if specified, attr must have type "int"
+ // If specified, attr must have type "list(type)", and none of
+ // type, type_attr, and number_attr may be specified.
+ string type_list_attr = 6;
+
+ // For inputs: if true, the inputs are required to be refs.
+ // By default, inputs can be either refs or non-refs.
+ // For outputs: if true, outputs are refs, otherwise they are not.
+ bool is_ref = 16;
+ };
+
+ // Description of the input(s).
+ repeated ArgDef input_arg = 2;
+
+ // Description of the output(s).
+ repeated ArgDef output_arg = 3;
+
+ // Description of the graph-construction-time configuration of this
+ // Op. That is to say, this describes the attr fields that will
+ // be specified in the NodeDef.
+ message AttrDef {
+ // A descriptive name for the argument. May be used, e.g. by the
+ // Python client, as a keyword argument name, and so should match
+ // the regexp "[a-z][a-z0-9_]+".
+ string name = 1;
+
+ // One of the type names from attr_value.proto ("string", "list(string)",
+ // "int", etc.).
+ string type = 2;
+
+ // A reasonable default for this attribute if the user does not supply
+ // a value. If not specified, the user must supply a value.
+ AttrValue default_value = 3;
+
+ // Human-readable description.
+ string description = 4;
+
+ // TODO(josh11b): bool is_optional?
+
+ // --- Constraints ---
+ // These constraints are only in effect if specified. Default is no
+ // constraints.
+
+ // For type == "int", this is a minimum value. For "list(___)"
+ // types, this is the minimum length.
+ bool has_minimum = 5;
+ int64 minimum = 6;
+
+ // The set of allowed values. Has type that is the "list" version
+ // of the "type" field above (uses the ..._list, fields of AttrValue).
+ // If type == "type" or "list(type)" above, then the type_list field
+ // of allowed_values has the set of allowed DataTypes.
+ // If type == "string" or "list(string)", then the s_list field has
+ // the set of allowed strings.
+ AttrValue allowed_values = 7;
+ }
+ repeated AttrDef attr = 4;
+
+ // One-line human-readable description of what the Op does.
+ string summary = 5;
+
+ // Additional, longer human-readable description of what the Op does.
+ string description = 6;
+
+ // -------------------------------------------------------------------------
+ // Which optimizations this operation can participate in.
+
+ // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs)
+ bool is_commutative = 18;
+
+ // If is_aggregate is true, then this operation accepts N >= 2
+ // inputs and produces 1 output all of the same type. Should be
+ // associative and commutative, and produce output with the same
+ // shape as the input. The optimizer may replace an aggregate op
+ // taking input from multiple devices with a tree of aggregate ops
+ // that aggregate locally within each device (and possibly within
+ // groups of nearby devices) before communicating.
+ // TODO(josh11b): Implement that optimization.
+ bool is_aggregate = 16; // for things like add
+
+ // Other optimizations go here, like
+ // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc.
+
+ // -------------------------------------------------------------------------
+ // Optimization constraints.
+
+ // By default Ops may be moved between devices. Stateful ops should
+ // either not be moved, or should only be moved if that state can also
+ // be moved (e.g. via some sort of save / restore).
+ // Stateful ops are guaranteed to never be optimized away by Common
+ // Subexpression Elimination (CSE).
+ bool is_stateful = 17; // for things like variables, queue
+
+ // -------------------------------------------------------------------------
+ // Non-standard options.
+
+ // By default, all inputs to an Op must be initialized Tensors. Ops
+ // that may initialize tensors for the first time should set this
+ // field to true, to allow the Op to take an uninitialized Tensor as
+ // input.
+ bool allows_uninitialized_input = 19; // for Assign, etc.
+};
+
+// A collection of OpDefs
+message OpList {
+ repeated OpDef op = 1;
+};
diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc
new file mode 100644
index 0000000000..7d7c07de4c
--- /dev/null
+++ b/tensorflow/core/framework/op_def_builder.cc
@@ -0,0 +1,447 @@
+#include "tensorflow/core/framework/op_def_builder.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace tensorflow {
+
+namespace {
+
+bool RE2Consume(StringPiece* sp, const char* pattern) {
+ RegexpStringPiece base_sp = ToRegexpStringPiece(*sp);
+ bool r = RE2::Consume(&base_sp, pattern);
+ *sp = FromRegexpStringPiece(base_sp);
+ return r;
+}
+
+bool RE2Consume(StringPiece* sp, const char* pattern, StringPiece* out) {
+ RegexpStringPiece base_sp = ToRegexpStringPiece(*sp);
+ RegexpStringPiece base_out;
+ bool r = RE2::Consume(&base_sp, pattern, &base_out);
+ *sp = FromRegexpStringPiece(base_sp);
+ *out = FromRegexpStringPiece(base_out);
+ return r;
+}
+
+bool RE2Consume(StringPiece* sp, const char* pattern, int64* out) {
+ RegexpStringPiece base_sp = ToRegexpStringPiece(*sp);
+ bool r = RE2::Consume(&base_sp, pattern, out);
+ *sp = FromRegexpStringPiece(base_sp);
+ return r;
+}
+
+string AttrError(StringPiece orig, const string& op_name) {
+ return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name);
+}
+
+#define VERIFY(expr, ...) \
+ do { \
+ if (!(expr)) { \
+ errors->push_back( \
+ strings::StrCat(__VA_ARGS__, AttrError(orig, op_def->name()))); \
+ return; \
+ } \
+ } while (false)
+
+void FinalizeAttr(StringPiece spec, OpDef* op_def,
+ std::vector<string>* errors) {
+ OpDef::AttrDef* attr = op_def->add_attr();
+ StringPiece orig(spec);
+
+ // Parse "<name>:" at the beginning.
+ StringPiece tmp_name;
+ VERIFY(RE2Consume(&spec, "([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*", &tmp_name),
+ "Trouble parsing '<name>:'");
+ attr->set_name(tmp_name.data(), tmp_name.size());
+
+ // Read "<type>" or "list(<type>)".
+ bool is_list = RE2Consume(&spec, "list\\s*\\(\\s*");
+ string type;
+ if (spec.Consume("string")) {
+ type = "string";
+ } else if (spec.Consume("int")) {
+ type = "int";
+ } else if (spec.Consume("float")) {
+ type = "float";
+ } else if (spec.Consume("bool")) {
+ type = "bool";
+ } else if (spec.Consume("type")) {
+ type = "type";
+ } else if (spec.Consume("shape")) {
+ type = "shape";
+ } else if (spec.Consume("tensor")) {
+ type = "tensor";
+ } else if (spec.Consume("func")) {
+ type = "func";
+ } else if (spec.Consume("numbertype") || spec.Consume("numerictype")) {
+ type = "type";
+ AttrValue* allowed = attr->mutable_allowed_values();
+ for (DataType dt : NumberTypes()) {
+ allowed->mutable_list()->add_type(dt);
+ }
+ } else if (spec.Consume("quantizedtype")) {
+ type = "type";
+ AttrValue* allowed = attr->mutable_allowed_values();
+ for (DataType dt : QuantizedTypes()) {
+ allowed->mutable_list()->add_type(dt);
+ }
+ } else if (spec.Consume("realnumbertype") ||
+ spec.Consume("realnumerictype")) {
+ type = "type";
+ AttrValue* allowed = attr->mutable_allowed_values();
+ for (DataType dt : RealNumberTypes()) {
+ allowed->mutable_list()->add_type(dt);
+ }
+ } else if (spec.Consume("{")) {
+ // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
+ RE2Consume(&spec, "\\s*");
+ AttrValue* allowed = attr->mutable_allowed_values();
+ if (spec.starts_with("\"") || spec.starts_with("'")) {
+ type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
+ while (true) {
+ StringPiece escaped_string;
+ VERIFY((RE2Consume(&spec, R"xx("((?:[^"\\]|\\.)*)"\s*)xx",
+ &escaped_string) ||
+ RE2Consume(&spec, R"xx('((?:[^'\\]|\\.)*)'\s*)xx",
+ &escaped_string)),
+ "Trouble parsing allowed string at '", spec, "'");
+ string unescaped;
+ string error;
+ VERIFY(str_util::CUnescape(escaped_string, &unescaped, &error),
+ "Trouble unescaping \"", escaped_string, "\", got error: ",
+ error);
+ allowed->mutable_list()->add_s(unescaped);
+ if (spec.Consume(",")) {
+ RE2Consume(&spec, "\\s*");
+ if (spec.Consume("}")) break; // Allow ending with ", }".
+ } else {
+ VERIFY(spec.Consume("}"),
+ "Expected , or } after strings in list, not: '", spec, "'");
+ break;
+ }
+ }
+ } else { // "{ int32, float, bool }"
+ type = "type";
+ while (true) {
+ StringPiece type_string;
+ VERIFY(RE2Consume(&spec, "([a-z0-9]+)\\s*", &type_string),
+ "Trouble parsing type string at '", spec, "'");
+ DataType dt;
+ VERIFY(DataTypeFromString(type_string, &dt),
+ "Unrecognized type string '", type_string, "'");
+ allowed->mutable_list()->add_type(dt);
+ if (spec.Consume(",")) {
+ RE2Consume(&spec, "\\s*");
+ if (spec.Consume("}")) break; // Allow ending with ", }".
+ } else {
+ VERIFY(spec.Consume("}"),
+ "Expected , or } after types in list, not: '", spec, "'");
+ break;
+ }
+ }
+ }
+ } else {
+ VERIFY(false, "Trouble parsing type string at '", spec, "'");
+ }
+ RE2Consume(&spec, "\\s*");
+
+ // Write the type into *attr.
+ if (is_list) {
+ VERIFY(spec.Consume(")"), "Expected ) to close 'list(', not: '", spec, "'");
+ RE2Consume(&spec, "\\s*");
+ attr->set_type(strings::StrCat("list(", type, ")"));
+ } else {
+ attr->set_type(type);
+ }
+
+ // Read optional minimum constraint at the end.
+ if ((is_list || type == "int") && spec.Consume(">=")) {
+ int64 min_limit = -999;
+ VERIFY(RE2Consume(&spec, "\\s*(-?\\d+)\\s*", &min_limit),
+ "Could not parse integer lower limit after '>=', found '", spec,
+ "' instead");
+ attr->set_has_minimum(true);
+ attr->set_minimum(min_limit);
+ }
+
+ // Parse default value, if present.
+ if (spec.Consume("=")) {
+ RE2Consume(&spec, "\\s*");
+ VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()),
+ "Could not parse default value '", spec, "'");
+ } else {
+ VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
+ }
+}
+
+#undef VERIFY
+
+string InOutError(bool is_output, StringPiece orig, const string& op_name) {
+ return strings::StrCat(" from ", is_output ? "Output" : "Input", "(\"", orig,
+ "\") for Op ", op_name);
+}
+
+#define VERIFY(expr, ...) \
+ do { \
+ if (!(expr)) { \
+ errors->push_back(strings::StrCat( \
+ __VA_ARGS__, InOutError(is_output, orig, op_def->name()))); \
+ return; \
+ } \
+ } while (false)
+
+void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def,
+ std::vector<string>* errors) {
+ OpDef::ArgDef* arg =
+ is_output ? op_def->add_output_arg() : op_def->add_input_arg();
+
+ StringPiece orig(spec);
+
+ // Parse "<name>:" at the beginning.
+ StringPiece tmp_name;
+ VERIFY(RE2Consume(&spec, "([a-z][a-z0-9_]*)\\s*:\\s*", &tmp_name),
+ "Trouble parsing 'name:'");
+ arg->set_name(tmp_name.data(), tmp_name.size());
+
+ // Detect "Ref(...)".
+ if (RE2Consume(&spec, "Ref\\s*\\(\\s*")) {
+ arg->set_is_ref(true);
+ }
+
+ { // Parse "<name|type>" or "<name>*<name|type>".
+ StringPiece first, second, type_or_attr;
+ VERIFY(RE2Consume(&spec, "([a-zA-Z][a-zA-Z0-9_]*)\\s*", &first),
+ "Trouble parsing either a type or an attr name at '", spec, "'");
+ if (RE2Consume(&spec, "[*]\\s*([a-zA-Z][a-zA-Z0-9_]*)\\s*", &second)) {
+ arg->set_number_attr(first.data(), first.size());
+ type_or_attr = second;
+ } else {
+ type_or_attr = first;
+ }
+ DataType dt;
+ if (DataTypeFromString(type_or_attr, &dt)) {
+ arg->set_type(dt);
+ } else {
+ const OpDef::AttrDef* attr = FindAttr(type_or_attr, *op_def);
+ VERIFY(attr != nullptr, "Reference to unknown attr '", type_or_attr, "'");
+ if (attr->type() == "type") {
+ arg->set_type_attr(type_or_attr.data(), type_or_attr.size());
+ } else {
+ VERIFY(attr->type() == "list(type)", "Reference to attr '",
+ type_or_attr, "' with type ", attr->type(),
+ " that isn't type or list(type)");
+ arg->set_type_list_attr(type_or_attr.data(), type_or_attr.size());
+ }
+ }
+ }
+
+ // Closing ) for Ref(.
+ if (arg->is_ref()) {
+ VERIFY(RE2Consume(&spec, "\\)\\s*"),
+ "Did not find closing ')' for 'Ref(', instead found: '", spec, "'");
+ }
+
+ // Should not have anything else.
+ VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end");
+
+ // Int attrs that are the length of an input or output get a default
+ // minimum of 1.
+ if (!arg->number_attr().empty()) {
+ OpDef::AttrDef* attr = FindAttrMutable(arg->number_attr(), op_def);
+ if (attr != nullptr && !attr->has_minimum()) {
+ attr->set_has_minimum(true);
+ attr->set_minimum(1);
+ }
+ } else if (!arg->type_list_attr().empty()) {
+ // If an input or output has type specified by a list(type) attr,
+ // it gets a default minimum of 1 as well.
+ OpDef::AttrDef* attr = FindAttrMutable(arg->type_list_attr(), op_def);
+ if (attr != nullptr && attr->type() == "list(type)" &&
+ !attr->has_minimum()) {
+ attr->set_has_minimum(true);
+ attr->set_minimum(1);
+ }
+ }
+}
+
+#undef VERIFY
+
+int num_leading_spaces(StringPiece s) {
+ size_t i = 0;
+ while (i < s.size() && s[i] == ' ') {
+ ++i;
+ }
+ return i;
+}
+
+void FinalizeDoc(const string& text, OpDef* op_def,
+ std::vector<string>* errors) {
+ std::vector<string> lines = str_util::Split(text, '\n');
+
+ // Remove trailing spaces.
+ for (string& line : lines) {
+ str_util::StripTrailingWhitespace(&line);
+ }
+
+ // First non-blank line -> summary.
+ int l = 0;
+ while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
+ if (static_cast<size_t>(l) < lines.size()) {
+ op_def->set_summary(lines[l]);
+ ++l;
+ }
+ while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l;
+
+ // Lines until we see name: -> description.
+ int start_l = l;
+ while (static_cast<size_t>(l) < lines.size() &&
+ !RE2::PartialMatch(lines[l], "^[a-zA-Z][a-zA-Z0-9_]*\\s*:")) {
+ ++l;
+ }
+ int end_l = l;
+ // Trim trailing blank lines from the description.
+ while (start_l < end_l && lines[end_l - 1].empty()) --end_l;
+ string desc = str_util::Join(
+ gtl::ArraySlice<string>(lines.data() + start_l, end_l - start_l), "\n");
+ if (!desc.empty()) op_def->set_description(desc);
+
+ // name: description
+ // possibly continued on the next line
+ // if so, we remove the minimum indent
+ StringPiece name;
+ std::vector<StringPiece> description;
+ while (static_cast<size_t>(l) < lines.size()) {
+ description.clear();
+ description.push_back(lines[l]);
+ RE2Consume(&description.back(), "([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*", &name);
+ ++l;
+ while (static_cast<size_t>(l) < lines.size() &&
+ !RE2::PartialMatch(lines[l], "^[a-zA-Z][a-zA-Z0-9_]*\\s*:")) {
+ description.push_back(lines[l]);
+ ++l;
+ }
+ // Remove any trailing blank lines.
+ while (!description.empty() && description.back().empty()) {
+ description.pop_back();
+ }
+ // Compute the minimum indent of all lines after the first.
+ int min_indent = -1;
+ for (size_t i = 1; i < description.size(); ++i) {
+ if (!description[i].empty()) {
+ int indent = num_leading_spaces(description[i]);
+ if (min_indent < 0 || indent < min_indent) min_indent = indent;
+ }
+ }
+ // Remove min_indent spaces from all lines after the first.
+ for (size_t i = 1; i < description.size(); ++i) {
+ if (!description[i].empty()) description[i].remove_prefix(min_indent);
+ }
+ // Concatenate lines into a single string.
+ const string complete(str_util::Join(description, "\n"));
+
+ // Find name.
+ bool found = false;
+ for (int i = 0; !found && i < op_def->input_arg_size(); ++i) {
+ if (op_def->input_arg(i).name() == name) {
+ op_def->mutable_input_arg(i)->set_description(complete);
+ found = true;
+ }
+ }
+ for (int i = 0; !found && i < op_def->output_arg_size(); ++i) {
+ if (op_def->output_arg(i).name() == name) {
+ op_def->mutable_output_arg(i)->set_description(complete);
+ found = true;
+ }
+ }
+ for (int i = 0; !found && i < op_def->attr_size(); ++i) {
+ if (op_def->attr(i).name() == name) {
+ op_def->mutable_attr(i)->set_description(complete);
+ found = true;
+ }
+ }
+ if (!found) {
+ errors->push_back(
+ strings::StrCat("No matching input/output/attr for name '", name,
+ "' from Doc() for Op ", op_def->name()));
+ return;
+ }
+ }
+}
+
+} // namespace
+
+OpDefBuilder::OpDefBuilder(StringPiece op_name) {
+ op_def_.set_name(op_name.ToString()); // NOLINT
+}
+
+OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) {
+ attrs_.emplace_back(spec.data(), spec.size());
+ return *this;
+}
+
+OpDefBuilder& OpDefBuilder::Input(StringPiece spec) {
+ inputs_.emplace_back(spec.data(), spec.size());
+ return *this;
+}
+
+OpDefBuilder& OpDefBuilder::Output(StringPiece spec) {
+ outputs_.emplace_back(spec.data(), spec.size());
+ return *this;
+}
+
+OpDefBuilder& OpDefBuilder::Doc(StringPiece text) {
+ if (!doc_.empty()) {
+ errors_.push_back(
+ strings::StrCat("Extra call to Doc() for Op ", op_def_.name()));
+ } else {
+ doc_.assign(text.data(), text.size());
+ }
+ return *this;
+}
+
+OpDefBuilder& OpDefBuilder::SetIsCommutative() {
+ op_def_.set_is_commutative(true);
+ return *this;
+}
+
+OpDefBuilder& OpDefBuilder::SetIsAggregate() {
+ op_def_.set_is_aggregate(true);
+ return *this;
+}
+
+OpDefBuilder& OpDefBuilder::SetIsStateful() {
+ op_def_.set_is_stateful(true);
+ return *this;
+}
+
+OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
+ op_def_.set_allows_uninitialized_input(true);
+ return *this;
+}
+
+Status OpDefBuilder::Finalize(OpDef* op_def) const {
+ std::vector<string> errors = errors_;
+ *op_def = op_def_;
+
+ for (StringPiece attr : attrs_) {
+ FinalizeAttr(attr, op_def, &errors);
+ }
+ for (StringPiece input : inputs_) {
+ FinalizeInputOrOutput(input, false, op_def, &errors);
+ }
+ for (StringPiece output : outputs_) {
+ FinalizeInputOrOutput(output, true, op_def, &errors);
+ }
+ FinalizeDoc(doc_, op_def, &errors);
+
+ if (errors.empty()) return Status::OK();
+ return errors::InvalidArgument(str_util::Join(errors, "\n"));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h
new file mode 100644
index 0000000000..017338c508
--- /dev/null
+++ b/tensorflow/core/framework/op_def_builder.h
@@ -0,0 +1,109 @@
+// Class and associated machinery for specifying an Op's OpDef for Op
+// registration.
+
+#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_
+#define TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_
+
+#include <string>
+#include <vector>
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// Builder class passed to the REGISTER_OP() macro.
+class OpDefBuilder {
+ public:
+ // Constructs an OpDef with just the name field set.
+ explicit OpDefBuilder(StringPiece op_name);
+
+ // Adds an attr to this OpDefBuilder (and returns *this). The spec has
+ // format "<name>:<type>" or "<name>:<type>=<default>"
+ // where <name> matches regexp [a-zA-Z][a-zA-Z0-9_]*
+ // (by convention only using capital letters for attrs that can be inferred)
+ // <type> can be:
+ // "string", "int", "float", "bool", "type", "shape", or "tensor"
+ // "numbertype", "realnumbertype", "quantizedtype", "{int32,int64}"
+ // (meaning "type" with a restriction on valid values)
+ // "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}"
+ // (meaning "string" with a restriction on valid values)
+ // "list(string)", ..., "list(tensor)", "list(numbertype)", ...
+ // (meaning lists of the above types)
+ // "int >= 2" (meaning "int" with a restriction on valid values)
+ // "list(string) >= 2", "list(int) >= 2"
+ // (meaning "list(string)" / "list(int)" with length at least 2)
+ // <default>, if included, should use the Proto text format
+ // of <type>. For lists use [a, b, c] format.
+ //
+ // Note that any attr specifying the length of an input or output will
+ // get a default minimum of 1 unless the >= # syntax is used.
+ //
+ // TODO(josh11b): Perhaps support restrictions and defaults as optional
+ // extra arguments to Attr() instead of encoding them in the spec string.
+ // TODO(josh11b): Would like to have better dtype handling for tensor attrs:
+ // * Ability to say the type of an input/output matches the type of
+ // the tensor.
+ // * Ability to restrict the type of the tensor like the existing
+ // restrictions for type attrs.
+ // Perhaps by linking the type of the tensor to a type attr?
+ OpDefBuilder& Attr(StringPiece spec);
+
+ // Adds an input or ouput to this OpDefBuilder (and returns *this).
+ // The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
+ // where <name> matches regexp [a-z][a-z0-9_]* and <type-expr> can be:
+ // * For a single tensor: <type>
+ // * For a sequence of tensors with the same type: <number>*<type>
+ // * For a sequence of tensors with different types: <type-list>
+ // Where:
+ // <type> is either one of "float", "int32", "string", ...
+ // or the name of an attr (see above) with type "type".
+ // <number> is the name of an attr with type "int".
+ // <type-list> is the name of an attr with type "list(type)".
+ // TODO(josh11b): Indicate Ref() via an optional argument instead of
+ // in the spec?
+ // TODO(josh11b): SparseInput() and SparseOutput() matching the Python
+ // handling?
+ OpDefBuilder& Input(StringPiece spec);
+ OpDefBuilder& Output(StringPiece spec);
+
+ // Turns on the indicated boolean flag in this OpDefBuilder (and
+ // returns *this).
+ OpDefBuilder& SetIsCommutative();
+ OpDefBuilder& SetIsAggregate();
+ OpDefBuilder& SetIsStateful();
+ OpDefBuilder& SetAllowsUninitializedInput();
+
+ // Adds docs to this OpDefBuilder (and returns *this).
+ // Docs have the format:
+ // <1-line summary>
+ // <rest of the description>
+ // <name>: <description of name>
+ // <name>: <description of name>
+ // <if long, indent the description on subsequent lines>
+ // Where <name> is the name of an attr, input, or output. Please
+ // wrap docs at 72 columns so that it may be indented in the
+ // generated output. For tensor inputs or outputs (not attrs), you
+ // may start the description with an "=" (like name:= <description>)
+ // to suppress the automatically-generated type documentation in
+ // generated output.
+ OpDefBuilder& Doc(StringPiece text);
+
+ // Sets *op_def to the requested OpDef, or returns an error.
+ // Must be called after all of the above methods.
+ // Note that OpDefBuilder only reports parsing errors. You should also
+ // call ValidateOpDef() to detect other problems.
+ Status Finalize(OpDef* op_def) const;
+
+ private:
+ OpDef op_def_;
+ std::vector<string> attrs_;
+ std::vector<string> inputs_;
+ std::vector<string> outputs_;
+ string doc_;
+ std::vector<string> errors_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_
diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc
new file mode 100644
index 0000000000..e53bad7075
--- /dev/null
+++ b/tensorflow/core/framework/op_def_builder_test.cc
@@ -0,0 +1,519 @@
+#include "tensorflow/core/framework/op_def_builder.h"
+
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+static void CanonicalizeAttrTypeListOrder(OpDef* def) {
+ for (int i = 0; i < def->attr_size(); i++) {
+ AttrValue* a = def->mutable_attr(i)->mutable_allowed_values();
+ std::sort(a->mutable_list()->mutable_type()->begin(),
+ a->mutable_list()->mutable_type()->end());
+ }
+}
+
+class OpDefBuilderTest : public ::testing::Test {
+ protected:
+ OpDefBuilder b() { return OpDefBuilder("Test"); }
+
+ void ExpectSuccess(const OpDefBuilder& builder, StringPiece proto) {
+ OpDef op_def;
+ Status status = builder.Finalize(&op_def);
+ EXPECT_OK(status);
+ if (status.ok()) {
+ OpDef expected;
+ protobuf::TextFormat::ParseFromString(
+ strings::StrCat("name: 'Test' ", proto), &expected);
+ // Allow different orderings
+ CanonicalizeAttrTypeListOrder(&op_def);
+ CanonicalizeAttrTypeListOrder(&expected);
+ EXPECT_EQ(op_def.ShortDebugString(), expected.ShortDebugString());
+ }
+ }
+
+ void ExpectOrdered(const OpDefBuilder& builder, StringPiece proto) {
+ OpDef op_def;
+ Status status = builder.Finalize(&op_def);
+ EXPECT_OK(status);
+ if (status.ok()) {
+ OpDef expected;
+ protobuf::TextFormat::ParseFromString(
+ strings::StrCat("name: 'Test' ", proto), &expected);
+ EXPECT_EQ(op_def.ShortDebugString(), expected.ShortDebugString());
+ }
+ }
+
+ void ExpectFailure(const OpDefBuilder& builder, string error) {
+ OpDef op_def;
+ Status status = builder.Finalize(&op_def);
+ EXPECT_FALSE(status.ok());
+ if (!status.ok()) {
+ EXPECT_EQ(status.error_message(), error);
+ }
+ }
+};
+
+TEST_F(OpDefBuilderTest, Attr) {
+ ExpectSuccess(b().Attr("a:string"), "attr: { name: 'a' type: 'string' }");
+ ExpectSuccess(b().Attr("A: int"), "attr: { name: 'A' type: 'int' }");
+ ExpectSuccess(b().Attr("a1 :float"), "attr: { name: 'a1' type: 'float' }");
+ ExpectSuccess(b().Attr("a_a : bool"), "attr: { name: 'a_a' type: 'bool' }");
+ ExpectSuccess(b().Attr("aB : type"), "attr: { name: 'aB' type: 'type' }");
+ ExpectSuccess(b().Attr("aB_3\t: shape"),
+ "attr: { name: 'aB_3' type: 'shape' }");
+ ExpectSuccess(b().Attr("t: tensor"), "attr: { name: 't' type: 'tensor' }");
+ ExpectSuccess(b().Attr("XYZ\t:\tlist(type)"),
+ "attr: { name: 'XYZ' type: 'list(type)' }");
+ ExpectSuccess(b().Attr("f: func"), "attr { name: 'f' type: 'func'}");
+}
+
+TEST_F(OpDefBuilderTest, AttrFailure) {
+ ExpectFailure(
+ b().Attr("_:string"),
+ "Trouble parsing '<name>:' from Attr(\"_:string\") for Op Test");
+ ExpectFailure(
+ b().Attr("9:string"),
+ "Trouble parsing '<name>:' from Attr(\"9:string\") for Op Test");
+ ExpectFailure(b().Attr(":string"),
+ "Trouble parsing '<name>:' from Attr(\":string\") for Op Test");
+ ExpectFailure(b().Attr("string"),
+ "Trouble parsing '<name>:' from Attr(\"string\") for Op Test");
+ ExpectFailure(b().Attr("a:invalid"),
+ "Trouble parsing type string at 'invalid' from "
+ "Attr(\"a:invalid\") for Op Test");
+ ExpectFailure(
+ b().Attr("b:"),
+ "Trouble parsing type string at '' from Attr(\"b:\") for Op Test");
+}
+
+TEST_F(OpDefBuilderTest, AttrWithRestrictions) {
+ // Types with restrictions.
+ ExpectSuccess(b().Attr("a:numbertype"),
+ "attr: { name: 'a' type: 'type' allowed_values { list { type: "
+ "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
+ "DT_INT8, DT_COMPLEX64, DT_QINT8, DT_QUINT8, DT_QINT32] } } }");
+ ExpectSuccess(b().Attr("a:realnumbertype"),
+ "attr: { name: 'a' type: 'type' allowed_values { list { type: "
+ "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
+ "DT_INT8] } } }");
+ ExpectSuccess(b().Attr("a:quantizedtype"),
+ "attr: { name: 'a' type: 'type' allowed_values { list { type: "
+ "[DT_QINT8, DT_QUINT8, DT_QINT32] } } }");
+ ExpectSuccess(b().Attr("a:{string,int32}"),
+ "attr: { name: 'a' type: 'type' allowed_values { list { type: "
+ "[DT_STRING, DT_INT32] } } }");
+ ExpectSuccess(b().Attr("a: { float , complex64 } "),
+ "attr: { name: 'a' type: 'type' allowed_values { list { type: "
+ "[DT_FLOAT, DT_COMPLEX64] } } }");
+ ExpectSuccess(b().Attr("a: {float, complex64,} "),
+ "attr: { name: 'a' type: 'type' allowed_values { list { type: "
+ "[DT_FLOAT, DT_COMPLEX64] } }");
+ ExpectSuccess(b().Attr(R"(a: { "X", "yz" })"),
+ "attr: { name: 'a' type: 'string' allowed_values { list { s: "
+ "['X', 'yz'] } } }");
+ ExpectSuccess(b().Attr(R"(a: { "X", "yz", })"),
+ "attr: { name: 'a' type: 'string' allowed_values { list { s: "
+ "['X', 'yz'] } } }");
+ ExpectSuccess(
+ b().Attr("i: int >= -5"),
+ "attr: { name: 'i' type: 'int' has_minimum: true minimum: -5 }");
+}
+
+TEST_F(OpDefBuilderTest, AttrRestrictionFailure) {
+ ExpectFailure(
+ b().Attr("a:{}"),
+ "Trouble parsing type string at '}' from Attr(\"a:{}\") for Op Test");
+ ExpectFailure(
+ b().Attr("a:{,}"),
+ "Trouble parsing type string at ',}' from Attr(\"a:{,}\") for Op Test");
+ ExpectFailure(b().Attr("a:{invalid}"),
+ "Unrecognized type string 'invalid' from Attr(\"a:{invalid}\") "
+ "for Op Test");
+ ExpectFailure(b().Attr("a:{\"str\", float}"),
+ "Trouble parsing allowed string at 'float}' from "
+ "Attr(\"a:{\"str\", float}\") for Op Test");
+ ExpectFailure(b().Attr("a:{ float, \"str\" }"),
+ "Trouble parsing type string at '\"str\" }' from Attr(\"a:{ "
+ "float, \"str\" }\") for Op Test");
+ ExpectFailure(b().Attr("a:{float,,string}"),
+ "Trouble parsing type string at ',string}' from "
+ "Attr(\"a:{float,,string}\") for Op Test");
+ ExpectFailure(b().Attr("a:{float,,}"),
+ "Trouble parsing type string at ',}' from "
+ "Attr(\"a:{float,,}\") for Op Test");
+}
+
+TEST_F(OpDefBuilderTest, AttrListOfRestricted) {
+ ExpectSuccess(
+ b().Attr("a:list(realnumbertype)"),
+ "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
+ "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
+ "DT_INT8] } } }");
+ ExpectSuccess(
+ b().Attr("a:list(quantizedtype)"),
+ "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
+ "[DT_QINT8, DT_QUINT8, DT_QINT32] } } }");
+ ExpectSuccess(
+ b().Attr("a: list({float, string, bool})"),
+ "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
+ "[DT_FLOAT, DT_STRING, DT_BOOL] } } }");
+ ExpectSuccess(
+ b().Attr(R"(a: list({ "one fish", "two fish" }))"),
+ "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: "
+ "['one fish', 'two fish'] } } }");
+ ExpectSuccess(
+ b().Attr(R"(a: list({ 'red fish', 'blue fish' }))"),
+ "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: "
+ "['red fish', 'blue fish'] } } }");
+ ExpectSuccess(
+ b().Attr(R"(a: list({ "single' ", 'double"' }))"),
+ "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: "
+ "[\"single' \", 'double\"'] } } }");
+ ExpectSuccess(
+ b().Attr(R"(a: list({ 'escape\'\n', "from\\\"NY" }))"),
+ "attr: { name: 'a' type: 'list(string)' allowed_values { list { s: "
+ "[\"escape'\\n\", 'from\\\\\"NY'] } } }");
+}
+
+TEST_F(OpDefBuilderTest, AttrListWithMinLength) {
+ ExpectSuccess(
+ b().Attr("i: list(bool) >= 4"),
+ "attr: { name: 'i' type: 'list(bool)' has_minimum: true minimum: 4 }");
+}
+
+TEST_F(OpDefBuilderTest, AttrWithDefaults) {
+ ExpectSuccess(b().Attr(R"(a:string="foo")"),
+ "attr: { name: 'a' type: 'string' default_value { s:'foo' } }");
+ ExpectSuccess(b().Attr(R"(a:string='foo')"),
+ "attr: { name: 'a' type: 'string' default_value { s:'foo' } }");
+ ExpectSuccess(b().Attr("a:float = 1.25"),
+ "attr: { name: 'a' type: 'float' default_value { f: 1.25 } }");
+ ExpectSuccess(b().Attr("a:tensor = { dtype: DT_INT32 int_val: 5 }"),
+ "attr: { name: 'a' type: 'tensor' default_value { tensor {"
+ " dtype: DT_INT32 int_val: 5 } } }");
+ ExpectSuccess(b().Attr("a:shape = { dim { size: 3 } dim { size: 4 } }"),
+ "attr: { name: 'a' type: 'shape' default_value { shape {"
+ " dim { size: 3 } dim { size: 4 } } } }");
+}
+
+TEST_F(OpDefBuilderTest, AttrFailedDefaults) {
+ ExpectFailure(b().Attr(R"(a:int="foo")"),
+ "Could not parse default value '\"foo\"' from "
+ "Attr(\"a:int=\"foo\"\") for Op Test");
+ ExpectFailure(b().Attr("a:float = [1.25]"),
+ "Could not parse default value '[1.25]' from Attr(\"a:float = "
+ "[1.25]\") for Op Test");
+}
+
+TEST_F(OpDefBuilderTest, AttrListWithDefaults) {
+ ExpectSuccess(b().Attr(R"(a:list(string)=["foo", "bar"])"),
+ "attr: { name: 'a' type: 'list(string)' "
+ "default_value { list { s: ['foo', 'bar'] } } }");
+ ExpectSuccess(b().Attr("a:list(bool)=[true, false, true]"),
+ "attr: { name: 'a' type: 'list(bool)' "
+ "default_value { list { b: [true, false, true] } } }");
+ ExpectSuccess(b().Attr(R"(a:list(int)=[0, -1, 2, -4, 8])"),
+ "attr: { name: 'a' type: 'list(int)' "
+ "default_value { list { i: [0, -1, 2, -4, 8] } } }");
+}
+
+TEST_F(OpDefBuilderTest, AttrFailedListDefaults) {
+ ExpectFailure(b().Attr(R"(a:list(int)=["foo"])"),
+ "Could not parse default value '[\"foo\"]' from "
+ "Attr(\"a:list(int)=[\"foo\"]\") for Op Test");
+ ExpectFailure(b().Attr(R"(a:list(int)=[7, "foo"])"),
+ "Could not parse default value '[7, \"foo\"]' from "
+ "Attr(\"a:list(int)=[7, \"foo\"]\") for Op Test");
+ ExpectFailure(b().Attr("a:list(float) = [[1.25]]"),
+ "Could not parse default value '[[1.25]]' from "
+ "Attr(\"a:list(float) = [[1.25]]\") for Op Test");
+ ExpectFailure(b().Attr("a:list(float) = 1.25"),
+ "Could not parse default value '1.25' from "
+ "Attr(\"a:list(float) = 1.25\") for Op Test");
+ ExpectFailure(b().Attr(R"(a:list(string)='foo')"),
+ "Could not parse default value ''foo'' from "
+ "Attr(\"a:list(string)='foo'\") for Op Test");
+}
+
+TEST_F(OpDefBuilderTest, InputOutput) {
+ ExpectSuccess(b().Input("a: int32"),
+ "input_arg: { name: 'a' type: DT_INT32 }");
+ ExpectSuccess(b().Output("b: string"),
+ "output_arg: { name: 'b' type: DT_STRING }");
+ ExpectSuccess(b().Input("c: float "),
+ "input_arg: { name: 'c' type: DT_FLOAT }");
+ ExpectSuccess(b().Output("d: Ref(bool)"),
+ "output_arg: { name: 'd' type: DT_BOOL is_ref: true }");
+ ExpectOrdered(b().Input("a: bool")
+ .Output("c: complex64")
+ .Input("b: int64")
+ .Output("d: string"),
+ "input_arg: { name: 'a' type: DT_BOOL } "
+ "input_arg: { name: 'b' type: DT_INT64 } "
+ "output_arg: { name: 'c' type: DT_COMPLEX64 } "
+ "output_arg: { name: 'd' type: DT_STRING }");
+}
+
+TEST_F(OpDefBuilderTest, PolymorphicInputOutput) {
+ ExpectSuccess(b().Input("a: foo").Attr("foo: type"),
+ "input_arg: { name: 'a' type_attr: 'foo' } "
+ "attr: { name: 'foo' type: 'type' }");
+ ExpectSuccess(b().Output("a: foo").Attr("foo: { bool, int32 }"),
+ "output_arg: { name: 'a' type_attr: 'foo' } "
+ "attr: { name: 'foo' type: 'type' "
+ "allowed_values: { list { type: [DT_BOOL, DT_INT32] } } }");
+}
+
+TEST_F(OpDefBuilderTest, InputOutputListSameType) {
+ ExpectSuccess(b().Input("a: n * int32").Attr("n: int"),
+ "input_arg: { name: 'a' number_attr: 'n' type: DT_INT32 } "
+ "attr: { name: 'n' type: 'int' has_minimum: true minimum: 1 }");
+ // Polymorphic case:
+ ExpectSuccess(b().Output("b: n * foo").Attr("n: int").Attr("foo: type"),
+ "output_arg: { name: 'b' number_attr: 'n' type_attr: 'foo' } "
+ "attr: { name: 'n' type: 'int' has_minimum: true minimum: 1 } "
+ "attr: { name: 'foo' type: 'type' }");
+}
+
+TEST_F(OpDefBuilderTest, InputOutputListAnyType) {
+ ExpectSuccess(
+ b().Input("c: foo").Attr("foo: list(type)"),
+ "input_arg: { name: 'c' type_list_attr: 'foo' } "
+ "attr: { name: 'foo' type: 'list(type)' has_minimum: true minimum: 1 }");
+ ExpectSuccess(
+ b().Output("c: foo").Attr("foo: list({string, float})"),
+ "output_arg: { name: 'c' type_list_attr: 'foo' } "
+ "attr: { name: 'foo' type: 'list(type)' has_minimum: true minimum: 1 "
+ "allowed_values: { list { type: [DT_STRING, DT_FLOAT] } } }");
+}
+
+TEST_F(OpDefBuilderTest, InputOutputFailure) {
+ ExpectFailure(b().Input("9: int32"),
+ "Trouble parsing 'name:' from Input(\"9: int32\") for Op Test");
+ ExpectFailure(
+ b().Output("_: int32"),
+ "Trouble parsing 'name:' from Output(\"_: int32\") for Op Test");
+ ExpectFailure(b().Input(": int32"),
+ "Trouble parsing 'name:' from Input(\": int32\") for Op Test");
+ ExpectFailure(b().Output("int32"),
+ "Trouble parsing 'name:' from Output(\"int32\") for Op Test");
+ ExpectFailure(
+ b().Input("CAPS: int32"),
+ "Trouble parsing 'name:' from Input(\"CAPS: int32\") for Op Test");
+ ExpectFailure(b().Input("a: _"),
+ "Trouble parsing either a type or an attr name at '_' from "
+ "Input(\"a: _\") for Op Test");
+ ExpectFailure(b().Input("a: 9"),
+ "Trouble parsing either a type or an attr name at '9' from "
+ "Input(\"a: 9\") for Op Test");
+ ExpectFailure(b().Input("a: 9 * int32"),
+ "Trouble parsing either a type or an attr name at '9 * int32' "
+ "from Input(\"a: 9 * int32\") for Op Test");
+ ExpectFailure(
+ b().Input("a: x * _").Attr("x: type"),
+ "Extra '* _' unparsed at the end from Input(\"a: x * _\") for Op Test");
+ ExpectFailure(b().Input("a: x * y extra").Attr("x: int").Attr("y: type"),
+ "Extra 'extra' unparsed at the end from Input(\"a: x * y "
+ "extra\") for Op Test");
+ ExpectFailure(b().Input("a: Ref(int32"),
+ "Did not find closing ')' for 'Ref(', instead found: '' from "
+ "Input(\"a: Ref(int32\") for Op Test");
+ ExpectFailure(b().Input("a: Ref(x y").Attr("x: type"),
+ "Did not find closing ')' for 'Ref(', instead found: 'y' from "
+ "Input(\"a: Ref(x y\") for Op Test");
+ ExpectFailure(
+ b().Input("a: x"),
+ "Reference to unknown attr 'x' from Input(\"a: x\") for Op Test");
+ ExpectFailure(
+ b().Input("a: x * y").Attr("x: int"),
+ "Reference to unknown attr 'y' from Input(\"a: x * y\") for Op Test");
+ ExpectFailure(b().Input("a: x").Attr("x: int"),
+ "Reference to attr 'x' with type int that isn't type or "
+ "list(type) from Input(\"a: x\") for Op Test");
+}
+
+TEST_F(OpDefBuilderTest, Set) {
+ ExpectSuccess(b().SetIsStateful(), "is_stateful: true");
+ ExpectSuccess(b().SetIsCommutative().SetIsAggregate(),
+ "is_commutative: true is_aggregate: true");
+}
+
+TEST_F(OpDefBuilderTest, DocUnpackSparseFeatures) {
+ ExpectOrdered(b().Input("sf: string")
+ .Output("indices: int32")
+ .Output("ids: int64")
+ .Output("weights: float")
+ .Doc(R"doc(
+Converts a vector of strings with dist_belief::SparseFeatures to tensors.
+
+Note that indices, ids and weights are vectors of the same size and have
+one-to-one correspondence between their elements. ids and weights are each
+obtained by sequentially concatenating sf[i].id and sf[i].weight, for i in
+1...size(sf). Note that if sf[i].weight is not set, the default value for the
+weight is assumed to be 1.0. Also for any j, if ids[j] and weights[j] were
+extracted from sf[i], then index[j] is set to i.
+
+sf: vector of string, where each element is the string encoding of
+ SparseFeatures proto.
+indices: vector of indices inside sf
+ids: vector of id extracted from the SparseFeatures proto.
+weights: vector of weight extracted from the SparseFeatures proto.
+)doc"),
+ R"proto(
+input_arg {
+ name: "sf"
+ description: "vector of string, where each element is the string encoding of\nSparseFeatures proto."
+ type: DT_STRING
+}
+output_arg {
+ name: "indices"
+ description: "vector of indices inside sf"
+ type: DT_INT32
+}
+output_arg {
+ name: "ids"
+ description: "vector of id extracted from the SparseFeatures proto."
+ type: DT_INT64
+}
+output_arg {
+ name: "weights"
+ description: "vector of weight extracted from the SparseFeatures proto."
+ type: DT_FLOAT
+}
+summary: "Converts a vector of strings with dist_belief::SparseFeatures to tensors."
+description: "Note that indices, ids and weights are vectors of the same size and have\none-to-one correspondence between their elements. ids and weights are each\nobtained by sequentially concatenating sf[i].id and sf[i].weight, for i in\n1...size(sf). Note that if sf[i].weight is not set, the default value for the\nweight is assumed to be 1.0. Also for any j, if ids[j] and weights[j] were\nextracted from sf[i], then index[j] is set to i."
+)proto");
+}
+
+TEST_F(OpDefBuilderTest, DocConcat) {
+ ExpectOrdered(b().Input("concat_dim: int32")
+ .Input("values: num_values * dtype")
+ .Output("output: dtype")
+ .Attr("dtype: type")
+ .Attr("num_values: int >= 2")
+ .Doc(R"doc(
+Concatenate N Tensors along one dimension.
+
+concat_dim: The (scalar) dimension along which to concatenate. Must be
+ in the range [0, rank(values...)).
+values: The N Tensors to concatenate. Their ranks and types must match,
+ and their sizes must match in all dimensions except concat_dim.
+output: A Tensor with the concatenation of values stacked along the
+ concat_dim dimension. This Tensor's shape matches the Tensors in
+ values, except in concat_dim where it has the sum of the sizes.
+)doc"),
+ R"proto(
+input_arg {
+ name: "concat_dim"
+ description: "The (scalar) dimension along which to concatenate. Must be\nin the range [0, rank(values...))."
+ type: DT_INT32
+}
+input_arg {
+ name: "values"
+ description: "The N Tensors to concatenate. Their ranks and types must match,\nand their sizes must match in all dimensions except concat_dim."
+ type_attr: "dtype"
+ number_attr: "num_values"
+}
+output_arg {
+ name: "output"
+ description: "A Tensor with the concatenation of values stacked along the\nconcat_dim dimension. This Tensor\'s shape matches the Tensors in\nvalues, except in concat_dim where it has the sum of the sizes."
+ type_attr: "dtype"
+}
+summary: "Concatenate N Tensors along one dimension."
+attr {
+ name: "dtype"
+ type: "type"
+}
+attr {
+ name: "num_values"
+ type: "int"
+ has_minimum: true
+ minimum: 2
+}
+)proto");
+}
+
+TEST_F(OpDefBuilderTest, DocAttr) {
+ ExpectOrdered(b().Attr("i: int").Doc(R"doc(
+Summary
+
+i: How much to operate.
+)doc"),
+ R"proto(
+summary: "Summary"
+attr {
+ name: "i"
+ type: "int"
+ description: "How much to operate."
+}
+)proto");
+}
+
+TEST_F(OpDefBuilderTest, DocCalledTwiceFailure) {
+ ExpectFailure(b().Doc("What's").Doc("up, doc?"),
+ "Extra call to Doc() for Op Test");
+}
+
+TEST_F(OpDefBuilderTest, DocFailureMissingName) {
+ ExpectFailure(
+ b().Input("a: int32").Doc(R"doc(
+Summary
+
+a: Something for a.
+b: b is not defined.
+)doc"),
+ "No matching input/output/attr for name 'b' from Doc() for Op Test");
+
+ ExpectFailure(
+ b().Input("a: int32").Doc(R"doc(
+Summary
+
+b: b is not defined and by itself.
+)doc"),
+ "No matching input/output/attr for name 'b' from Doc() for Op Test");
+}
+
+TEST_F(OpDefBuilderTest, DefaultMinimum) {
+ ExpectSuccess(b().Input("values: num_values * dtype")
+ .Output("output: anything")
+ .Attr("anything: list(type)")
+ .Attr("dtype: type")
+ .Attr("num_values: int"),
+ R"proto(
+input_arg {
+ name: "values"
+ type_attr: "dtype"
+ number_attr: "num_values"
+}
+output_arg {
+ name: "output"
+ type_list_attr: "anything"
+}
+attr {
+ name: "anything"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+}
+attr {
+ name: "dtype"
+ type: "type"
+}
+attr {
+ name: "num_values"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+}
+)proto");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
new file mode 100644
index 0000000000..e3aef011de
--- /dev/null
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -0,0 +1,344 @@
+#include "tensorflow/core/framework/op_def_util.h"
+
+#include <set>
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace tensorflow {
+namespace { // ------ Helper functions ------
+
+bool HasAttrStyleType(const OpDef::ArgDef& arg) {
+ return arg.type() != DT_INVALID || !arg.type_attr().empty() ||
+ !arg.type_list_attr().empty();
+}
+
+Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) {
+ const AttrValue& allowed_values(attr.allowed_values());
+ for (auto allowed : allowed_values.list().type()) {
+ if (dt == allowed) {
+ return Status::OK();
+ }
+ }
+ string allowed_str;
+ for (int i = 0; i < allowed_values.list().type_size(); ++i) {
+ if (!allowed_str.empty()) {
+ strings::StrAppend(&allowed_str, ", ");
+ }
+ strings::StrAppend(&allowed_str,
+ DataTypeString(allowed_values.list().type(i)));
+ }
+ return errors::InvalidArgument(
+ "Value for attr '", attr.name(), "' of ", DataTypeString(dt),
+ " is not in the list of allowed values: ", allowed_str);
+}
+
+Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) {
+ const AttrValue& allowed_values(attr.allowed_values());
+ for (auto allowed : allowed_values.list().s()) {
+ if (str == allowed) {
+ return Status::OK();
+ }
+ }
+ string allowed_str;
+ for (const string& allowed : allowed_values.list().s()) {
+ if (!allowed_str.empty()) {
+ strings::StrAppend(&allowed_str, ", ");
+ }
+ strings::StrAppend(&allowed_str, "\"", allowed, "\"");
+ }
+ return errors::InvalidArgument(
+ "Value for attr '", attr.name(), "' of \"", str,
+ "\" is not in the list of allowed values: ", allowed_str);
+}
+
+} // namespace
+
+// Requires: attr has already been validated.
+Status ValidateAttrValue(const AttrValue& attr_value,
+ const OpDef::AttrDef& attr) {
+ // Is it a valid value?
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()),
+ " for attr '", attr.name(), "'");
+
+ // Does the value satisfy the minimum constraint in the AttrDef?
+ if (attr.has_minimum()) {
+ if (attr.type() == "int") {
+ if (attr_value.i() < attr.minimum()) {
+ return errors::InvalidArgument(
+ "Value for attr '", attr.name(), "' of ", attr_value.i(),
+ " must be at least minimum ", attr.minimum());
+ }
+ } else {
+ int length = -1;
+ if (attr.type() == "list(string)") {
+ length = attr_value.list().s_size();
+ } else if (attr.type() == "list(int)") {
+ length = attr_value.list().i_size();
+ } else if (attr.type() == "list(float)") {
+ length = attr_value.list().f_size();
+ } else if (attr.type() == "list(bool)") {
+ length = attr_value.list().b_size();
+ } else if (attr.type() == "list(type)") {
+ length = attr_value.list().type_size();
+ } else if (attr.type() == "list(shape)") {
+ length = attr_value.list().shape_size();
+ } else if (attr.type() == "list(tensor)") {
+ length = attr_value.list().tensor_size();
+ }
+ if (length < attr.minimum()) {
+ return errors::InvalidArgument(
+ "Length for attr '", attr.name(), "' of ", length,
+ " must be at least minimum ", attr.minimum());
+ }
+ }
+ }
+
+ // Does the value satisfy the allowed_value constraint in the AttrDef?
+ if (attr.has_allowed_values()) {
+ if (attr.type() == "type") {
+ TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr));
+ } else if (attr.type() == "list(type)") {
+ for (int dt : attr_value.list().type()) {
+ TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast<DataType>(dt), attr));
+ }
+ } else if (attr.type() == "string") {
+ TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr));
+ } else if (attr.type() == "list(string)") {
+ for (const string& str : attr_value.list().s()) {
+ TF_RETURN_IF_ERROR(AllowedStringValue(str, attr));
+ }
+ } else {
+ return errors::Unimplemented(
+ "Support for allowed_values not implemented for type ", attr.type());
+ }
+ }
+ return Status::OK();
+}
+
+const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) {
+ for (int i = 0; i < op_def.attr_size(); ++i) {
+ if (op_def.attr(i).name() == name) {
+ return &op_def.attr(i);
+ }
+ }
+ return nullptr;
+}
+
+OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) {
+ for (int i = 0; i < op_def->attr_size(); ++i) {
+ if (op_def->attr(i).name() == name) {
+ return op_def->mutable_attr(i);
+ }
+ }
+ return nullptr;
+}
+
+#define VALIDATE(EXPR, ...) \
+ do { \
+ if (!(EXPR)) { \
+ return errors::InvalidArgument(__VA_ARGS__, "; in OpDef: ", \
+ op_def.ShortDebugString()); \
+ } \
+ } while (false)
+
+static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def,
+ bool output, std::set<string>* names) {
+ const string suffix = strings::StrCat(
+ output ? " for output '" : " for input '", arg.name(), "'");
+ VALIDATE(gtl::InsertIfNotPresent(names, arg.name()), "Duplicate name: ",
+ arg.name());
+ VALIDATE(HasAttrStyleType(arg), "Missing type", suffix);
+
+ if (!arg.number_attr().empty()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
+ VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'",
+ suffix);
+ VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length",
+ suffix, " has type ", attr->type(), " != int");
+ VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length",
+ suffix, " must have minimum");
+ VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length",
+ suffix, " must have minimum >= 0");
+ VALIDATE(arg.type_list_attr().empty(),
+ "Can't have both number_attr and type_list_attr", suffix);
+ VALIDATE((arg.type() != DT_INVALID ? 1 : 0) +
+ (!arg.type_attr().empty() ? 1 : 0) ==
+ 1,
+ "Exactly one of type, type_attr must be set", suffix);
+ } else {
+ const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) +
+ (!arg.type_attr().empty() ? 1 : 0) +
+ (!arg.type_list_attr().empty() ? 1 : 0);
+ VALIDATE(num_type_fields == 1,
+ "Exactly one of type, type_attr, type_list_attr must be set",
+ suffix);
+ }
+
+ if (!arg.type_attr().empty()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
+ VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'",
+ suffix);
+ VALIDATE(attr->type() == "type", "Attr '", attr->name(),
+ "' used as type_attr", suffix, " has type ", attr->type(),
+ " != type");
+ } else if (!arg.type_list_attr().empty()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def);
+ VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'",
+ suffix);
+ VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(),
+ "' used as type_list_attr", suffix, " has type ", attr->type(),
+ " != list(type)");
+ } else {
+ // All argument types should be non-reference types at this point.
+ // ArgDef.is_ref is set to true for reference arguments.
+ VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '",
+ DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix);
+ }
+
+ return Status::OK();
+}
+
+Status ValidateOpDef(const OpDef& op_def) {
+ VALIDATE(RE2::FullMatch(op_def.name(), "(?:_.*|[A-Z][a-zA-Z0-9]*)"),
+ "Invalid name: ", op_def.name(), " (Did you use CamelCase?)");
+
+ std::set<string> names; // for detecting duplicate names
+ for (const auto& attr : op_def.attr()) {
+ // Validate name
+ VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()), "Duplicate name: ",
+ attr.name());
+ DataType dt;
+ VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ",
+ attr.name(), " that matches a data type");
+
+ // Validate type
+ StringPiece type(attr.type());
+ bool is_list = type.Consume("list(");
+ bool found = false;
+ for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape",
+ "tensor", "func"}) {
+ if (type.Consume(valid)) {
+ found = true;
+ break;
+ }
+ }
+ VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(),
+ "'");
+ if (is_list) {
+ VALIDATE(type.Consume(")"), "'list(' is missing ')' in attr ",
+ attr.name(), "'s type ", attr.type());
+ }
+ VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ",
+ attr.name(), "'s type ", attr.type());
+
+ // Validate minimum
+ if (attr.has_minimum()) {
+ VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(),
+ "' has minimum for unsupported type ", attr.type());
+ if (is_list) {
+ VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(),
+ "' with list type must have a non-negative minimum, not ",
+ attr.minimum());
+ }
+ } else {
+ VALIDATE(attr.minimum() == 0, "Attr '", attr.name(),
+ "' with has_minimum = false but minimum ", attr.minimum(),
+ " not equal to default of 0");
+ }
+
+ // Validate allowed_values
+ if (attr.has_allowed_values()) {
+ const string list_type =
+ is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")");
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ AttrValueHasType(attr.allowed_values(), list_type), " for attr '",
+ attr.name(), "' in Op '", op_def.name(), "'");
+ }
+
+ // Validate default_value (after we have validated the rest of the attr,
+ // so we can use ValidateAttrValue()).
+ if (attr.has_default_value()) {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ ValidateAttrValue(attr.default_value(), attr), " in Op '",
+ op_def.name(), "'");
+ }
+ }
+
+ for (const auto& arg : op_def.input_arg()) {
+ TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names));
+ }
+
+ for (const auto& arg : op_def.output_arg()) {
+ TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names));
+ }
+
+ return Status::OK();
+}
+
+#undef VALIDATE
+
+namespace {
+
+string SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
+ string ret;
+ for (const OpDef::ArgDef& arg : args) {
+ if (!ret.empty()) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, arg.name(), ":");
+ if (arg.is_ref()) strings::StrAppend(&ret, "Ref(");
+ if (!arg.number_attr().empty()) {
+ strings::StrAppend(&ret, arg.number_attr(), "*");
+ }
+ if (arg.type() != DT_INVALID) {
+ strings::StrAppend(&ret, DataTypeString(arg.type()));
+ } else {
+ strings::StrAppend(&ret, arg.type_attr());
+ }
+ if (arg.is_ref()) strings::StrAppend(&ret, ")");
+ }
+ return ret;
+}
+
+} // namespace
+
+string SummarizeOpDef(const OpDef& op_def) {
+ string ret = strings::StrCat("Op<name=", op_def.name());
+ strings::StrAppend(&ret, "; signature=", SummarizeArgs(op_def.input_arg()),
+ " -> ", SummarizeArgs(op_def.output_arg()));
+ for (int i = 0; i < op_def.attr_size(); ++i) {
+ strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":",
+ op_def.attr(i).type());
+ if (op_def.attr(i).has_default_value()) {
+ strings::StrAppend(&ret, ",default=",
+ SummarizeAttrValue(op_def.attr(i).default_value()));
+ }
+ if (op_def.attr(i).has_minimum()) {
+ strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum());
+ }
+ if (op_def.attr(i).has_allowed_values()) {
+ strings::StrAppend(&ret, ",allowed=",
+ SummarizeAttrValue(op_def.attr(i).allowed_values()));
+ }
+ }
+ if (op_def.is_commutative()) {
+ strings::StrAppend(&ret, "; is_commutative=true");
+ }
+ if (op_def.is_aggregate()) {
+ strings::StrAppend(&ret, "; is_aggregate=true");
+ }
+ if (op_def.is_stateful()) {
+ strings::StrAppend(&ret, "; is_stateful=true");
+ }
+ if (op_def.allows_uninitialized_input()) {
+ strings::StrAppend(&ret, "; allows_uninitialized_input=true");
+ }
+ strings::StrAppend(&ret, ">");
+ return ret;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h
new file mode 100644
index 0000000000..a9fecf3fa0
--- /dev/null
+++ b/tensorflow/core/framework/op_def_util.h
@@ -0,0 +1,32 @@
+// TODO(josh11b): Probably not needed for OpKernel authors, so doesn't
+// need to be as publicly accessible as other files in framework/.
+
+#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
+#define TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
+
+#include <string>
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// Performs a consistency check across the fields of the op_def.
+Status ValidateOpDef(const OpDef& op_def);
+
+// Validates that attr_value satisfies the type and constraints from attr.
+// REQUIRES: attr has already been validated.
+Status ValidateAttrValue(const AttrValue& attr_value,
+ const OpDef::AttrDef& attr);
+
+// The following search through op_def for an attr with the indicated name.
+// Returns nullptr if no such attr is found.
+const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def);
+OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def);
+
+// Produce a human-readable version of an op_def that is more concise
+// than a text-format proto. Excludes descriptions.
+string SummarizeOpDef(const OpDef& op_def);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/op_def_util_test.cc b/tensorflow/core/framework/op_def_util_test.cc
new file mode 100644
index 0000000000..515e8bb288
--- /dev/null
+++ b/tensorflow/core/framework/op_def_util_test.cc
@@ -0,0 +1,330 @@
+#include "tensorflow/core/framework/op_def_util.h"
+
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+OpDef FromText(const string& text) {
+ OpDef op_def;
+ EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &op_def));
+ return op_def;
+}
+
+class ValidateOpDefTest : public ::testing::Test {
+ protected:
+ Status TestProto(const string& text) {
+ return ValidateOpDef(FromText(text));
+ }
+
+ Status TestBuilder(const OpDefBuilder& builder) {
+ OpDef op_def;
+ Status status = builder.Finalize(&op_def);
+ EXPECT_OK(status);
+ if (!status.ok()) {
+ return status;
+ } else {
+ return ValidateOpDef(op_def);
+ }
+ }
+
+ void ExpectFailure(const Status& status, const string& message) {
+ EXPECT_FALSE(status.ok()) << "Did not see error with: " << message;
+ if (!status.ok()) {
+ LOG(INFO) << "message: " << status;
+ EXPECT_TRUE(StringPiece(status.ToString()).contains(message))
+ << "Actual: " << status << "\nExpected to contain: " << message;
+ }
+ }
+};
+
+TEST_F(ValidateOpDefTest, OpDefValid) {
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Input("a: int32")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Output("a: bool")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("t: type").Input("a: t")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int = 3")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int >= -5")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int >= -5")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: int >= -5 = 3")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("X").Attr("a: numbertype")));
+ EXPECT_OK(TestBuilder(OpDefBuilder("Uppercase")));
+}
+
+TEST_F(ValidateOpDefTest, InvalidName) {
+ ExpectFailure(TestBuilder(OpDefBuilder("lower").Attr("a: int")),
+ "Invalid name");
+ ExpectFailure(TestBuilder(OpDefBuilder("BadSuffix 7%")), "Invalid name");
+}
+
+TEST_F(ValidateOpDefTest, DuplicateName) {
+ ExpectFailure(
+ TestBuilder(OpDefBuilder("DupeName").Input("a: int32").Input("a: float")),
+ "Duplicate name: a");
+ ExpectFailure(
+ TestBuilder(
+ OpDefBuilder("DupeName").Input("a: int32").Output("a: float")),
+ "Duplicate name: a");
+ ExpectFailure(
+ TestBuilder(
+ OpDefBuilder("DupeName").Output("a: int32").Output("a: float")),
+ "Duplicate name: a");
+ ExpectFailure(
+ TestBuilder(OpDefBuilder("DupeName").Input("a: int32").Attr("a: float")),
+ "Duplicate name: a");
+ ExpectFailure(
+ TestBuilder(OpDefBuilder("DupeName").Output("a: int32").Attr("a: float")),
+ "Duplicate name: a");
+ ExpectFailure(
+ TestBuilder(OpDefBuilder("DupeName").Attr("a: int").Attr("a: float")),
+ "Duplicate name: a");
+}
+
+TEST_F(ValidateOpDefTest, BadAttrName) {
+ ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude").Attr("int32: int")),
+ "Attr can't have name int32 that matches a data type");
+ ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude").Attr("float: string")),
+ "Attr can't have name float that matches a data type");
+}
+
+TEST_F(ValidateOpDefTest, BadAttrType) {
+ ExpectFailure(
+ TestProto("name: 'BadAttrType' attr { name: 'a' type: 'illegal' }"),
+ "Unrecognized type");
+ ExpectFailure(
+ TestProto("name: 'BadAttrType' attr { name: 'a' type: 'list(illegal)' }"),
+ "Unrecognized type");
+ ExpectFailure(
+ TestProto("name: 'BadAttrType' attr { name: 'a' type: 'int extra' }"),
+ "Extra ' extra' at the end");
+ ExpectFailure(
+ TestProto(
+ "name: 'BadAttrType' attr { name: 'a' type: 'list(int extra)' }"),
+ "'list(' is missing ')' in attr");
+ ExpectFailure(
+ TestProto(
+ "name: 'BadAttrType' attr { name: 'a' type: 'list(int) extra' }"),
+ "Extra ' extra' at the end");
+}
+
+TEST_F(ValidateOpDefTest, BadAttrDefault) {
+ ExpectFailure(
+ TestProto("name: 'BadAttrDef' attr { name: 'a' "
+ "type: 'int' default_value { s: 'x' } }"),
+ "AttrValue had value with type string when int expected\n\t for "
+ "attr 'a'\n\t in Op 'BadAttrDef'");
+ ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' "
+ "type: 'int' default_value { f: 0.5 } }"),
+ "AttrValue had value with type float when int expected\n\t for "
+ "attr 'a'\n\t in Op 'BadAttrDef'");
+ ExpectFailure(
+ TestProto("name: 'BadAttrDef' attr { name: 'a' type: 'int' "
+ "default_value { i: 5 list { i: [2] } } }"),
+ "AttrValue had value with type list(int) when int expected\n\t for "
+ "attr 'a'\n\t in Op 'BadAttrDef'");
+ ExpectFailure(
+ TestProto("name: 'BadAttrDef' attr { name: 'a' "
+ "type: 'list(int)' default_value { f: 0.5 } }"),
+ "AttrValue had value with type float when list(int) expected\n\t "
+ "for attr 'a'\n\t in Op 'BadAttrDef'");
+ ExpectFailure(
+ TestProto("name: 'BadAttrDef' attr { name: 'a' type: 'list(int)' "
+ "default_value { list { i: [5] f: [0.5] } } }"),
+ "AttrValue had value with type list(float) when list(int) "
+ "expected\n\t for attr 'a'\n\t in Op 'BadAttrDef'");
+
+ ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' "
+ "type: 'type' default_value { } }"),
+ "AttrValue missing value with expected type type\n\t for "
+ "attr 'a'\n\t in Op 'BadAttrDef'");
+ ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' "
+ "type: 'shape' default_value { } }"),
+ "AttrValue missing value with expected type shape\n\t for "
+ "attr 'a'\n\t in Op 'BadAttrDef'");
+ ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' "
+ "type: 'tensor' default_value { } }"),
+ "AttrValue missing value with expected type tensor\n\t for "
+ "attr 'a'\n\t in Op 'BadAttrDef'");
+
+ // default_value {} is indistinguishable from default_value{ list{} } (one
+ // with an empty list) in proto3 semantics.
+ EXPECT_OK(
+ TestProto("name: 'GoodAttrDef' attr { name: 'a' "
+ "type: 'list(int)' default_value { } }"));
+
+ // Empty lists are allowed:
+ EXPECT_OK(
+ TestProto("name: 'GoodAttrDef' attr { name: 'a' "
+ "type: 'list(int)' default_value { list { } } }"));
+ // Builder should make the same proto:
+ EXPECT_OK(TestBuilder(OpDefBuilder("GoodAttrDef").Attr("a: list(int) = []")));
+
+ // Unless there is a minimum length specified:
+ ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' "
+ "type: 'list(int)' has_minimum: true minimum: 2 "
+ "default_value { list { } } }"),
+ "Length for attr 'a' of 0 must be at least minimum 2\n\t in Op "
+ "'BadAttrDef'");
+ ExpectFailure(
+ TestBuilder(OpDefBuilder("GoodAttrDef").Attr("a: list(bool) >=2 = []")),
+ "Length for attr 'a' of 0 must be at least minimum 2\n\t in Op "
+ "'GoodAttrDef'");
+ ExpectFailure(TestProto("name: 'BadAttrDef' attr { name: 'a' type: "
+ "'list(string)' has_minimum: true minimum: 2 "
+ "default_value { list { s: ['foo'] } } }"),
+ "Length for attr 'a' of 1 must be at least minimum 2\n\t in Op "
+ "'BadAttrDef'");
+ ExpectFailure(TestBuilder(OpDefBuilder("GoodAttrDef")
+ .Attr("a: list(type) >=2 = [DT_STRING]")),
+ "Length for attr 'a' of 1 must be at least minimum 2\n\t in Op "
+ "'GoodAttrDef'");
+}
+
+TEST_F(ValidateOpDefTest, NoRefTypes) {
+ ExpectFailure(TestBuilder(OpDefBuilder("BadAttrDef").Input("i: float_ref")),
+ "Illegal use of ref type 'float_ref'. "
+ "Use 'Ref(type)' instead for input 'i'");
+ ExpectFailure(
+ TestBuilder(OpDefBuilder("BadAttrDef").Attr("T: type = DT_INT32_REF")),
+ "AttrValue must not have reference type value of int32_ref");
+ ExpectFailure(TestBuilder(OpDefBuilder("BadAttrDef")
+ .Attr("T: list(type) = [DT_STRING_REF]")),
+ "AttrValue must not have reference type value of string_ref");
+}
+
+TEST_F(ValidateOpDefTest, BadAttrMin) {
+ ExpectFailure(TestProto("name: 'BadAttrMin' attr { name: 'a' type: 'string' "
+ "has_minimum: true minimum: 0 }"),
+ "minimum for unsupported type string");
+ ExpectFailure(
+ TestProto("name: 'BadAttrMin' attr { name: 'a' type: 'int' default_value "
+ "{ i: 2 } has_minimum: true minimum: 7 }"),
+ "Value for attr 'a' of 2 must be at least minimum 7\n\t in Op "
+ "'BadAttrMin'");
+ ExpectFailure(
+ TestProto("name: 'BadAttrMin' attr { name: 'a' "
+ "type: 'list(string)' has_minimum: true minimum: -5 }"),
+ "list type must have a non-negative minimum, not -5");
+ EXPECT_OK(
+ TestProto("name: 'GoodAttrMin' attr { name: 'a' type: 'list(string)' "
+ "has_minimum: true minimum: 1 }"));
+ ExpectFailure(TestProto("name: 'NoHasMin' attr { name: 'a' "
+ "type: 'list(string)' minimum: 3 }"),
+ "Attr 'a' with has_minimum = false but minimum 3 not equal to "
+ "default of 0");
+}
+
+TEST_F(ValidateOpDefTest, BadAttrAllowed) {
+ // Is in list of allowed types.
+ EXPECT_OK(TestBuilder(
+ OpDefBuilder("GoodAttrtude").Attr("x: numbertype = DT_INT32")));
+ // Not in list of allowed types.
+ ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude")
+ .Attr("x: numbertype = DT_STRING")),
+ "attr 'x' of string is not in the list of allowed values");
+ ExpectFailure(
+ TestBuilder(OpDefBuilder("BadAttrtude")
+ .Attr("x: list(realnumbertype) = [DT_COMPLEX64]")),
+ "attr 'x' of complex64 is not in the list of allowed values");
+ // Is in list of allowed strings.
+ EXPECT_OK(TestBuilder(
+ OpDefBuilder("GoodAttrtude").Attr("x: {'foo', 'bar'} = 'bar'")));
+ // Not in list of allowed strings.
+ ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude")
+ .Attr("x: {'foo', 'bar'} = 'baz'")),
+ "attr 'x' of \"baz\" is not in the list of allowed values");
+ ExpectFailure(TestBuilder(OpDefBuilder("BadAttrtude")
+ .Attr("x: list({'foo', 'bar'}) = ['baz']")),
+ "attr 'x' of \"baz\" is not in the list of allowed values");
+ ExpectFailure(TestProto(
+ "name: 'BadAttrtude' attr { name: 'a' "
+ "type: 'string' allowed_values { s: 'not list' } }"),
+ "with type string when list(string) expected");
+ ExpectFailure(
+ TestProto("name: 'BadAttrtude' attr { name: 'a' "
+ "type: 'string' allowed_values { list { i: [6] } } }"),
+ "with type list(int) when list(string) expected");
+}
+
+TEST_F(ValidateOpDefTest, BadArgType) {
+ ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' "
+ "type: DT_INT32 } input_arg { name: 'b' }"),
+ "Missing type for input 'b'");
+ ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' "
+ "type: DT_INT32 } output_arg { name: 'b' }"),
+ "Missing type for output 'b'");
+ ExpectFailure(
+ TestProto("name: 'BadArg' input_arg { name: 'a' type: "
+ "DT_INT32 type_attr: 'x' } attr { name: 'x' type: 'type' }"),
+ "Exactly one of type, type_attr, type_list_attr must be set for input "
+ "'a'");
+ ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' "
+ "type_attr: 'x' } attr { name: 'x' type: 'int' }"),
+ "Attr 'x' used as type_attr for input 'a' has type int");
+ ExpectFailure(
+ TestProto("name: 'BadArg' input_arg { name: 'a' "
+ "type_attr: 'x' } attr { name: 'x' type: 'list(type)' }"),
+ "Attr 'x' used as type_attr for input 'a' has type list(type)");
+ ExpectFailure(
+ TestProto("name: 'BadArg' input_arg { name: 'a' "
+ "type_list_attr: 'x' } attr { name: 'x' type: 'int' }"),
+ "Attr 'x' used as type_list_attr for input 'a' has type int");
+ ExpectFailure(
+ TestProto("name: 'BadArg' input_arg { name: 'a' "
+ "type_list_attr: 'x' } attr { name: 'x' type: 'type' }"),
+ "Attr 'x' used as type_list_attr for input 'a' has type type");
+ ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' "
+ "type_attr: 'x' }"),
+ "No attr with name 'x' for input 'a'");
+ ExpectFailure(
+ TestProto("name: 'BadArg' input_arg { name: 'a' number_attr: 'n' "
+ "type_attr: 'x' } attr { name: 'x' type: 'list(type)' } "
+ "attr { name: 'n' type: 'int' has_minimum: true minimum: 1 }"),
+ "Attr 'x' used as type_attr for input 'a' has type list(type)");
+ // But list(type) is fine as the type of an arg without a number_attr:
+ EXPECT_OK(TestProto(
+ "name: 'Arg' input_arg { name: 'a' type_list_attr: 'x' } "
+ "attr { name: 'x' type: 'list(type)' } attr { name: 'n' type: 'int' "
+ "has_minimum: true minimum: 1 }"));
+
+ // number_attr
+ EXPECT_OK(TestProto(
+ "name: 'Arg' input_arg { name: 'a' type: DT_INT32 number_attr: 'n' } "
+ "attr { name: 'n' type: 'int' has_minimum: true minimum: 0 }"));
+
+ ExpectFailure(TestProto("name: 'Arg' input_arg { name: 'a' type: DT_INT32 "
+ "number_attr: 'n' }"),
+ "No attr with name 'n'");
+ ExpectFailure(
+ TestProto(
+ "name: 'Arg' input_arg { name: 'a' type: "
+ "DT_INT32 number_attr: 'n' } attr { name: 'n' type: 'string' }"),
+ "Attr 'n' used as length for input 'a' has type string");
+ ExpectFailure(
+ TestProto("name: 'Arg' input_arg { name: 'a' type: "
+ "DT_INT32 number_attr: 'n' } attr { name: 'n' type: 'int' }"),
+ "Attr 'n' used as length for input 'a' must have minimum;");
+ ExpectFailure(
+ TestProto("name: 'Arg' input_arg { name: 'a' type: DT_INT32 number_attr: "
+ "'n' } attr { name: 'n' type: 'int' has_minimum: true minimum: "
+ "-5 }"),
+ "Attr 'n' used as length for input 'a' must have minimum >= 0;");
+ ExpectFailure(
+ TestProto("name: 'Arg' input_arg { name: 'a' number_attr: 'n' } attr { "
+ "name: 'n' type: 'int' has_minimum: true minimum: 2 }"),
+ "Missing type for input 'a'; in OpDef:");
+ ExpectFailure(TestProto("name: 'BadArg' input_arg { name: 'a' number_attr: "
+ "'n' type_list_attr: 'x' } attr { name: 'n' type: "
+ "'int' has_minimum: true minimum: 1 } attr { name: "
+ "'x' type: 'list(type)' }"),
+ "Can't have both number_attr and type_list_attr for input 'a'");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc
new file mode 100644
index 0000000000..04f4b7cacd
--- /dev/null
+++ b/tensorflow/core/framework/op_gen_lib.cc
@@ -0,0 +1,55 @@
+#include "tensorflow/core/framework/op_gen_lib.h"
+
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+
+string WordWrap(StringPiece prefix, StringPiece str, int width) {
+ const string indent_next_line = "\n" + Spaces(prefix.size());
+ width -= prefix.size();
+ string result;
+ strings::StrAppend(&result, prefix);
+
+ while (!str.empty()) {
+ if (static_cast<int>(str.size()) <= width) {
+ // Remaining text fits on one line.
+ strings::StrAppend(&result, str);
+ break;
+ }
+ auto space = str.rfind(' ', width);
+ if (space == StringPiece::npos) {
+ // Rather make a too-long line and break at a space.
+ space = str.find(' ');
+ if (space == StringPiece::npos) {
+ strings::StrAppend(&result, str);
+ break;
+ }
+ }
+ // Breaking at character at position <space>.
+ StringPiece to_append = str.substr(0, space);
+ str.remove_prefix(space + 1);
+ // Remove spaces at break.
+ while (to_append.ends_with(" ")) {
+ to_append.remove_suffix(1);
+ }
+ while (str.Consume(" ")) {
+ }
+
+ // Go on to the next line.
+ strings::StrAppend(&result, to_append);
+ if (!str.empty()) strings::StrAppend(&result, indent_next_line);
+ }
+
+ return result;
+}
+
+bool ConsumeEquals(StringPiece* description) {
+ if (description->Consume("=")) {
+ while (description->Consume(" ")) { // Also remove spaces after "=".
+ }
+ return true;
+ }
+ return false;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_gen_lib.h b/tensorflow/core/framework/op_gen_lib.h
new file mode 100644
index 0000000000..9890f1bcec
--- /dev/null
+++ b/tensorflow/core/framework/op_gen_lib.h
@@ -0,0 +1,24 @@
+#ifndef TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_
+#define TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_
+
+#include <string>
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+
+inline string Spaces(int n) { return string(n, ' '); }
+
+// Wrap prefix + str to be at most width characters, indenting every line
+// after the first by prefix.size() spaces. Intended use case is something
+// like prefix = " Foo(" and str is a list of arguments (terminated by a ")").
+// TODO(josh11b): Option to wrap on ", " instead of " " when possible.
+string WordWrap(StringPiece prefix, StringPiece str, int width);
+
+// Looks for an "=" at the beginning of *description. If found, strips it off
+// (and any following spaces) from *description and return true. Otherwise
+// returns false.
+bool ConsumeEquals(StringPiece* description);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
new file mode 100644
index 0000000000..eb83d393f0
--- /dev/null
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -0,0 +1,749 @@
+#include "tensorflow/core/framework/op_kernel.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+namespace {
+
+Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
+ const DataTypeSlice expected_outputs,
+ const DataTypeSlice inputs,
+ const DataTypeSlice outputs) {
+ bool signature_mismatch = false;
+
+ if (inputs.size() != expected_inputs.size()) signature_mismatch = true;
+ for (size_t i = 0; !signature_mismatch && i < inputs.size(); ++i) {
+ if (!TypesCompatible(expected_inputs[i], inputs[i])) {
+ signature_mismatch = true;
+ }
+ }
+
+ if (outputs.size() != expected_outputs.size()) signature_mismatch = true;
+ for (size_t i = 0; !signature_mismatch && i < outputs.size(); ++i) {
+ if (!TypesCompatible(expected_outputs[i], outputs[i])) {
+ signature_mismatch = true;
+ }
+ }
+
+ if (signature_mismatch) {
+ return errors::InvalidArgument("Signature mismatch, have: ",
+ DataTypeSliceString(inputs), "->",
+ DataTypeSliceString(outputs), " expected: ",
+ DataTypeSliceString(expected_inputs), "->",
+ DataTypeSliceString(expected_outputs));
+ }
+ return Status::OK();
+}
+
+// Check HostMemory backward compatibility.
+bool CheckHostMemoryCompatibility(const DeviceType device_type,
+ const OpKernel* kernel) {
+ if (device_type == DEVICE_GPU) {
+ for (int i = 0; i < kernel->num_inputs(); ++i) {
+ if (kernel->input_type(i) == DT_INT32 &&
+ kernel->input_memory_types()[i] != HOST_MEMORY) {
+ return false;
+ }
+ }
+ for (int i = 0; i < kernel->num_outputs(); ++i) {
+ if (kernel->output_type(i) == DT_INT32 &&
+ kernel->output_memory_types()[i] != HOST_MEMORY) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+} // namespace
+
+// OpKernel ------------------------------------------------------------------
+
+OpKernel::OpKernel(OpKernelConstruction* context)
+ : def_(context->def()),
+ input_types_(context->input_types().begin(),
+ context->input_types().end()),
+ output_types_(context->output_types().begin(),
+ context->output_types().end()),
+ input_name_map_(context->num_inputs()),
+ output_name_map_(context->num_outputs()) {
+ OP_REQUIRES_OK(context,
+ NameRangesForNode(def_, context->op_def(), &input_name_map_,
+ &output_name_map_));
+
+ // By default, the input and output memory types are always in device memory,
+ // but can be overridden by individual implementations of OpKernels in their
+ // constructor.
+ input_memory_types_ = MemoryTypeVector(input_types_.size(), DEVICE_MEMORY);
+ output_memory_types_ = MemoryTypeVector(output_types_.size(), DEVICE_MEMORY);
+ // TODO(yuanbyu): For now we assume the memory types of function
+ // inputs/outputs to be DEVICE_MEMORY.
+ auto lib = context->function_library();
+ if (lib == nullptr || !lib->IsDefined(def_.op())) {
+ OP_REQUIRES_OK(context, MemoryTypesForNode(
+ context->device_type(), def_, context->op_def(),
+ input_name_map_, output_name_map_,
+ &input_memory_types_, &output_memory_types_));
+ // Log all the uses of int32 on GPU.
+ // TODO(yunabyu): Remove once everyone transitions to HostMemory.
+ if (VLOG_IS_ON(2)) {
+ if (!CheckHostMemoryCompatibility(context->device_type(), this)) {
+ VLOG(2) << "Using int32 on GPU at node: " << SummarizeNodeDef(def());
+ }
+ }
+ }
+}
+
+Status OpKernel::InputRange(const string& input_name, int* start,
+ int* stop) const {
+ const auto result = input_name_map_.find(input_name);
+ if (result == input_name_map_.end()) {
+ return errors::InvalidArgument("Unknown input name: ", input_name);
+ } else {
+ *start = result->second.first;
+ *stop = result->second.second;
+ return Status::OK();
+ }
+}
+
+Status OpKernel::OutputRange(const string& output_name, int* start,
+ int* stop) const {
+ const auto result = output_name_map_.find(output_name);
+ if (result == output_name_map_.end()) {
+ return errors::InvalidArgument("Unknown output name: ", output_name);
+ } else {
+ *start = result->second.first;
+ *stop = result->second.second;
+ return Status::OK();
+ }
+}
+
+void AsyncOpKernel::Compute(OpKernelContext* context) {
+ Notification n;
+ ComputeAsync(context, [&n]() { n.Notify(); });
+ n.WaitForNotification();
+}
+
+// PersistentTensor ----------------------------------------------------------
+
+Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) {
+ // the caller has to have a valid context
+ CHECK(context);
+ return &tensor_;
+}
+
+Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) {
+ context->NotifyUseOfPersistentTensor(tensor_);
+ return &tensor_;
+}
+
+// OpKernelConstruction ------------------------------------------------------
+
+Status OpKernelConstruction::MatchSignature(
+ const DataTypeSlice expected_inputs, const DataTypeSlice expected_outputs) {
+ return MatchSignatureHelper(expected_inputs, expected_outputs, input_types_,
+ output_types_);
+}
+
+Status OpKernelConstruction::allocate_temp(DataType type,
+ const TensorShape& shape,
+ Tensor* out_temp) {
+ Tensor new_temp(allocator_, type, shape);
+
+ if (!new_temp.IsInitialized() && shape.num_elements() > 0) {
+ return errors::ResourceExhausted(
+ "OOM when allocating temporary tensor with shape", shape.DebugString());
+ }
+ *out_temp = new_temp;
+ return Status::OK();
+}
+
+Status OpKernelConstruction::allocate_persistent(
+ DataType type, const TensorShape& shape, PersistentTensor* out_persistent,
+ Tensor** out_tensor) {
+ // for now just do the same thing as allocate_temp
+ // TODO(misard) add specific memory tracking for persistent tensors
+ Tensor persistent;
+ Status s = allocate_temp(type, shape, &persistent);
+ if (!s.ok()) {
+ return s;
+ }
+ *out_persistent = PersistentTensor(persistent);
+ Tensor* allocated = out_persistent->AccessTensor(this);
+ if (out_tensor) {
+ *out_tensor = allocated;
+ }
+ return s;
+}
+
+// OpKernelContext -----------------------------------------------------------
+
+OpKernelContext::OpKernelContext(const Params& params)
+ : params_(params),
+ outputs_(params.op_kernel->output_types().size()),
+ output_allocation_types_(params.op_kernel->output_types().size()) {
+ Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
+ eigen_gpu_device_ = params_.device->MakeGpuDevice(params_.op_device_context,
+ eigen_gpu_allocator);
+}
+
+OpKernelContext::~OpKernelContext() {
+ for (TensorValue& value : outputs_) {
+ if (!value.is_ref()) {
+ delete value.tensor;
+ }
+ }
+ for (Tensor* t : temp_tensors_) delete t;
+ delete eigen_gpu_device_;
+}
+
+Status OpKernelContext::input(const string& name, const Tensor** tensor) const {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued input name '",
+ name,
+ "' when single-valued input was "
+ "expected");
+ }
+ if ((*params_.inputs)[start].is_ref()) {
+ return errors::InvalidArgument("OpKernel used ref input name '", name,
+ "' when immutable input was expected");
+ }
+ *tensor = (*params_.inputs)[start].tensor;
+ return Status::OK();
+}
+
+Status OpKernelContext::input_ref_mutex(const string& name, mutex** out_mutex) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued input name '",
+ name,
+ "' when single-valued input was expected");
+ }
+ *out_mutex = input_ref_mutex(start);
+ return Status::OK();
+}
+
+Status OpKernelContext::mutable_input(const string& name, Tensor* tensor,
+ bool lock_held) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued input name '",
+ name,
+ "' when single-valued input was expected");
+ }
+ if (!(*params_.inputs)[start].is_ref()) {
+ return errors::InvalidArgument("OpKernel used immutable input name '", name,
+ "' when ref input was expected");
+ }
+ // return a copy of the Ref acquired while holding the mutex
+ if (lock_held) {
+ *tensor = *(*params_.inputs)[start].tensor;
+ } else {
+ mutex_lock l(*input_ref_mutex(start));
+ *tensor = *(*params_.inputs)[start].tensor;
+ }
+ return Status::OK();
+}
+
+Status OpKernelContext::replace_ref_input(const string& name,
+ const Tensor& tensor,
+ bool lock_held) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued input name '",
+ name,
+ "' when single-valued input was expected");
+ }
+ if (!(*params_.inputs)[start].is_ref()) {
+ return errors::InvalidArgument("OpKernel used immutable input name '", name,
+ "' when ref input was expected");
+ }
+ replace_ref_input(start, tensor, lock_held);
+ return Status::OK();
+}
+
+Status OpKernelContext::input_list(const string& name,
+ OpInputList* list) const {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ *list = OpInputList(this, start, stop);
+ return Status::OK();
+}
+
+Status OpKernelContext::mutable_input_list(const string& name,
+ OpMutableInputList* list) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->InputRange(name, &start, &stop));
+ *list = OpMutableInputList(this, start, stop);
+ return Status::OK();
+}
+
+Status OpKernelContext::output_list(const string& name, OpOutputList* list) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ *list = OpOutputList(this, start, stop);
+ return Status::OK();
+}
+
+Status OpKernelContext::allocate_output(const string& name,
+ const TensorShape& shape,
+ Tensor** tensor) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued output name '",
+ name,
+ "' when single-valued output was "
+ "expected");
+ }
+ return allocate_output(start, shape, tensor);
+}
+
+Status OpKernelContext::allocate_output(const string& name,
+ const TensorShape& shape,
+ Tensor** tensor,
+ AllocatorAttributes attr) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued output name '",
+ name,
+ "' when single-valued output was "
+ "expected");
+ }
+ return allocate_output(start, shape, tensor, attr);
+}
+
+Status OpKernelContext::set_output(const string& name, const Tensor& tensor) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued output name '",
+ name,
+ "' when single-valued output was "
+ "expected");
+ }
+ set_output(start, tensor);
+ return Status::OK();
+}
+
+Status OpKernelContext::set_output_ref(const string& name, mutex* mu,
+ Tensor* tensor_for_ref) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued output name '",
+ name,
+ "' when single-valued output was "
+ "expected");
+ }
+ set_output_ref(start, mu, tensor_for_ref);
+ return Status::OK();
+}
+
+Status OpKernelContext::mutable_output(const string& name, Tensor** tensor) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued output name '",
+ name,
+ "' when single-valued output was "
+ "expected");
+ }
+ *tensor = mutable_output(start);
+ return Status::OK();
+}
+
+Status OpKernelContext::release_output(const string& name, TensorValue* value) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(params_.op_kernel->OutputRange(name, &start, &stop));
+ if (stop != start + 1) {
+ return errors::InvalidArgument("OpKernel used list-valued output name '",
+ name,
+ "' when single-valued output was "
+ "expected");
+ }
+ *value = release_output(start);
+ return Status::OK();
+}
+
+bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
+ const auto& inputs = *params_.inputs;
+ for (size_t i = 1; i < inputs.size(); ++i) {
+ if (!inputs[0]->IsSameSize(*(inputs[i].tensor))) {
+ SetStatus(errors::InvalidArgument(
+ "Inputs to operation ", op->name(), " of type ", op->type_string(),
+ " must have the same size and shape. Input 0: ",
+ inputs[0]->shape().DebugString(), " != input ", i, ": ",
+ inputs[i]->shape().DebugString()));
+ return false;
+ }
+ }
+ return true;
+}
+
+Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
+ const DataTypeSlice expected_outputs) {
+ DataTypeVector inputs;
+ for (const TensorValue& t : *params_.inputs) {
+ inputs.push_back(t.is_ref() ? MakeRefType(t->dtype()) : t->dtype());
+ }
+ DataTypeVector outputs = params_.op_kernel->output_types();
+ return MatchSignatureHelper(expected_inputs, expected_outputs, inputs,
+ outputs);
+}
+
+// OpKernel registration ------------------------------------------------------
+
+struct KernelRegistration {
+ KernelRegistration(const KernelDef& d,
+ kernel_factory::OpKernelRegistrar::Factory f)
+ : def(d), factory(f) {}
+ const KernelDef def;
+ const kernel_factory::OpKernelRegistrar::Factory factory;
+};
+
+// This maps from 'op_type' + DeviceType to the set of KernelDefs and
+// factory functions for instantiating the OpKernel that matches the
+// KernelDef.
+typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry;
+
+static KernelRegistry* GlobalKernelRegistry() {
+ static KernelRegistry* global_kernel_registry = new KernelRegistry;
+ return global_kernel_registry;
+}
+
+static string Key(const string& op_type, DeviceType device_type,
+ const string& label) {
+ return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":",
+ label);
+}
+
+namespace kernel_factory {
+
+OpKernelRegistrar::OpKernelRegistrar(const KernelDef* kernel_def,
+ Factory factory) {
+ const string key =
+ Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
+ kernel_def->label());
+ GlobalKernelRegistry()->insert(
+ std::make_pair(key, KernelRegistration(*kernel_def, factory)));
+ delete kernel_def;
+}
+
+} // namespace kernel_factory
+
+namespace {
+
+// Helper for AttrsMatch().
+bool InTypeList(DataType dt, const AttrValue& type_list) {
+ for (int in_list : type_list.list().type()) {
+ if (dt == in_list) return true;
+ }
+ return false;
+}
+
+// Returns whether the attrs in the NodeDef satisfy the constraints in
+// the kernel_def. Returns an error if attrs in kernel_def are not
+// found, or have a mismatching type.
+Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def,
+ bool* match) {
+ *match = false;
+ AttrSlice attrs(node_def);
+ for (const auto& constraint : kernel_def.constraint()) {
+ if (constraint.allowed_values().list().type_size() == 0) {
+ return errors::Unimplemented(
+ "KernelDef '", kernel_def.ShortDebugString(),
+ " has constraint on attr '", constraint.name(),
+ "' with unsupported type: ",
+ SummarizeAttrValue(constraint.allowed_values()));
+ }
+
+ const AttrValue* found = attrs.Find(constraint.name());
+ if (found) {
+ if (found->type() != DT_INVALID) {
+ if (!InTypeList(found->type(), constraint.allowed_values())) {
+ return Status::OK();
+ }
+ } else {
+ if (!AttrValueHasType(*found, "list(type)").ok()) {
+ return errors::InvalidArgument(
+ "KernelDef '", kernel_def.ShortDebugString(),
+ "' has constraint on attr '", constraint.name(),
+ "' that has value '", SummarizeAttrValue(*found),
+ "' that does not have type 'type' or 'list(type)' in NodeDef '",
+ SummarizeNodeDef(node_def), "'");
+ }
+
+ for (int t : found->list().type()) {
+ if (!InTypeList(static_cast<DataType>(t),
+ constraint.allowed_values())) {
+ return Status::OK();
+ }
+ }
+ }
+ } else {
+ return errors::InvalidArgument(
+ "OpKernel '", kernel_def.op(), "' has constraint on attr '",
+ constraint.name(), "' not in NodeDef '", SummarizeNodeDef(node_def),
+ "', KernelDef: '", kernel_def.ShortDebugString(), "'");
+ }
+ }
+ *match = true;
+ return Status::OK();
+}
+
+Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def,
+ const KernelRegistration** reg) {
+ *reg = nullptr;
+ string label; // Label defaults to empty if not found in NodeDef.
+ GetNodeAttr(node_def, "_kernel", &label);
+ const string key = Key(node_def.op(), device_type, label);
+ auto regs = GlobalKernelRegistry()->equal_range(key);
+ for (auto iter = regs.first; iter != regs.second; ++iter) {
+ // If there is a kernel registered for the op and device_type,
+ // check that the attrs match.
+ bool match;
+ TF_RETURN_IF_ERROR(AttrsMatch(node_def, iter->second.def, &match));
+ if (match) {
+ if (*reg != nullptr) {
+ return errors::InvalidArgument(
+ "Multiple OpKernel registrations match NodeDef '",
+ SummarizeNodeDef(node_def), "': '", (*reg)->def.ShortDebugString(),
+ "' and '", iter->second.def.ShortDebugString(), "'");
+ }
+ *reg = &iter->second;
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status SupportedDeviceTypesForNode(
+ const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
+ DeviceTypeVector* device_types) {
+ // TODO(zhifengc): Changes the callers (SimplePlacer and
+ // DynamicPlacer) to consider the possibility that 'def' is call to
+ // a user-defined function and only calls this
+ // SupportedDeviceTypesForNode for primitive ops.
+ Status s;
+ const OpDef* op_def = OpRegistry::Global()->LookUp(def.op(), &s);
+ if (op_def) {
+ for (const DeviceType& device_type : prioritized_types) {
+ const KernelRegistration* reg = nullptr;
+ TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, def, &reg));
+ if (reg != nullptr) device_types->push_back(device_type);
+ }
+ } else {
+ // Assumes that all device types support this node.
+ for (const DeviceType& device_type : prioritized_types) {
+ device_types->push_back(device_type);
+ }
+ }
+ return Status::OK();
+}
+
+std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
+ DeviceBase* device,
+ Allocator* allocator,
+ const NodeDef& node_def,
+ Status* status) {
+ OpKernel* kernel = nullptr;
+ *status = CreateOpKernel(device_type, device, allocator, nullptr, node_def,
+ &kernel);
+ return std::unique_ptr<OpKernel>(kernel);
+}
+
+Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
+ Allocator* allocator, FunctionLibraryRuntime* flib,
+ const NodeDef& node_def, OpKernel** kernel) {
+ VLOG(1) << "Instantiating kernel for node: " << SummarizeNodeDef(node_def);
+
+ // Look up the Op registered for this op name.
+ Status s;
+ const OpDef* op_def = OpRegistry::Global()->LookUp(node_def.op(), &s);
+ if (op_def == nullptr) return s;
+
+ // Validate node_def against OpDef.
+ s = ValidateNodeDef(node_def, *op_def);
+ if (!s.ok()) return s;
+
+ // Look up kernel registration.
+ const KernelRegistration* registration;
+ s = FindKernelRegistration(device_type, node_def, &registration);
+ if (!s.ok()) {
+ errors::AppendToMessage(&s, " when instantiating ", node_def.op());
+ return s;
+ }
+ if (registration == nullptr) {
+ s.Update(errors::NotFound("No registered '", node_def.op(),
+ "' OpKernel for ", DeviceTypeString(device_type),
+ " devices compatible with node ",
+ SummarizeNodeDef(node_def)));
+ return s;
+ }
+
+ // Get signature from the OpDef & NodeDef
+ DataTypeVector inputs;
+ DataTypeVector outputs;
+ s.Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs));
+ if (!s.ok()) {
+ errors::AppendToMessage(&s, " for node: ", SummarizeNodeDef(node_def));
+ return s;
+ }
+
+ // Everything needed for OpKernel construction.
+ OpKernelConstruction context(device_type, device, allocator, &node_def,
+ op_def, flib, inputs, outputs, &s);
+ *kernel = (*registration->factory)(&context);
+ if (!s.ok()) {
+ delete *kernel;
+ *kernel = nullptr;
+ }
+ return s;
+}
+
+namespace { // Helper for MemoryTypesForNode.
+// Fills memory_types for either input or output, setting everything
+// to DEVICE_MEMORY except those args in host_memory_args. Removes
+// elements of host_memory_args that were used.
+void MemoryTypesHelper(const NameRangeMap& name_map,
+ std::vector<string>* host_memory_args,
+ MemoryTypeVector* memory_types) {
+ // Set total to the largest endpoint of anything in the name_map.
+ int total = 0;
+ for (const auto& item : name_map) {
+ total = std::max(total, item.second.second);
+ }
+
+ // Now that we know the size, fill with the default 'DEVICE_MEMORY'.
+ memory_types->clear();
+ memory_types->resize(total, DEVICE_MEMORY);
+
+ // Update args that have been marked as in "HOST_MEMORY".
+ size_t keep = 0;
+ for (size_t i = 0; i < host_memory_args->size(); ++i) {
+ auto iter = name_map.find((*host_memory_args)[i]);
+ if (iter != name_map.end()) {
+ for (int j = iter->second.first; j < iter->second.second; ++j) {
+ (*memory_types)[j] = HOST_MEMORY;
+ }
+ } else {
+ // (*host_memory_args)[i] not found, save it for the next pass.
+ if (i > keep) (*host_memory_args)[keep] = (*host_memory_args)[i];
+ ++keep;
+ }
+ }
+ host_memory_args->resize(keep);
+}
+} // namespace
+
+Status MemoryTypesForNode(DeviceType device_type, const NodeDef& ndef,
+ const OpDef& op_def,
+ const NameRangeMap& input_name_map,
+ const NameRangeMap& output_name_map,
+ MemoryTypeVector* input_memory_types,
+ MemoryTypeVector* output_memory_types) {
+ Status status;
+ const KernelRegistration* registration;
+ TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, ndef, &registration));
+
+ if (registration != nullptr) {
+ const auto& from_proto = registration->def.host_memory_arg();
+ std::vector<string> host_memory_args(from_proto.begin(), from_proto.end());
+ MemoryTypesHelper(input_name_map, &host_memory_args, input_memory_types);
+ MemoryTypesHelper(output_name_map, &host_memory_args, output_memory_types);
+ if (!host_memory_args.empty()) {
+ return errors::InvalidArgument(
+ "HostMemory args '", str_util::Join(host_memory_args, "', '"),
+ "' not found in OpDef: ", SummarizeOpDef(op_def));
+ }
+ }
+ return status;
+}
+
+Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
+ DeviceType device_type, const NodeDef& ndef,
+ MemoryTypeVector* input_memory_types,
+ MemoryTypeVector* output_memory_types) {
+ // Look up the Op registered for this op name.
+ Status status;
+ const OpDef* op_def = op_registry->LookUp(ndef.op(), &status);
+ if (op_def == nullptr) return status;
+
+ NameRangeMap inputs, outputs;
+ status = NameRangesForNode(ndef, *op_def, &inputs, &outputs);
+ if (!status.ok()) return status;
+
+ return MemoryTypesForNode(device_type, ndef, *op_def, inputs, outputs,
+ input_memory_types, output_memory_types);
+}
+
+namespace {
+
+bool FindArgInOp(const string& arg_name,
+ const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
+ for (const auto& arg : args) {
+ if (arg_name == arg.name()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace
+
+Status ValidateKernelRegistrations(const OpRegistryInterface* op_registry) {
+ Status unused_status;
+ for (const auto& key_registration : *GlobalKernelRegistry()) {
+ const KernelDef& kernel_def(key_registration.second.def);
+ const OpDef* op_def = op_registry->LookUp(kernel_def.op(), &unused_status);
+ if (op_def == nullptr) {
+ // TODO(josh11b): Make this a hard error.
+ LOG(ERROR) << "OpKernel ('" << kernel_def.ShortDebugString()
+ << "') for unknown op: " << kernel_def.op();
+ continue;
+ }
+ for (const auto& host_memory_arg : kernel_def.host_memory_arg()) {
+ if (!FindArgInOp(host_memory_arg, op_def->input_arg()) &&
+ !FindArgInOp(host_memory_arg, op_def->output_arg())) {
+ return errors::InvalidArgument("HostMemory arg '", host_memory_arg,
+ "' not found in OpDef: ",
+ SummarizeOpDef(*op_def));
+ }
+ }
+ }
+ return Status::OK();
+}
+
+template <>
+const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const {
+ return eigen_cpu_device();
+}
+
+template <>
+const Eigen::GpuDevice& OpKernelContext::eigen_device() const {
+ return eigen_gpu_device();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
new file mode 100644
index 0000000000..34d588c6c9
--- /dev/null
+++ b/tensorflow/core/framework/op_kernel.h
@@ -0,0 +1,1250 @@
+#ifndef TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
+#define TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
+
+#include <functional>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/control_flow.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/kernel_def.pb.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/tracking_allocator.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace Eigen {
+class ThreadPoolDevice;
+class GpuDevice;
+} // end namespace Eigen
+
+namespace tensorflow {
+
+namespace checkpoint {
+class TensorSliceReaderCacheWrapper;
+} // namespace checkpoint
+
+class AsyncOpKernel;
+class OpKernelConstruction; // declared below
+class OpKernelContext; // declared below
+class ResourceMgr;
+
+// TODO(josh11b): Make reference-counted if needed.
+class OpKernel {
+ public:
+ // OpKernel won't be instantiated by the scheduler, so you may perform
+ // expensive initialization in the descendant's constructor.
+ explicit OpKernel(OpKernelConstruction* context);
+ virtual ~OpKernel() {}
+
+ // An OpKernel's computation can be either synchronous or
+ // asynchronous.
+ //
+ // Most OpKernels should compute synchronously. They should
+ // subclass OpKernel and override the Compute() method and have it
+ // return after completing the supplied work.
+ //
+ // A few special kernels might need to be asynchronous to bound the
+ // number of threads (e.g., network receive operations). These
+ // kernels must subclass AsyncOpKernel and override
+ // AsyncOpKernel::ComputeAsync().
+ //
+ // In both cases, implementations of Compute() and ComputeAsync()
+ // get inputs and write outputs through the given OpKernelContext
+ // and returns a status via context->SetStatus(). They must be
+ // thread-safe.
+
+ // Synchronous compute.
+ //
+ // "context" is guaranteed to be alive until Compute() returns.
+ virtual void Compute(OpKernelContext* context) = 0;
+
+ // Returns nullptr iff this op kernel is synchronous.
+ virtual AsyncOpKernel* AsAsync() { return nullptr; }
+
+ // Returns true iff this op kernel is considered "expensive". The
+ // runtime may use this flag to optimize graph execution for example
+ // to "inline" inexpensive kernels.
+ virtual bool IsExpensive() { return true; }
+
+ // Accessors.
+ const NodeDef& def() const { return def_; }
+ const string& name() const { return def_.name(); }
+ const string& type_string() const { return def_.op(); }
+
+ int num_inputs() const { return input_types_.size(); }
+ DataType input_type(int i) const { return input_types_[i]; }
+ const DataTypeVector& input_types() const { return input_types_; }
+ const MemoryTypeVector& input_memory_types() const {
+ return input_memory_types_;
+ }
+
+ int num_outputs() const { return output_types_.size(); }
+ DataType output_type(int o) const { return output_types_[o]; }
+ const DataTypeVector& output_types() const { return output_types_; }
+ const MemoryTypeVector& output_memory_types() const {
+ return output_memory_types_;
+ }
+
+ Status InputRange(const string& input_name, int* start, int* stop) const;
+ Status OutputRange(const string& output_name, int* start, int* stop) const;
+
+ private:
+ const NodeDef def_;
+ const DataTypeVector input_types_;
+ const DataTypeVector output_types_;
+ NameRangeMap input_name_map_;
+ NameRangeMap output_name_map_;
+ MemoryTypeVector input_memory_types_;
+ MemoryTypeVector output_memory_types_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
+};
+
+class AsyncOpKernel : public OpKernel {
+ public:
+ using OpKernel::OpKernel; // Lift OpKernel constructors.
+
+ // Asynchronous compute.
+ //
+ // Implementations of ComputeAsync() must run "done" to signal the
+ // completion of the computation. "context" is guaranteed to be
+ // alive until the "done" callback starts.
+ typedef std::function<void()> DoneCallback;
+ virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0;
+
+ AsyncOpKernel* AsAsync() final { return this; }
+
+ void Compute(OpKernelContext* context) final;
+};
+
+// Wraps a tensor that is held by an Op across calls to Compute(). For
+// memory safety when using asynchronous devices like GPUs, the system
+// must be notified when a Tensor is used inside an Op execution. The
+// wrapper ensures that all uses of the Tensor are tracked, because in
+// order to retrieve the Tensor the caller must use AccessTensor which
+// notifies the context.
+class PersistentTensor {
+ public:
+ PersistentTensor() {}
+ explicit PersistentTensor(const Tensor& tensor) : tensor_(tensor) {}
+
+ // Caller does not own the returned Tensor*.
+ Tensor* AccessTensor(OpKernelConstruction* context);
+ // Caller does not own the returned Tensor*.
+ Tensor* AccessTensor(OpKernelContext* context);
+
+ // The check for initialization does not need to access the
+ // underlying tensor buffer.
+ bool IsInitialized() { return tensor_.IsInitialized(); }
+
+ private:
+ Tensor tensor_;
+};
+
+class OpKernelConstruction {
+ public:
+ // TODO(yuanbyu): Probably reduce the number of arguments.
+ OpKernelConstruction(DeviceType device_type, DeviceBase* device,
+ Allocator* allocator, const NodeDef* node_def,
+ const OpDef* op_def, FunctionLibraryRuntime* flib,
+ const DataTypeSlice& input_types,
+ const DataTypeSlice& output_types, Status* status)
+ : device_type_(device_type),
+ device_(device),
+ allocator_(allocator),
+ def_(node_def),
+ op_def_(op_def),
+ flib_(flib),
+ input_types_(input_types),
+ output_types_(output_types),
+ status_(status) {}
+
+ Env* env() const { return device_->env(); }
+
+ // Allocation of tensors during kernel construction:
+ //
+ // It is legal to temporarily allocate scratch tensor storage during
+ // Op kernel construction. Scratch tensors should be allocated using
+ // allocate_temp below. Some kernels need to keep tensors in between
+ // invocations. If such a Tensor is allocated during kernel
+ // construction this must be done using allocate_persistent, and the
+ // Op may only store the returned PersistentTensor object. When the
+ // Tensor is needed in a subsequent invocation, it can be retrieved
+ // from the PersistentTensor using the AccessTensor method. This
+ // ensures that the system is made aware of any use of the tensor's
+ // allocated memory, which is needed for correctness on asynchronous
+ // devices such as GPUs.
+
+ // Allocates a temporary Tensor of the specified type and shape. The
+ // Tensor must not be used after kernel construction is
+ // complete. See comment above.
+ Status allocate_temp(DataType type, const TensorShape& shape,
+ Tensor* out_temp);
+
+ // Allocates a Tensor of the specified type and shape which the Op
+ // plans to maintain as persistent state. out_persistent holds the
+ // PersistentTensor which is the object the caller should store. For
+ // convenience, if out_tensor is non-null then it will be filled in
+ // with a Tensor* pointing to the newly-allocated tensor which the
+ // caller can use instead of calling
+ // out_persistent->AccessTensor. The caller does not own out_tensor
+ // and should not keep a copy of it. See comment above.
+ Status allocate_persistent(DataType type, const TensorShape& shape,
+ PersistentTensor* out_persistent,
+ Tensor** out_tensor);
+
+ // User-supplied configuration of this operation.
+ const NodeDef& def() const { return *def_; }
+
+ // Op registered for this op type.
+ const OpDef& op_def() const { return *op_def_; }
+
+ // For inspecting the inputs to this operation.
+ int num_inputs() const { return input_types_.size(); }
+ DataType input_type(int i) const { return input_types_[i]; }
+ const DataTypeSlice& input_types() const { return input_types_; }
+
+ // For inspecting the outputs expected from this operation.
+ int num_outputs() const { return output_types_.size(); }
+ DataType output_type(int i) const { return output_types_[i]; }
+ const DataTypeSlice& output_types() const { return output_types_; }
+
+ // If expected_inputs == inputs() and expected_outputs == output_types(),
+ // returns OK, else returns INVALID_ARGUMENT with an error message.
+ // Recommended for Ops with dynamic signatures.
+ Status MatchSignature(const DataTypeSlice expected_inputs,
+ const DataTypeSlice expected_outputs);
+
+ // For recording configuration errors during construction.
+ void SetStatus(const Status& status) { status_->Update(status); }
+ const Status& status() const { return *status_; }
+
+ // Look up the attr with name attr_name and set *value to its value. If no
+ // attr with attr_name is found in def(), or the attr does not have
+ // a matching type, a non-ok status will be returned.
+ template <class T>
+ Status GetAttr(const string& attr_name, T* value) const {
+ return GetNodeAttr(def(), attr_name, value);
+ }
+
+ // May be used, e.g., to get GPU handles, etc.
+ // TODO(tucker): Add example usage.
+ DeviceBase* device() const { return device_; }
+
+ // Return the device type.
+ const DeviceType& device_type() const { return device_type_; }
+
+ // If not nullptr, the kernel can instantiate functions defined in
+ // the library. E.g.,
+ // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...).
+ FunctionLibraryRuntime* function_library() const { return flib_; }
+
+ private:
+ const DeviceType device_type_;
+ DeviceBase* const device_;
+ Allocator* allocator_;
+ const NodeDef* def_;
+ const OpDef* op_def_;
+ FunctionLibraryRuntime* flib_;
+ DataTypeSlice input_types_;
+ DataTypeSlice output_types_;
+ Status* status_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(OpKernelConstruction);
+};
+
+// TODO(mrry): Consider converting to a random_access_iterator, and upgrading
+// tensorflow::gtl::iterator_range to make the below container classes
+// unnecessary.
+template <typename ListType, typename ElementType>
+class OpArgIterator {
+ public:
+ typedef OpArgIterator<ListType, ElementType> ME;
+ OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
+ bool operator==(const ME& rhs) {
+ DCHECK(list_ == rhs.list_);
+ return i_ == rhs.i_;
+ }
+ bool operator!=(const ME& rhs) {
+ DCHECK(list_ == rhs.list_);
+ return i_ != rhs.i_;
+ }
+ void operator++() { ++i_; }
+ ElementType& operator*() { return (*list_)[i_]; }
+
+ private:
+ const ListType* const list_;
+ int i_;
+};
+
+// Utility class for representing a list of immutable input tensors
+// that are passed to the op as a single named argument.
+class OpInputList {
+ public:
+ typedef OpArgIterator<OpInputList, const Tensor&> Iterator;
+ OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
+ OpInputList(const OpKernelContext* ctx, int start, int stop)
+ : ctx_(ctx), start_(start), stop_(stop) {}
+ OpInputList& operator=(const OpInputList& other) = default;
+ const Tensor& operator[](int i) const;
+ int size() const { return stop_ - start_; }
+ Iterator begin() const { return Iterator(this, 0); }
+ Iterator end() const { return Iterator(this, size()); }
+
+ private:
+ const OpKernelContext* ctx_; // not owned
+ int start_;
+ int stop_;
+};
+
+// Utility class for representing a list of mutable ("ref") input tensors
+// that are passed to the op as a single named argument.
+class OpMutableInputList {
+ public:
+ typedef OpArgIterator<OpMutableInputList, Tensor*> Iterator;
+ OpMutableInputList(OpKernelContext* ctx, int start, int stop)
+ : ctx_(ctx), start_(start), stop_(stop) {}
+ OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {}
+ OpMutableInputList& operator=(const OpMutableInputList& other) = default;
+ Tensor at(int i, bool lock_held);
+ mutex* ref_mutex(int i);
+ int size() const { return stop_ - start_; }
+ Iterator begin() const { return Iterator(this, 0); }
+ Iterator end() const { return Iterator(this, size()); }
+
+ private:
+ OpKernelContext* ctx_; // not owned
+ int start_;
+ int stop_;
+};
+
+// Utility class for representing a list of output tensors that are
+// grouped as a single named output.
+class OpOutputList {
+ public:
+ typedef OpArgIterator<OpOutputList, const Tensor*> Iterator;
+ OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {}
+ OpOutputList(OpKernelContext* ctx, int start, int stop)
+ : ctx_(ctx), start_(start), stop_(stop) {}
+ OpOutputList& operator=(const OpOutputList& other) = default;
+ Tensor* operator[](int i);
+ bool required(int i) const;
+ Status allocate(int i, const TensorShape& shape, Tensor** output);
+ void set(int i, const Tensor& tensor);
+ void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
+ int size() const { return stop_ - start_; }
+ Iterator begin() const { return Iterator(this, 0); }
+ Iterator end() const { return Iterator(this, size()); }
+
+ private:
+ OpKernelContext* ctx_; // not owned
+ int start_;
+ int stop_;
+};
+
+// Holds a tensor or tensor reference. For tensor references, we need
+// a mutex to prevent concurrent access to the tensor.
+struct TensorValue {
+ TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {}
+ TensorValue(Tensor* t) // NOLINT(runtime/explicit)
+ : mutex_if_ref(nullptr),
+ tensor(t) {}
+ TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {}
+ Tensor* operator->() const { return tensor; }
+ bool is_ref() const { return mutex_if_ref != nullptr; }
+
+ mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref
+ Tensor* tensor;
+};
+
+class OpKernelContext {
+ public:
+ // The first element of a WrappedAllocator is a "base" Allocator and
+ // the second element is that Allocator wrapped by a
+ // TrackingAllocator
+ typedef std::pair<Allocator*, TrackingAllocator*> WrappedAllocator;
+
+ // TODO(zhifengc): Do some cleanup of Params.
+ struct Params {
+ // The op kernel being computed.
+ OpKernel* op_kernel = nullptr;
+
+ // The device on which the kernel is running.
+ DeviceBase* device = nullptr;
+
+ bool track_allocations = false;
+ std::function<AllocatorAttributes(int index)> output_alloc_attr = nullptr;
+
+ // Shared resources accessible by this op kernel invocation.
+ ResourceMgr* resource_manager = nullptr;
+
+ // Per-step resources accessible by this op kernel invocation.
+ ResourceMgr* step_resource_manager = nullptr;
+
+ // Mechanism used by this op kernel invocation to communicate with
+ // computations running on other devices.
+ Rendezvous* rendezvous = nullptr;
+
+ // Mechanism used by this op kernel invocation to register a callback
+ // for its cancellation.
+ CancellationManager* cancellation_manager = nullptr;
+
+ // Inputs to this op kernel.
+ const gtl::InlinedVector<TensorValue, 4>* inputs = nullptr;
+ bool is_input_dead = false;
+
+ const gtl::InlinedVector<AllocatorAttributes, 4>* input_alloc_attrs =
+ nullptr;
+
+ // Device contexts.
+ const gtl::InlinedVector<DeviceContext*, 4>* input_device_contexts =
+ nullptr;
+ DeviceContext* op_device_context = nullptr;
+
+ // Control-flow op supports.
+ FrameAndIter frame_iter;
+
+ // Function call supports.
+ FunctionCallFrame* call_frame = nullptr;
+ FunctionLibraryRuntime* function_library = nullptr;
+
+ // TensorSliceReaderCache support.
+ checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
+ };
+ explicit OpKernelContext(const Params& params);
+ ~OpKernelContext();
+
+ Env* env() const { return params_.device->env(); }
+
+ // Input/output signature.
+
+ int num_inputs() const { return params_.inputs->size(); }
+ DataType input_dtype(int index) const;
+ int num_outputs() const { return outputs_.size(); }
+ DataType expected_output_dtype(int index) const;
+
+ // Input
+
+ // Returns an immutable input tensor. May only be used for non-Ref
+ // inputs. For Ref inputs use mutable_input below.
+ // REQUIRES: !IsRefType(input_dtype(index))
+ // TODO(mrry): Convert this to return Status.
+ const Tensor& input(int index) const;
+
+ // Returns the named immutable input tensor in "tensor", as defined
+ // in the OpDef. May only be used for non-Ref inputs. For Ref inputs
+ // use mutable_input below.
+ // REQUIRES: !IsRefType(input_dtype(index))
+ // REQUIRES: the named input must not be a list.
+ Status input(const string& name, const Tensor** tensor) const;
+
+ // Returns the named list-valued immutable input in "list", as
+ // defined in the OpDef. If the named output is not list-valued,
+ // returns a one-element list. May only be used for non-Ref
+ // inputs. For Ref inputs use mutable_input below.
+ // REQUIRES: !IsRefType(input_dtype(index))
+ Status input_list(const string& name, OpInputList* list) const;
+
+ // For mutable inputs, use the following together to make sure there
+ // is no concurrent access to mutable_input(), e.g.:
+ // {
+ // Tensor& t = context->mutable_input(index);
+ // mutex_lock lock(*context->input_ref_mutex(index));
+ // // modify the values in t
+ // }
+ // REQUIRES: IsRefType(input_dtype(index))
+ // TODO(mrry): Convert this to return Status.
+ mutex* input_ref_mutex(int index);
+ Status input_ref_mutex(const string& name, mutex** out_mutex);
+
+ // Returns a mutable input tensor. Must be used to access Ref
+ // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may
+ // modify the values stored in the Tensor buffer, and modifications
+ // will be visible to other Ops reading the same ref tensor. If
+ // !lock_held the input mutex will be acquired before returning the
+ // Tensor.
+ // TODO(mrry):
+ // Convert this to return Status.
+ Tensor mutable_input(int index, bool lock_held);
+
+ // Returns the named mutable input tensor in "tensor", as defined in
+ // the OpDef. Must be used to access Ref inputs. The values stored
+ // in the Tensor buffer may be modified, and modifications will be
+ // visible to other Ops reading the same ref tensor. If !lock_held
+ // the input mutex will be acquired before returning the Tensor.
+ // REQUIRES: the named input must not be a list.
+ // REQUIRES: the named input must be a ref tensor.
+ Status mutable_input(const string& name, Tensor* tensor, bool lock_held);
+
+ // Returns the named list-valued mutable input in "list", as defined
+ // in the OpDef. If the named intput is not list-valued, returns a
+ // one-element list. Must be used to access Ref inputs. The values
+ // stored in the Tensor buffer may be modified, and modifications
+ // will be visible to other Ops reading the same ref tensor.
+ // REQUIRES: the named input must be a ref tensor.
+ Status mutable_input_list(const string& name, OpMutableInputList* list);
+
+ // Replace the corresponding Ref Input to use the storage buffer
+ // used by tensor. If !lock_held the input mutex will be acquired
+ // before returning the Tensor.
+ // REQUIRES: IsRefType(input_dtype(index)).
+ void replace_ref_input(int index, const Tensor& tensor, bool lock_held);
+
+ // Replace the corresponding named Ref Input to use the storage
+ // buffer used by tensor. If !lock_held the input mutex will be
+ // acquired before returning the Tensor.
+ // REQUIRES: IsRefType(input_dtype(index)).
+ Status replace_ref_input(const string& name, const Tensor& tensor,
+ bool lock_held);
+
+ // Set the output Ref Tensor at output_index to be an alias of the
+ // input Ref Tensor at input_index.
+ // REQUIRES: IsRefType(input_dtype(input_index)).
+ // REQUIRES: IsRefType(output_dtype(output_index)).
+ void forward_ref_input_to_ref_output(int input_index, int output_index);
+
+ // Deletes the Tensor object used as the Ref Input at
+ // input_index. This is not usually necessary and should be used
+ // with caution. If !lock_held the input mutex will be acquired
+ // before returning the Tensor.
+ // REQUIRES: IsRefType(input_dtype(input_index)).
+ void delete_ref_input(int input_index, bool lock_held);
+
+ // Return true if there is input at the given index. An operator has no
+ // input at index if its tensor is null. This is primarily used by the
+ // merge operator.
+ // TODO(mrry): Convert this to return Status.
+ bool has_input(int index) const;
+
+ // Returns true if all inputs are the same shape, otherwise sets the
+ // status to a non-OK value and returns false.
+ // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
+ bool ValidateInputsAreSameShape(OpKernel* op);
+
+ // Output
+
+ // Returns the named list-valued output in "list", as defined in the OpDef.
+ // If the named output is not list-valued, returns a one-element list.
+ Status output_list(const string& name, OpOutputList* list);
+
+ // If output_required(index) returns true, the OpKernel's Compute() method
+ // should call allocate_output(index, ...), set_output(index, ...),
+ // set_output_ref(index, ...), or set the status to a non-ok value.
+ // If it returns false, it may output, but is not required to do so.
+ // TODO(mrry): Convert this to return Status, and implement a string
+ // name version.
+ bool output_required(int index) const {
+ return true; // TODO(josh11b): implement
+ }
+
+ // Allocation of tensors during kernel execution inside the Compute
+ // method:
+ //
+ // There are three methods to allocate Tensors when an Op kernel
+ // executes.
+ //
+ // 1) allocate_persistent. This is only needed for Tensors that will
+ // be stored by the Op between invocations, and it *must* be used
+ // for those Tensors. The call returns a PersistentTensor, and that
+ // is the only object the Op is allowed to hold on to between
+ // invocations. When the Tensor is needed in a subsequent
+ // invocation, it can be retrieved from the PersistentTensor using
+ // the AccessTensor method. This ensures that the system is made
+ // aware of any use of the tensor's allocated memory, which is
+ // needed for correctness on asynchronous devices such as GPUs.
+ //
+ // 2) allocate_output. This should be used to allocate any tensor
+ // that is going to be used as an output from the Op at the end of
+ // the current execution. The caller indicates which output the
+ // Tensor will be assigned to, and the call returns the
+ // newly-allocated Tensor. The Tensor can subsequently be assigned
+ // to during kernel execution, and will be used as the designated
+ // output when the kernel execution completes.
+ //
+ // 3) allocate_temp. This should be used to allocate any scratch
+ // storage that is needed while the kernel is executing, and will
+ // not be retained by the Op.
+ //
+ // In some cases a Tensor needs to be used as an output even though
+ // it was previously allocated elsewhere. The Tensor may have been
+ // passed as an input, or stored in a PersistentTensor during a
+ // previous kernel execution, or allocated earlier in the kernel
+ // execution at a time when it was not known which output it would
+ // be assigned to. In this case the kernel can use set_output or
+ // set_output_ref to indicate that the tensor should be used as the
+ // designated output. It is legal to use any previously-allocated
+ // Tensor as an argument to set_output or set_output_ref, including
+ // Tensors allocated via allocate_temp. There may be a performance
+ // penalty to using a Tensor that was not allocated using
+ // allocate_output. This is because allocate_output uses the
+ // AllocatorAttributes stored in output_alloc_attr for the
+ // designated output. In some cases, using the wrong attributes may
+ // cause an extra copy of the Tensor's buffer.
+
+ // Allocates output for the specified output index with shape.
+ // OpKernelContext retains ownership of the returned pointer. See
+ // comment above.
+ //
+ // If memory allocation fails, returns an error status.
+ //
+ // REQUIRES: !IsRefType(expected_output_dtype(index))
+ Status allocate_output(int index, const TensorShape& shape,
+ Tensor** tensor) TF_MUST_USE_RESULT;
+ Status allocate_output(const string& name, const TensorShape& shape,
+ Tensor** tensor) TF_MUST_USE_RESULT;
+ // The following methods use the supplied attributes instead of
+ // those in output_alloc_attr. The caller is responsible for
+ // ensuring that the attributes are "compatible" with the
+ // output_alloc_attr, e.g. the tensor is allocated on the correct
+ // device. See comment above.
+ Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
+ AllocatorAttributes attr) TF_MUST_USE_RESULT;
+ Status allocate_output(const string& name, const TensorShape& shape,
+ Tensor** tensor,
+ AllocatorAttributes attr) TF_MUST_USE_RESULT;
+
+ // Allocates a temporary Tensor of the specified type and
+ // shape. Devices such as GPUs that enqueue Ops for lazy execution
+ // may retain references to the temporary tensors after the Op's
+ // Compute method has run. See comment above.
+ Status allocate_temp(DataType type, const TensorShape& shape,
+ Tensor* out_temp, AllocatorAttributes attr);
+ Status allocate_temp(DataType type, const TensorShape& shape,
+ Tensor* out_temp) {
+ return allocate_temp(type, shape, out_temp, AllocatorAttributes());
+ }
+
+ // Allocates a Tensor of the specified type and shape which the Op
+ // plans to maintain as persistent state. out_persistent holds the
+ // PersistentTensor which is the object the caller should store. For
+ // convenience, if out_tensor is non-null then it will be filled in
+ // with a Tensor* pointing to the newly-allocated tensor which the
+ // caller can use instead of calling
+ // out_persistent->AccessTensor. The caller does not own out_tensor
+ // and should not keep a copy of it. See comment above.
+ Status allocate_persistent(DataType type, const TensorShape& shape,
+ PersistentTensor* out_persistent,
+ Tensor** out_tensor, AllocatorAttributes attr);
+ Status allocate_persistent(DataType type, const TensorShape& shape,
+ PersistentTensor* out_persistent,
+ Tensor** out_tensor) {
+ return allocate_persistent(type, shape, out_persistent, out_tensor,
+ AllocatorAttributes());
+ }
+
+ // Copies a tensor (allocated by the caller) to the specified output
+ // index. REQUIRES: !IsRefType(expected_output_dtype(index))
+ // REQUIRES: 'tensor' must have the same MemoryType as
+ // output_memory_types[index]. See comment above.
+ // TODO(mrry): Convert this to return Status.
+ void set_output(int index, const Tensor& tensor);
+ Status set_output(const string& name, const Tensor& tensor);
+
+ // To output a reference. Caller retains ownership of mu and tensor_for_ref,
+ // and they must outlive all uses within the step. See comment above.
+ // REQUIRES: IsRefType(expected_output_dtype(index))
+ // TODO(mrry): Convert this to return Status.
+ void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
+ Status set_output_ref(const string& name, mutex* mu, Tensor* tensor_for_ref);
+
+ // Returns nullptr if allocate_output() or set_output() have not been called.
+ // TODO(mrry): Convert this to return Status.
+ Tensor* mutable_output(int index);
+ Status mutable_output(const string& name, Tensor** tensor);
+
+ // Transfers ownership of an output tensor to the caller.
+ // NOTE: For non-reference outputs, the caller takes responsibility
+ // for deletion. For reference outputs, the caller does NOT take
+ // responsibility for deletion.
+ // TODO(mrry): Convert this to return Status.
+ TensorValue release_output(int index);
+ Status release_output(const string& name, TensorValue* value);
+
+ // Records device specific state about how the input tensors were
+ // computed.
+ //
+ // If using the templated function, the type must be a subclass
+ // of DeviceContext.
+ //
+ // Get the DeviceContext used for the index input. Returns nullptr
+ // if no DeviceContext was provided.
+ template <typename T>
+ T* input_device_context(int index);
+ DeviceContext* input_device_context(int index);
+
+ // Return the DeviceContext that should be used for this Op.
+ //
+ // If using the templated function, the type must be a subclass
+ // of DeviceContext.
+ //
+ // Returns nullptr if the device did not provide one.
+ template <typename T>
+ T* op_device_context();
+ DeviceContext* op_device_context() {
+ DeviceContext* ret = params_.op_device_context;
+ if (ret == nullptr) {
+ auto* dev_info = device()->tensorflow_gpu_device_info();
+ if (dev_info) ret = dev_info->default_context;
+ }
+ return ret;
+ }
+
+ AllocatorAttributes input_alloc_attr(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.input_alloc_attrs->size());
+ return (*params_.input_alloc_attrs)[index];
+ }
+
+ AllocatorAttributes output_alloc_attr(int index) const {
+ return params_.output_alloc_attr(index);
+ }
+
+ gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators() const {
+ mutex_lock lock(mu_);
+ gtl::InlinedVector<WrappedAllocator, 4> retrieved = wrapped_allocators_;
+ return retrieved;
+ }
+
+ // Communication.
+ //
+ // An op kernel communicates with outside environment through
+ // Rendezvous Send() and Recv().
+ Rendezvous* rendezvous() const { return params_.rendezvous; }
+
+ // Function call support.
+ //
+ // If this kernel invocation is within a function execution,
+ // call_frame() returns the call frame for the function call.
+ FunctionCallFrame* call_frame() const { return params_.call_frame; }
+
+ // If not nullptr, the kernel invoke functions defined in the
+ // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...).
+ FunctionLibraryRuntime* function_library() const {
+ return params_.function_library;
+ }
+
+ // Shared resources accessible to this kernel.
+ ResourceMgr* resource_manager() const { return params_.resource_manager; }
+
+ checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const {
+ return params_.slice_reader_cache;
+ }
+
+ // Execution.
+ //
+ // OpKernels can use these eigen devices to carry out their
+ // numerical computation.
+ const Eigen::ThreadPoolDevice& eigen_cpu_device() const {
+ return *device()->eigen_cpu_device();
+ }
+ const Eigen::GpuDevice& eigen_gpu_device() const {
+ return eigen_gpu_device_->device();
+ }
+ template <typename EigenDeviceType>
+ const EigenDeviceType& eigen_device() const;
+
+ // Error handling.
+
+ // If expected_inputs == inputs() and expected_outputs == output_types(),
+ // returns OK, else returns INVALID_ARGUMENT with an error message.
+ // Recommended for Ops with dynamic signatures, where validation can only
+ // be performed at runtime.
+ Status MatchSignature(const DataTypeSlice expected_inputs,
+ const DataTypeSlice expected_outputs);
+
+ // An OpKernel should call SetStatus() if Compute() encounters an
+ // error.
+ void SetStatus(const Status& status) { status_.Update(status); }
+ const Status& status() const { return status_; }
+
+ // Cancellation.
+ //
+ // EXPERIMENTAL. See the implementation in tensorflow::TensorQueue for an
+ // example of how to use this API.
+ CancellationManager* cancellation_manager() const {
+ return params_.cancellation_manager;
+ }
+
+ // Other accessors.
+
+ // For control flow.
+ FrameAndIter frame_iter() const { return params_.frame_iter; }
+ bool is_input_dead() const { return params_.is_input_dead; }
+ bool* is_output_dead() { return &is_output_dead_; }
+
+ // May be used, e.g., to get GPU handles, etc.
+ // TODO(tucker): Add example usage.
+ DeviceBase* device() const { return params_.device; }
+
+ // Access to list of temporary tensors.
+ int num_temps();
+ Tensor* temp(int index);
+
+ // Access to information about whether each output was newly
+ // allocated or copied from an existing tensor
+ AllocationType output_allocation_type(int index) const {
+ return output_allocation_types_[index];
+ }
+
+ private:
+ Allocator* get_allocator(AllocatorAttributes attr) {
+ Allocator* allocator = params_.device->GetAllocator(attr);
+ if (params_.track_allocations) {
+ mutex_lock lock(mu_);
+ for (const auto& wrapped : wrapped_allocators_) {
+ if (wrapped.first == allocator) {
+ return wrapped.second;
+ }
+ }
+ TrackingAllocator* wrapped_allocator = new TrackingAllocator(allocator);
+ wrapped_allocators_.push_back(
+ std::make_pair(allocator, wrapped_allocator));
+ return wrapped_allocator;
+ } else {
+ return allocator;
+ }
+ }
+
+ // Per-step resource manager for use by white-listed internal ops.
+ friend class TemporaryVariableOp;
+ friend class DestroyTemporaryVariableOp;
+ ResourceMgr* step_resource_manager() const {
+ return params_.step_resource_manager;
+ }
+
+ // Internal common method used when allocating tensor memory
+ Status allocate_tensor(DataType type, const TensorShape& shape,
+ Tensor* out_tensor, AllocatorAttributes attr);
+
+ // This is called by PersistentTensor::AccessTensor whenever the
+ // wrapped tensor is retrieved, to ensure the runtime knows that the
+ // Tensor is being accessed within an Op. This is necessary for
+ // memory safety of devices like GPUs that queue Ops for
+ // asynchronous execution after the Compute() method completes.
+ friend class PersistentTensor;
+ void NotifyUseOfPersistentTensor(const Tensor& tensor);
+
+ Status status_;
+ Params params_; // immutable after construction.
+ const PerOpGpuDevice* eigen_gpu_device_; // owned, with a per-op
+ // wrapped allocator
+ mutable mutex mu_; // mutable so const accessors can acquire the lock
+ gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators_ GUARDED_BY(mu_);
+ gtl::InlinedVector<TensorValue, 4> outputs_;
+ gtl::InlinedVector<AllocationType, 4> output_allocation_types_;
+ gtl::InlinedVector<Tensor*, 4> temp_tensors_;
+ bool is_output_dead_ = false;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(OpKernelContext);
+};
+
+// Register your OpKernel by specifying the Op's name, the device the
+// kernel runs on, any type attr constraints for this kernel, any
+// host-memory args, and the class to instantiate. Examples:
+//
+// // A kernel that supports all types.
+// REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
+//
+// // The following are equivalent ways of specifying that the kernel only
+// // works if the "T" type attr is set to DT_FLOAT.
+// REGISTER_KERNEL_BUILDER(
+// Name("Sub").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+// SubOp<float>);
+// // (You would then repeat this for every type supported by "Sub".)
+//
+// // This form allows you to specify a list of types as the constraint.
+// REGISTER_KERNEL_BUILDER(Name("Sub")
+// .Device(DEVICE_CPU)
+// .TypeConstraint("T", {DT_FLOAT}),
+// SubOp<float>);
+//
+// // A kernel that expects one of the input tensors in host memory.
+// REGISTER_KERNEL_BUILDER(
+// Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp);
+//
+// See kernel_def_builder for details.
+
+// Instantiate an OpKernel that has been registered. Returns nullptr
+// if no operation for that type of device / input signature combination
+// (and a NOT_FOUND *status), or there is an error in construction (and
+// an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership
+// of the returned pointer.
+// EXPECTED USAGE: unique_ptr<OpKernel> op = CreateOpKernel(...);
+// REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
+std::unique_ptr<OpKernel> CreateOpKernel(DeviceType device_type,
+ DeviceBase* device,
+ Allocator* allocator,
+ const NodeDef& def, Status* status);
+Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
+ Allocator* allocator, FunctionLibraryRuntime* flib,
+ const NodeDef& def, OpKernel** kernel);
+
+// Returns into 'device_types' the subset of prioritized_types that this
+// binary has registered for the given NodeDef.
+//
+// REQUIRES: * 'device_types' is not nullptr.
+// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
+Status SupportedDeviceTypesForNode(
+ const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
+ DeviceTypeVector* device_types);
+
+// Returns into *{input,output}_memory_types the memory type of each
+// {input,output} tensor.
+//
+// REQUIRES: * '*_memory_types' is not nullptr.
+// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
+Status MemoryTypesForNode(DeviceType device_type, const NodeDef& ndef,
+ const OpDef& op_def,
+ const NameRangeMap& input_name_map,
+ const NameRangeMap& output_name_map,
+ MemoryTypeVector* input_memory_types,
+ MemoryTypeVector* output_memory_types);
+
+Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
+ DeviceType device_type, const NodeDef& ndef,
+ MemoryTypeVector* input_memory_types,
+ MemoryTypeVector* output_memory_types);
+
+// Call once after Op registration has completed.
+Status ValidateKernelRegistrations(const OpRegistryInterface* op_registry);
+
+// -----------------------------------------------------------------------------
+// OpKernel registration implementation follows, please ignore.
+
+// Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax.
+namespace register_kernel {
+typedef ::tensorflow::KernelDefBuilder Name;
+} // namespace register_kernel
+
+#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
+ REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
+
+#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
+ REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
+
+#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
+ static ::tensorflow::kernel_factory::OpKernelRegistrar \
+ registrar__body__##ctr##__object( \
+ ::tensorflow::register_kernel::kernel_builder.Build(), \
+ +[](::tensorflow::OpKernelConstruction* context) \
+ -> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); })
+
+namespace kernel_factory {
+
+class OpKernelRegistrar {
+ public:
+ typedef OpKernel* (*Factory)(OpKernelConstruction*);
+ OpKernelRegistrar(const KernelDef* kernel_def, Factory factory);
+};
+
+} // namespace kernel_factory
+
+// -----------------------------------------------------------------------------
+// Template and inline method implementations, please ignore
+
+inline DataType OpKernelContext::input_dtype(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.inputs->size());
+ const TensorValue& value((*params_.inputs)[index]);
+ if (value.is_ref()) {
+ return MakeRefType(value->dtype());
+ } else {
+ return value->dtype();
+ }
+}
+
+inline DataType OpKernelContext::expected_output_dtype(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.op_kernel->output_types().size());
+ return params_.op_kernel->output_type(index);
+}
+
+inline const Tensor& OpKernelContext::input(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.inputs->size());
+ DCHECK(!(*params_.inputs)[index].is_ref());
+ return *((*params_.inputs)[index].tensor);
+}
+
+inline Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.inputs->size());
+ DCHECK((*params_.inputs)[index].is_ref());
+ // return a copy of the Ref acquired while holding the mutex
+ if (lock_held) {
+ return *((*params_.inputs)[index].tensor);
+ } else {
+ mutex_lock l(*input_ref_mutex(index));
+ return *((*params_.inputs)[index].tensor);
+ }
+}
+
+inline void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
+ bool lock_held) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.inputs->size());
+ DCHECK((*params_.inputs)[index].is_ref());
+ // should only modify the tensor while holding the mutex
+ if (lock_held) {
+ *(*params_.inputs)[index].tensor = tensor;
+ } else {
+ mutex_lock l(*input_ref_mutex(index));
+ *(*params_.inputs)[index].tensor = tensor;
+ }
+}
+
+inline void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
+ int output_index) {
+ DCHECK_GE(input_index, 0);
+ DCHECK_LT(input_index, params_.inputs->size());
+ DCHECK((*params_.inputs)[input_index].is_ref());
+ set_output_ref(output_index, (*params_.inputs)[input_index].mutex_if_ref,
+ (*params_.inputs)[input_index].tensor);
+}
+
+inline void OpKernelContext::delete_ref_input(int index, bool lock_held) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.inputs->size());
+ DCHECK((*params_.inputs)[index].is_ref());
+ // should only modify the tensor while holding the mutex
+ if (lock_held) {
+ delete (*params_.inputs)[index].tensor;
+ } else {
+ mutex_lock l(*input_ref_mutex(index));
+ delete (*params_.inputs)[index].tensor;
+ }
+}
+
+// no input if tensor == nullptr.
+inline bool OpKernelContext::has_input(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.inputs->size());
+ return (*params_.inputs)[index].tensor != nullptr;
+}
+
+inline mutex* OpKernelContext::input_ref_mutex(int index) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.inputs->size());
+ DCHECK((*params_.inputs)[index].is_ref());
+ return (*params_.inputs)[index].mutex_if_ref;
+}
+
+inline Status OpKernelContext::allocate_output(int index,
+ const TensorShape& shape,
+ Tensor** output) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, num_outputs());
+ DCHECK(params_.output_alloc_attr);
+ AllocatorAttributes attr = params_.output_alloc_attr(index);
+ return allocate_output(index, shape, output, attr);
+}
+
+inline Status OpKernelContext::allocate_tensor(DataType type,
+ const TensorShape& shape,
+ Tensor* out_tensor,
+ AllocatorAttributes attr) {
+ Allocator* a = get_allocator(attr);
+ Tensor new_tensor(a, type, shape);
+
+ if (!new_tensor.IsInitialized() && shape.num_elements() > 0) {
+ return errors::ResourceExhausted("OOM when allocating tensor with shape",
+ shape.DebugString());
+ }
+ *out_tensor = new_tensor;
+ return Status::OK();
+}
+
+inline Status OpKernelContext::allocate_output(int index,
+ const TensorShape& shape,
+ Tensor** output,
+ AllocatorAttributes attr) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, outputs_.size());
+ // Record the fact that this output tensor was allocated by the Op
+ DCHECK_LT(index, output_allocation_types_.size());
+ output_allocation_types_[index] = AT_ALLOCATED;
+ const DataType type = params_.op_kernel->output_type(index);
+ DCHECK(!IsRefType(type));
+ DCHECK(mutable_output(index) == nullptr);
+ Tensor* output_tensor = new Tensor();
+ Status s = allocate_tensor(type, shape, output_tensor, attr);
+ if (s.ok()) {
+ outputs_[index] = TensorValue(output_tensor);
+ *output = outputs_[index].tensor;
+ }
+ return s;
+}
+
+inline Status OpKernelContext::allocate_temp(DataType type,
+ const TensorShape& shape,
+ Tensor* out_temp,
+ AllocatorAttributes attr) {
+ Status s = allocate_tensor(type, shape, out_temp, attr);
+ if (s.ok()) {
+ if (params_.device->SaveTemporaryTensors()) {
+ // keep a reference to the underlying memory around
+ temp_tensors_.push_back(new Tensor(*out_temp));
+ }
+ }
+ return s;
+}
+
+inline Status OpKernelContext::allocate_persistent(
+ DataType type, const TensorShape& shape, PersistentTensor* out_persistent,
+ Tensor** out_tensor, AllocatorAttributes attr) {
+ // TODO(misard) add specific memory tracking for persistent tensors
+ Tensor persistent;
+ Status s = allocate_tensor(type, shape, &persistent, attr);
+ if (s.ok()) {
+ *out_persistent = PersistentTensor(persistent);
+ // This call saves a reference to the newly-allocated tensor if we
+ // are saving temporary tensors
+ Tensor* allocated = out_persistent->AccessTensor(this);
+ if (out_tensor) {
+ *out_tensor = allocated;
+ }
+ }
+ return s;
+}
+
+inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) {
+ if (t.IsInitialized() && params_.device->SaveTemporaryTensors()) {
+ // keep a reference to the underlying memory around
+ temp_tensors_.push_back(new Tensor(t));
+ }
+}
+
+inline int OpKernelContext::num_temps() { return temp_tensors_.size(); }
+
+inline Tensor* OpKernelContext::temp(int index) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, temp_tensors_.size());
+ return temp_tensors_[index];
+}
+
+inline void OpKernelContext::set_output(int index, const Tensor& tensor) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, outputs_.size());
+ // Record the fact that this output tensor was set by the Op
+ DCHECK_LT(index, output_allocation_types_.size());
+ output_allocation_types_[index] = AT_EXISTING;
+ DCHECK(!IsRefType(params_.op_kernel->output_type(index)));
+ DCHECK_EQ(mutable_output(index), nullptr);
+ outputs_[index] = TensorValue(new Tensor(tensor));
+}
+
+inline void OpKernelContext::set_output_ref(int index, mutex* mu,
+ Tensor* tensor_for_ref) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, outputs_.size());
+ // Record the fact that this output tensor was set by reference the Op
+ DCHECK_LT(index, output_allocation_types_.size());
+ output_allocation_types_[index] = AT_REF;
+ DCHECK(IsRefType(params_.op_kernel->output_type(index)));
+ outputs_[index] = TensorValue(mu, tensor_for_ref);
+}
+
+inline Tensor* OpKernelContext::mutable_output(int index) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, outputs_.size());
+ return outputs_[index].tensor;
+}
+
+inline TensorValue OpKernelContext::release_output(int index) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, outputs_.size());
+ TensorValue value = outputs_[index];
+ outputs_[index] = TensorValue();
+ return value;
+}
+
+template <typename T>
+T* OpKernelContext::op_device_context() {
+ static_assert(std::is_base_of<DeviceContext, T>::value,
+ "T is not a subclass of DeviceContext");
+ return static_cast<T*>(op_device_context());
+}
+
+template <typename T>
+T* OpKernelContext::input_device_context(int index) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.input_device_contexts->size());
+ static_assert(std::is_base_of<DeviceContext, T>::value,
+ "T is not a subclass of DeviceContext");
+ return static_cast<T*>((*params_.input_device_contexts)[index]);
+}
+
+inline DeviceContext* OpKernelContext::input_device_context(int index) {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, params_.input_device_contexts->size());
+ return (*params_.input_device_contexts)[index];
+}
+
+inline const Tensor& OpInputList::operator[](int i) const {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, stop_ - start_);
+ return ctx_->input(start_ + i);
+}
+
+inline mutex* OpMutableInputList::ref_mutex(int i) {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, stop_ - start_);
+ return ctx_->input_ref_mutex(start_ + i);
+}
+
+inline Tensor OpMutableInputList::at(int i, bool lock_held) {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, stop_ - start_);
+ return ctx_->mutable_input(start_ + i, lock_held);
+}
+
+inline Tensor* OpOutputList::operator[](int i) {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, stop_ - start_);
+ return ctx_->mutable_output(start_ + i);
+}
+
+inline bool OpOutputList::required(int i) const {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, stop_ - start_);
+ return ctx_->output_required(start_ + i);
+}
+
+inline Status OpOutputList::allocate(int i, const TensorShape& shape,
+ Tensor** output) {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, stop_ - start_);
+ return ctx_->allocate_output(start_ + i, shape, output);
+}
+
+inline void OpOutputList::set(int i, const Tensor& tensor) {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, stop_ - start_);
+ ctx_->set_output(start_ + i, tensor);
+}
+
+inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, stop_ - start_);
+ ctx_->set_output_ref(i, mu, tensor_for_ref);
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc
new file mode 100644
index 0000000000..9400ef24f8
--- /dev/null
+++ b/tensorflow/core/framework/op_kernel_test.cc
@@ -0,0 +1,803 @@
+#include "tensorflow/core/framework/op_kernel.h"
+
+#include <memory>
+#include <vector>
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include <gtest/gtest.h>
+
+class DummyKernel : public tensorflow::OpKernel {
+ public:
+ explicit DummyKernel(tensorflow::OpKernelConstruction* context)
+ : OpKernel(context) {}
+ void Compute(tensorflow::OpKernelContext* context) override {}
+};
+
+// Test that registration works outside a namespace.
+REGISTER_OP("Test1").Input("a: float").Input("b: int32").Output("o: uint8");
+REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU),
+ DummyKernel);
+
+namespace foo {
+bool match_signature_ = false;
+
+// Test that registration works inside a different namespace.
+class TestOp2 : public ::tensorflow::OpKernel {
+ public:
+ explicit TestOp2(::tensorflow::OpKernelConstruction* context)
+ : OpKernel(context) {
+ ::tensorflow::Status status = context->MatchSignature(
+ {::tensorflow::DT_INT32}, {::tensorflow::DT_INT32});
+ match_signature_ = status.ok();
+ context->SetStatus(status);
+ }
+ void Compute(::tensorflow::OpKernelContext* context) override {}
+};
+
+REGISTER_OP("Test2").Input("i: T").Output("o: T").Attr("T: type");
+REGISTER_KERNEL_BUILDER(Name("Test2")
+ .Device(::tensorflow::DEVICE_GPU)
+ .HostMemory("i")
+ .HostMemory("o"),
+ TestOp2);
+} // namespace foo
+
+namespace tensorflow {
+
+// Two operations with the same name but different devices.
+REGISTER_OP("Test3").Input("a: T").Input("b: T").Attr("T: type");
+
+class TestOp3Cpu : public tensorflow::OpKernel {
+ public:
+ explicit TestOp3Cpu(OpKernelConstruction* context) : OpKernel(context) {}
+ void Compute(OpKernelContext* context) override {}
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("Test3").Device(DEVICE_CPU).TypeConstraint<int8>("T"), TestOp3Cpu);
+
+namespace {
+
+class TestOp3Gpu : public tensorflow::OpKernel {
+ public:
+ explicit TestOp3Gpu(OpKernelConstruction* context) : OpKernel(context) {}
+ void Compute(OpKernelContext* context) override {}
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("Test3").Device(DEVICE_GPU).TypeConstraint<float>("T"), TestOp3Cpu);
+
+// An Op registered for both
+REGISTER_OP("Test4").Input("i: float").Output("o: float");
+REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel);
+REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel);
+
+static std::vector<DeviceType> DeviceTypes() {
+ return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)};
+}
+
+class OpKernelTest : public ::testing::Test {
+ public:
+ OpKernelTest() : device_(Env::Default()) {}
+
+ protected:
+ NodeDef CreateNodeDef(const string& op_type, const DataTypeVector& inputs) {
+ NodeDefBuilder builder(op_type + "-op", op_type);
+ for (DataType dt : inputs) {
+ builder.Input(FakeInput(dt));
+ }
+ NodeDef node_def;
+ TF_CHECK_OK(builder.Finalize(&node_def));
+ return node_def;
+ }
+
+ void ExpectEqual(const string& what, const DataTypeVector& expected,
+ const DataTypeVector& observed) {
+ EXPECT_EQ(expected.size(), observed.size()) << what;
+ const int size = std::min(expected.size(), observed.size());
+ for (int i = 0; i < size; ++i) {
+ bool match = TypesCompatible(expected[i], observed[i]);
+ EXPECT_TRUE(match) << what << " i:" << i << ", expected: " << expected[i]
+ << ", observed: " << observed[i];
+ }
+ }
+
+ void ExpectSuccess(const string& op_type, DeviceType device_type,
+ const DataTypeVector& inputs,
+ const DataTypeVector& outputs) {
+ Status status;
+ std::unique_ptr<OpKernel> op(
+ CreateOpKernel(device_type, &device_, cpu_allocator(),
+ CreateNodeDef(op_type, inputs), &status));
+ EXPECT_TRUE(status.ok()) << status;
+ EXPECT_TRUE(op != nullptr);
+ if (op != nullptr) {
+ ExpectEqual("inputs", op->input_types(), inputs);
+ ExpectEqual("outputs", op->output_types(), outputs);
+ }
+ }
+
+ void ExpectFailure(const string& ascii_node_def, DeviceType device_type,
+ error::Code code) {
+ NodeDef node_def;
+ protobuf::TextFormat::ParseFromString(ascii_node_def, &node_def);
+ Status status;
+ std::unique_ptr<OpKernel> op(CreateOpKernel(
+ device_type, &device_, cpu_allocator(), node_def, &status));
+ EXPECT_TRUE(op == nullptr);
+ EXPECT_FALSE(status.ok());
+ if (!status.ok()) {
+ LOG(INFO) << "Status message: " << status.error_message();
+ EXPECT_EQ(code, status.code());
+ }
+ }
+
+ private:
+ DeviceBase device_;
+};
+
+TEST_F(OpKernelTest, SuccessCpu) {
+ ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT, DT_INT32}, {DT_UINT8});
+ ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT_REF, DT_INT32}, {DT_UINT8});
+}
+
+TEST_F(OpKernelTest, SuccessGpu) {
+ foo::match_signature_ = false;
+ ExpectSuccess("Test2", DEVICE_GPU, {DT_INT32}, {DT_INT32});
+ EXPECT_TRUE(foo::match_signature_);
+}
+
+TEST_F(OpKernelTest, SuccessBothCpuAndGpu) {
+ ExpectSuccess("Test3", DEVICE_CPU, {DT_INT8, DT_INT8}, {});
+ ExpectSuccess("Test3", DEVICE_GPU, {DT_FLOAT, DT_FLOAT}, {});
+}
+
+TEST_F(OpKernelTest, CpuTypeRegistered) {
+ NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
+ DeviceTypeVector devs;
+ ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
+ EXPECT_EQ(1, devs.size());
+ EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]);
+}
+
+TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) {
+ {
+ // Try a node def of an op that is registered for a specific type
+ // only on CPU.
+ NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8});
+ DeviceTypeVector devs;
+ ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
+ EXPECT_EQ(1, devs.size());
+ EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]);
+ }
+ {
+ // Try a node def of an op that is registered for a specific type
+ // only on GPU.
+ NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT});
+ DeviceTypeVector devs;
+ ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
+ EXPECT_EQ(1, devs.size());
+ EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]);
+ }
+ {
+ // Try a node def of an op that is only registered for other types.
+ NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING});
+ DeviceTypeVector devs;
+ ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
+ EXPECT_EQ(0, devs.size());
+ }
+
+ {
+ // Try a node def of an op that is registered for both.
+ NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT});
+ DeviceTypeVector devs;
+ ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
+ EXPECT_EQ(2, devs.size());
+ EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]);
+ EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1]);
+ }
+}
+
+TEST_F(OpKernelTest, NotFound) {
+ const auto not_found = error::NOT_FOUND;
+ // Something with that op type name exists, but only with a
+ // different DeviceType.
+ ExpectFailure(CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}).DebugString(),
+ DEVICE_GPU, not_found);
+ ExpectFailure(CreateNodeDef("Test3", {DT_INT8, DT_INT8}).DebugString(),
+ DEVICE_GPU, not_found);
+ ExpectFailure(CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}).DebugString(),
+ DEVICE_CPU, not_found);
+
+ // No kernel with that signature registered.
+ ExpectFailure(CreateNodeDef("Test3", {DT_INT32, DT_INT32}).DebugString(),
+ DEVICE_GPU, not_found);
+
+ // Nothing with that op type name exists.
+ ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_CPU, not_found);
+ ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_GPU, not_found);
+}
+
+TEST_F(OpKernelTest, TooFewInputs) {
+ const auto invalid = error::INVALID_ARGUMENT;
+ NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
+ node_def.clear_input();
+ ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
+ node_def.add_input("a");
+ ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
+}
+
+TEST_F(OpKernelTest, TooManyInputs) {
+ const auto invalid = error::INVALID_ARGUMENT;
+ NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
+ node_def.add_input("c");
+ ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
+}
+
+TEST_F(OpKernelTest, MatchSignatureFailes) {
+ const auto invalid = error::INVALID_ARGUMENT;
+ foo::match_signature_ = true;
+ ExpectFailure(CreateNodeDef("Test2", {DT_FLOAT}).DebugString(), DEVICE_GPU,
+ invalid);
+ EXPECT_FALSE(foo::match_signature_);
+}
+
+class DummyDevice : public DeviceBase {
+ public:
+ DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
+ bool SaveTemporaryTensors() const override { return save_; }
+ Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
+ return cpu_allocator();
+ }
+
+ private:
+ bool save_;
+};
+
+TEST_F(OpKernelTest, SaveTempFalse) {
+ Env* env = Env::Default();
+ OpKernelContext::Params params;
+ params.device = new DummyDevice(env, false);
+ Status status;
+ std::unique_ptr<OpKernel> op(
+ CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
+ CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), &status));
+ EXPECT_TRUE(status.ok());
+ params.op_kernel = op.get();
+ OpKernelContext* ctx = new OpKernelContext(params);
+
+ Tensor t;
+ EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
+
+ EXPECT_EQ(0, ctx->num_temps());
+
+ delete ctx;
+ delete params.device;
+}
+
+TEST_F(OpKernelTest, SaveTempTrue) {
+ Env* env = Env::Default();
+ OpKernelContext::Params params;
+ params.device = new DummyDevice(env, true);
+ Status status;
+ std::unique_ptr<OpKernel> op(
+ CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
+ CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), &status));
+ EXPECT_TRUE(status.ok());
+ params.op_kernel = op.get();
+ OpKernelContext* ctx = new OpKernelContext(params);
+
+ Tensor t;
+ EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
+
+ EXPECT_EQ(1, ctx->num_temps());
+
+ delete ctx;
+ delete params.device;
+}
+
+class OpKernelBuilderTest : public ::testing::Test {
+ protected:
+ // Each attr is described by a "name|type|value".
+ NodeDef CreateNodeDef(const string& op_type,
+ const std::vector<string>& attrs) {
+ NodeDef node_def;
+ node_def.set_name(op_type + "-op");
+ node_def.set_op(op_type);
+ for (const string& attr_desc : attrs) {
+ std::vector<string> parts = str_util::Split(attr_desc, '|');
+ CHECK_EQ(parts.size(), 3);
+ AttrValue attr_value;
+ CHECK(ParseAttrValue(parts[1], parts[2], &attr_value)) << attr_desc;
+ node_def.mutable_attr()->insert(
+ AttrValueMap::value_type(parts[0], attr_value));
+ }
+ return node_def;
+ }
+
+ std::unique_ptr<OpKernel> ExpectSuccess(const string& op_type,
+ DeviceType device_type,
+ const std::vector<string>& attrs,
+ DataTypeSlice input_types = {}) {
+ Status status;
+ NodeDef def = CreateNodeDef(op_type, attrs);
+ for (size_t i = 0; i < input_types.size(); ++i) {
+ def.add_input("a:0");
+ }
+
+ Env* env = Env::Default();
+ DeviceBase device(env);
+
+ // Test CreateOpKernel()
+ std::unique_ptr<OpKernel> op(
+ CreateOpKernel(device_type, &device, cpu_allocator(), def, &status));
+ EXPECT_TRUE(status.ok()) << status;
+ EXPECT_TRUE(op != nullptr);
+ if (op != nullptr) {
+ EXPECT_EQ(input_types.size(), op->num_inputs());
+ EXPECT_EQ(0, op->num_outputs());
+ }
+
+ // Test SupportedDeviceTypesForNode()
+ DeviceTypeVector devices;
+ EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
+ bool found = false;
+ for (DeviceType dt : devices) {
+ if (dt == device_type) {
+ found = true;
+ }
+ }
+ EXPECT_TRUE(found) << "Missing " << device_type << " from "
+ << devices.size() << " devices.";
+
+ // In case the caller wants to use the OpKernel
+ return op;
+ }
+
+ void ExpectFailure(const string& op_type, DeviceType device_type,
+ const std::vector<string>& attrs, error::Code code) {
+ Status status;
+ const NodeDef def = CreateNodeDef(op_type, attrs);
+ Env* env = Env::Default();
+ DeviceBase device(env);
+
+ // Test CreateOpKernel().
+ std::unique_ptr<OpKernel> op(
+ CreateOpKernel(device_type, &device, cpu_allocator(), def, &status));
+ EXPECT_TRUE(op == nullptr);
+ EXPECT_FALSE(status.ok());
+ if (!status.ok()) {
+ LOG(INFO) << "Status message: " << status.error_message();
+ EXPECT_EQ(code, status.code());
+
+ // Test SupportedDeviceTypesForNode().
+ DeviceTypeVector devices;
+ if (errors::IsNotFound(status)) {
+ EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
+ for (DeviceType dt : devices) {
+ EXPECT_NE(dt, device_type);
+ }
+ } else {
+ Status status2 =
+ SupportedDeviceTypesForNode(DeviceTypes(), def, &devices);
+ EXPECT_EQ(status.code(), status2.code());
+ }
+ }
+ }
+};
+
+REGISTER_OP("BuildCPU");
+REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel);
+
+TEST_F(OpKernelBuilderTest, BuilderCPU) {
+ ExpectSuccess("BuildCPU", DEVICE_CPU, {});
+ ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND);
+}
+
+REGISTER_OP("BuildGPU");
+REGISTER_KERNEL_BUILDER(Name("BuildGPU").Device(DEVICE_GPU), DummyKernel);
+
+TEST_F(OpKernelBuilderTest, BuilderGPU) {
+ ExpectFailure("BuildGPU", DEVICE_CPU, {}, error::NOT_FOUND);
+ ExpectSuccess("BuildGPU", DEVICE_GPU, {});
+}
+
+REGISTER_OP("BuildBoth");
+REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_CPU), DummyKernel);
+REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_GPU), DummyKernel);
+
+TEST_F(OpKernelBuilderTest, BuilderBoth) {
+ ExpectSuccess("BuildBoth", DEVICE_CPU, {});
+ ExpectSuccess("BuildBoth", DEVICE_GPU, {});
+}
+
+REGISTER_OP("BuildTypeAttr").Attr("T: type");
+REGISTER_KERNEL_BUILDER(Name("BuildTypeAttr")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ DummyKernel);
+
+TEST_F(OpKernelBuilderTest, BuilderTypeAttr) {
+ ExpectSuccess("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_FLOAT"});
+ ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_BOOL"},
+ error::NOT_FOUND);
+ ExpectFailure("BuildTypeAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
+ ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|int|7"},
+ error::INVALID_ARGUMENT);
+}
+
+REGISTER_OP("BuildTypeListAttr").Attr("T: list(type)");
+REGISTER_KERNEL_BUILDER(Name("BuildTypeListAttr")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<bool>("T"),
+ DummyKernel);
+
+TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
+ ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"});
+ ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"});
+ ExpectSuccess("BuildTypeListAttr", DEVICE_CPU,
+ {"T|list(type)|[DT_BOOL, DT_BOOL]"});
+ ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"},
+ error::NOT_FOUND);
+ ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
+ ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
+ error::INVALID_ARGUMENT);
+}
+
+REGISTER_OP("DuplicateKernel");
+REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
+ DummyKernel);
+REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
+ DummyKernel);
+
+TEST_F(OpKernelBuilderTest, DuplicateKernel) {
+ const NodeDef ndef = CreateNodeDef("DuplicateKernel", {});
+ DeviceTypeVector devs;
+ Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message())
+ .contains("Multiple OpKernel registrations match NodeDef"));
+
+ ExpectFailure("DuplicateKernel", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
+}
+
+REGISTER_OP("DuplicateKernelForT").Attr("T: type");
+REGISTER_KERNEL_BUILDER(Name("DuplicateKernelForT")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ DummyKernel);
+REGISTER_KERNEL_BUILDER(Name("DuplicateKernelForT")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ DummyKernel);
+
+TEST_F(OpKernelBuilderTest, DuplicateKernelForT) {
+ const NodeDef ndef =
+ CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"});
+ DeviceTypeVector devs;
+ Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message())
+ .contains("Multiple OpKernel registrations match NodeDef"));
+
+ ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_FLOAT"},
+ error::INVALID_ARGUMENT);
+ ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_BOOL"},
+ error::NOT_FOUND);
+}
+
+REGISTER_OP("BadConstraint").Attr("dtype: type");
+REGISTER_KERNEL_BUILDER(Name("BadConstraint")
+ .Device(DEVICE_CPU)
+ // Mistake: "T" should be "dtype".
+ .TypeConstraint<float>("T"),
+ DummyKernel);
+
+TEST_F(OpKernelBuilderTest, BadConstraint) {
+ const NodeDef ndef = CreateNodeDef("BadConstraint", {});
+ DeviceTypeVector devs;
+ Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message())
+ .contains("OpKernel 'BadConstraint' has constraint on attr "
+ "'T' not in NodeDef"));
+
+ ExpectFailure("BadConstraint", DEVICE_CPU, {"dtype|type|DT_FLOAT"},
+ error::INVALID_ARGUMENT);
+}
+
+class GetAttrKernel : public ::tensorflow::OpKernel {
+ public:
+ explicit GetAttrKernel(OpKernelConstruction* context) : OpKernel(context) {
+ string attr_name;
+ OP_REQUIRES_OK(context, context->GetAttr("attr_name", &attr_name));
+
+ status.emplace_back("s", context->GetAttr(attr_name, &s));
+ status.emplace_back("s_list", context->GetAttr(attr_name, &s_list));
+ status.emplace_back("i", context->GetAttr(attr_name, &i));
+ status.emplace_back("i_list", context->GetAttr(attr_name, &i_list));
+ status.emplace_back("i32", context->GetAttr(attr_name, &i32));
+ status.emplace_back("i32_list", context->GetAttr(attr_name, &i32_list));
+ status.emplace_back("f", context->GetAttr(attr_name, &f));
+ status.emplace_back("f_list", context->GetAttr(attr_name, &f_list));
+ status.emplace_back("b", context->GetAttr(attr_name, &b));
+ status.emplace_back("b_list", context->GetAttr(attr_name, &b_list));
+ status.emplace_back("type", context->GetAttr(attr_name, &type));
+ status.emplace_back("type_list", context->GetAttr(attr_name, &type_list));
+ status.emplace_back("type_vector",
+ context->GetAttr(attr_name, &type_vector));
+ status.emplace_back("shape_proto",
+ context->GetAttr(attr_name, &shape_proto));
+ status.emplace_back("shape_proto_list",
+ context->GetAttr(attr_name, &shape_proto_list));
+ status.emplace_back("shape", context->GetAttr(attr_name, &shape));
+ status.emplace_back("shape_list", context->GetAttr(attr_name, &shape_list));
+ }
+ void Compute(::tensorflow::OpKernelContext* context) override {}
+
+ void ExpectOk(std::initializer_list<string> keys) {
+ for (const auto& key_status : status) {
+ // Only the status for keys in "keys" should be ok().
+ bool in_keys = false;
+ for (const string& key : keys) {
+ if (key_status.first == key) {
+ in_keys = true;
+ }
+ }
+ EXPECT_EQ(in_keys, key_status.second.ok())
+ << "key_status: " << key_status.first << ", " << key_status.second;
+ }
+ }
+
+ string s;
+ std::vector<string> s_list;
+ int64 i;
+ std::vector<int64> i_list;
+ int32 i32;
+ std::vector<int32> i32_list;
+ float f;
+ std::vector<float> f_list;
+ bool b;
+ std::vector<bool> b_list;
+ DataType type;
+ std::vector<DataType> type_list;
+ DataTypeVector type_vector;
+ TensorShapeProto shape_proto;
+ std::vector<TensorShapeProto> shape_proto_list;
+ TensorShape shape;
+ std::vector<TensorShape> shape_list;
+ std::vector<std::pair<string, Status>> status;
+};
+
+class GetAttrTest : public OpKernelBuilderTest {};
+
+REGISTER_OP("GetAttrStringList")
+ .Attr("attr_name: string")
+ .Attr("a: list(string)");
+REGISTER_KERNEL_BUILDER(Name("GetAttrStringList").Device(DEVICE_CPU),
+ GetAttrKernel);
+
+TEST_F(GetAttrTest, StringList) {
+ std::unique_ptr<OpKernel> op_kernel =
+ ExpectSuccess("GetAttrStringList", DEVICE_CPU,
+ {"attr_name|string|'a'", "a|list(string)|['foo', 'bar']"});
+ auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({"s_list"});
+ EXPECT_EQ(std::vector<string>({"foo", "bar"}), get_attr_kernel->s_list);
+
+ op_kernel = ExpectSuccess("GetAttrStringList", DEVICE_CPU,
+ {"attr_name|string|'b'", "a|list(string)|['baz']"});
+ get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({});
+ EXPECT_TRUE(get_attr_kernel->s_list.empty());
+}
+
+REGISTER_OP("GetAttrInt")
+ .Attr("attr_name: string")
+ .Attr("a: int")
+ .Attr("b: list(int)");
+REGISTER_KERNEL_BUILDER(Name("GetAttrInt").Device(DEVICE_CPU), GetAttrKernel);
+
+TEST_F(GetAttrTest, Int) {
+ std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
+ "GetAttrInt", DEVICE_CPU,
+ {"attr_name|string|'a'", "a|int|35", "b|list(int)|[-1, 2, -4]"});
+ auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({"i", "i32"});
+ EXPECT_EQ(35, get_attr_kernel->i);
+ EXPECT_EQ(35, get_attr_kernel->i32);
+
+ op_kernel = ExpectSuccess(
+ "GetAttrInt", DEVICE_CPU,
+ {"attr_name|string|'b'", "a|int|35", "b|list(int)|[-1, 2, -4]"});
+ get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({"i_list", "i32_list"});
+ EXPECT_EQ(std::vector<int64>({-1, 2, -4}), get_attr_kernel->i_list);
+ EXPECT_EQ(std::vector<int32>({-1, 2, -4}), get_attr_kernel->i32_list);
+
+ // 8589934592 == 2^33, too big to fit in an int32
+ op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU,
+ {"attr_name|string|'a'", "a|int|8589934592",
+ "b|list(int)|[-8589934592]"});
+ get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({"i"}); // no i32
+ EXPECT_EQ(8589934592ll, get_attr_kernel->i);
+ for (const auto& key_status : get_attr_kernel->status) {
+ if (key_status.first == "i32") {
+ EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code());
+ EXPECT_EQ("Attr a has value 8589934592 out of range for an int32",
+ key_status.second.error_message());
+ }
+ }
+
+ op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU,
+ {"attr_name|string|'b'", "a|int|8589934592",
+ "b|list(int)|[-8589934592]"});
+ get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({"i_list"}); // no i32_list
+ EXPECT_EQ(std::vector<int64>({-8589934592ll}), get_attr_kernel->i_list);
+ for (const auto& key_status : get_attr_kernel->status) {
+ if (key_status.first == "i32_list") {
+ EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code());
+ EXPECT_EQ("Attr b has value -8589934592 out of range for an int32",
+ key_status.second.error_message());
+ }
+ }
+}
+
+REGISTER_OP("GetAttrShape")
+ .Attr("attr_name: string")
+ .Attr("a: shape")
+ .Attr("b: list(shape)");
+REGISTER_KERNEL_BUILDER(Name("GetAttrShape").Device(DEVICE_CPU), GetAttrKernel);
+
+TEST_F(GetAttrTest, Shape) {
+ std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
+ "GetAttrShape", DEVICE_CPU,
+ {"attr_name|string|'a'", "a|shape|{ dim { size: 3 } }",
+ "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"});
+ auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({"shape", "shape_proto"});
+ EXPECT_EQ(get_attr_kernel->shape_proto.ShortDebugString(), "dim { size: 3 }");
+ EXPECT_EQ("[3]", get_attr_kernel->shape.ShortDebugString());
+
+ op_kernel = ExpectSuccess(
+ "GetAttrShape", DEVICE_CPU,
+ {"attr_name|string|'b'", "a|shape|{ dim { size: 3 } }",
+ "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"});
+ get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({"shape_list", "shape_proto_list"});
+ ASSERT_EQ(2, get_attr_kernel->shape_proto_list.size());
+ EXPECT_EQ(get_attr_kernel->shape_proto_list[0].ShortDebugString(),
+ "dim { size: 2 }");
+ EXPECT_EQ(get_attr_kernel->shape_proto_list[1].ShortDebugString(),
+ "dim { size: 4 }");
+ ASSERT_EQ(2, get_attr_kernel->shape_list.size());
+ EXPECT_EQ("[2]", get_attr_kernel->shape_list[0].ShortDebugString());
+ EXPECT_EQ("[4]", get_attr_kernel->shape_list[1].ShortDebugString());
+}
+
+REGISTER_OP("GetAttrType").Attr("attr_name: string").Attr("a: type");
+REGISTER_KERNEL_BUILDER(Name("GetAttrType").Device(DEVICE_CPU), GetAttrKernel);
+
+TEST_F(GetAttrTest, Type) {
+ std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
+ "GetAttrType", DEVICE_CPU, {"attr_name|string|'a'", "a|type|DT_FLOAT"});
+ auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+ get_attr_kernel->ExpectOk({"type"});
+ EXPECT_EQ(DT_FLOAT, get_attr_kernel->type);
+}
+
+REGISTER_OP("GetAttrTypeList").Attr("attr_name: string").Attr("a: list(type)");
+REGISTER_KERNEL_BUILDER(Name("GetAttrTypeList").Device(DEVICE_CPU),
+ GetAttrKernel);
+
+TEST_F(GetAttrTest, TypeList) {
+ std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
+ "GetAttrTypeList", DEVICE_CPU,
+ {"attr_name|string|'a'", "a|list(type)|[DT_INT32, DT_BOOL]"});
+ auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
+
+ get_attr_kernel->ExpectOk({"type_list", "type_vector"});
+ ASSERT_EQ(2, get_attr_kernel->type_list.size());
+ EXPECT_EQ(DT_INT32, get_attr_kernel->type_list[0]);
+ EXPECT_EQ(DT_BOOL, get_attr_kernel->type_list[1]);
+ ASSERT_EQ(2, get_attr_kernel->type_vector.size());
+ EXPECT_EQ(DT_INT32, get_attr_kernel->type_vector[0]);
+ EXPECT_EQ(DT_BOOL, get_attr_kernel->type_vector[1]);
+}
+
+REGISTER_OP("HostMemoryTest")
+ .Input("a: float")
+ .Input("b: T")
+ .Input("c: N * string")
+ .Output("o: N * T")
+ .Attr("T: type")
+ .Attr("N: int");
+REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel);
+REGISTER_KERNEL_BUILDER(Name("HostMemoryTest")
+ .Device(DEVICE_GPU)
+ .HostMemory("a")
+ .HostMemory("c")
+ .HostMemory("o"),
+ DummyKernel);
+
+TEST(MemoryTypesForNode, Simple) {
+ NodeDef node_def;
+ ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest")
+ .Input(FakeInput())
+ .Input(FakeInput(DT_BOOL))
+ .Input(FakeInput(3))
+ .Finalize(&node_def));
+ MemoryTypeVector input, output;
+
+ EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def,
+ &input, &output));
+ EXPECT_EQ(MemoryTypeVector(5, DEVICE_MEMORY), input);
+ EXPECT_EQ(MemoryTypeVector(3, DEVICE_MEMORY), output);
+
+ EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def,
+ &input, &output));
+ EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY,
+ HOST_MEMORY, HOST_MEMORY}),
+ input);
+ EXPECT_EQ(MemoryTypeVector(3, HOST_MEMORY), output);
+}
+
+class BaseKernel : public ::tensorflow::OpKernel {
+ public:
+ explicit BaseKernel(OpKernelConstruction* context) : OpKernel(context) {}
+ void Compute(::tensorflow::OpKernelContext* context) override {}
+ virtual int Which() const = 0;
+};
+
+template <int WHICH>
+class LabeledKernel : public BaseKernel {
+ public:
+ using BaseKernel::BaseKernel;
+ int Which() const override { return WHICH; }
+};
+
+class LabelTest : public OpKernelBuilderTest {};
+
+REGISTER_OP("LabeledKernel");
+REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU),
+ LabeledKernel<0>);
+REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("one"),
+ LabeledKernel<1>);
+REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"),
+ LabeledKernel<2>);
+REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"),
+ LabeledKernel<3>);
+
+TEST_F(LabelTest, Default) {
+ std::unique_ptr<OpKernel> op_kernel =
+ ExpectSuccess("LabeledKernel", DEVICE_CPU, {});
+ auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
+ EXPECT_EQ(0, get_labeled_kernel->Which());
+}
+
+TEST_F(LabelTest, Specified) {
+ std::unique_ptr<OpKernel> op_kernel =
+ ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"});
+ auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
+ EXPECT_EQ(1, get_labeled_kernel->Which());
+}
+
+TEST_F(LabelTest, Duplicate) {
+ ExpectFailure("LabeledKernel", DEVICE_CPU, {"_kernel|string|'dupe'"},
+ error::INVALID_ARGUMENT);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc
new file mode 100644
index 0000000000..a39bebd854
--- /dev/null
+++ b/tensorflow/core/framework/op_segment.cc
@@ -0,0 +1,86 @@
+#include "tensorflow/core/framework/op_segment.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+OpSegment::Item::~Item() {
+ for (auto kv : name_kernel) delete kv.second;
+}
+
+OpSegment::OpSegment() {}
+
+OpSegment::~OpSegment() {
+ for (auto kv : sessions_) delete kv.second;
+}
+
+Status OpSegment::FindOrCreate(const string& session_handle,
+ const string& node_name, OpKernel** kernel,
+ CreateKernelFn create_fn) {
+ {
+ mutex_lock l(mu_);
+ auto item = gtl::FindPtrOrNull(sessions_, session_handle);
+ if (item == nullptr) {
+ return errors::NotFound("Session ", session_handle, " is not found.");
+ }
+ *kernel = gtl::FindPtrOrNull(item->name_kernel, node_name);
+ if (*kernel != nullptr) {
+ return Status::OK();
+ }
+ }
+ Status s = create_fn(kernel);
+ if (!s.ok()) {
+ LOG(ERROR) << "Create kernel failed: " << s;
+ return s;
+ }
+ {
+ mutex_lock l(mu_);
+ auto item = gtl::FindPtrOrNull(sessions_, session_handle);
+ if (item == nullptr) {
+ return errors::NotFound("Session ", session_handle, " is not found.");
+ }
+ OpKernel** p_kernel = &(item->name_kernel[node_name]);
+ if (*p_kernel == nullptr) {
+ *p_kernel = *kernel; // Inserts 'kernel' in the map.
+ } else {
+ delete *kernel;
+ *kernel = *p_kernel;
+ }
+ }
+ return Status::OK();
+}
+
+void OpSegment::AddHold(const string& session_handle) {
+ mutex_lock l(mu_);
+ Item** item = &sessions_[session_handle];
+ if (*item == nullptr) {
+ *item = new Item; // num_holds == 1
+ } else {
+ ++((*item)->num_holds);
+ }
+}
+
+void OpSegment::RemoveHold(const string& session_handle) {
+ Item* item = nullptr;
+ {
+ mutex_lock l(mu_);
+ auto siter = sessions_.find(session_handle);
+ if (siter == sessions_.end()) {
+ VLOG(1) << "Session " << session_handle << " is not found.";
+ return;
+ }
+ item = siter->second;
+ if (--(item->num_holds) > 0) {
+ return;
+ } else {
+ sessions_.erase(siter);
+ }
+ }
+ delete item;
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/op_segment.h b/tensorflow/core/framework/op_segment.h
new file mode 100644
index 0000000000..55249d2a38
--- /dev/null
+++ b/tensorflow/core/framework/op_segment.h
@@ -0,0 +1,67 @@
+#ifndef TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_
+#define TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// OpSegment keeps track of OpKernels registered for sessions running
+// on a device.
+//
+// The implementation maintains a two-level map. The 1st level maps
+// session handle to the map of registered OpKernels. The 2nd level
+// map maps node names to instantiated OpKernel objects.
+//
+// Each 2-nd level map is reference-counted and the caller can call
+// AddHold to obtain a reference on all kernels of a session and
+// ensure these kernels are alive until a corresponding RemoveHold is
+// called on the same session.
+class OpSegment {
+ public:
+ OpSegment();
+ ~OpSegment();
+
+ // A hold can be placed on a session, preventing all its kernels
+ // from being deleted.
+ void AddHold(const string& session_handle);
+ void RemoveHold(const string& session_handle);
+
+ // If the kernel for "node_name" has been created in the
+ // "session_handle", returns the existing op kernel in "*kernel".
+ // Otherwise, creates the kernel by calling create_fn(), cache it,
+ // and returns it in "*kernel". If create_fn() fails, returns the
+ // error.
+ //
+ // OpSegment keeps the ownership of the returned "*kernel".
+ typedef std::function<Status(OpKernel**)> CreateKernelFn;
+ Status FindOrCreate(const string& session_handle, const string& node_name,
+ OpKernel** kernel, CreateKernelFn create_fn);
+
+ private:
+ // op name -> OpKernel
+ typedef std::unordered_map<string, OpKernel*> KernelMap;
+ struct Item {
+ int num_holds = 1; // Num of holds put on the session.
+ KernelMap name_kernel; // op name -> kernel.
+ ~Item();
+ };
+
+ // session handle -> item.
+ // Session handles are produced by strings::FpToString()
+ typedef std::unordered_map<string, Item*> SessionMap;
+
+ mutable mutex mu_;
+ SessionMap sessions_ GUARDED_BY(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(OpSegment);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_OP_SEGMENT_H_
diff --git a/tensorflow/core/framework/op_segment_test.cc b/tensorflow/core/framework/op_segment_test.cc
new file mode 100644
index 0000000000..6297718df8
--- /dev/null
+++ b/tensorflow/core/framework/op_segment_test.cc
@@ -0,0 +1,142 @@
+#include "tensorflow/core/framework/op_segment.h"
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+
+class OpSegmentTest : public ::testing::Test {
+ protected:
+ DeviceBase device_;
+ std::vector<NodeDef> int32_nodedefs_;
+ std::vector<NodeDef> float_nodedefs_;
+
+ OpSegmentTest() : device_(Env::Default()) {
+ RequireDefaultOps();
+ for (int i = 0; i < 10; ++i) {
+ NodeDef def;
+ TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul")
+ .Input("x", 0, DT_INT32)
+ .Input("y", 0, DT_INT32)
+ .Finalize(&def));
+ int32_nodedefs_.push_back(def);
+ TF_CHECK_OK(NodeDefBuilder(strings::StrCat("op", i), "Mul")
+ .Input("x", 0, DT_FLOAT)
+ .Input("y", 0, DT_FLOAT)
+ .Finalize(&def));
+ float_nodedefs_.push_back(def);
+ }
+ }
+
+ void ValidateOpAndTypes(OpKernel* op, const NodeDef& expected, DataType dt) {
+ ASSERT_NE(op, nullptr);
+ EXPECT_EQ(expected.DebugString(), op->def().DebugString());
+ EXPECT_EQ(2, op->num_inputs());
+ EXPECT_EQ(dt, op->input_type(0));
+ EXPECT_EQ(dt, op->input_type(1));
+ EXPECT_EQ(1, op->num_outputs());
+ EXPECT_EQ(dt, op->output_type(0));
+ }
+
+ OpSegment::CreateKernelFn GetFn(const NodeDef* ndef) {
+ return [this, ndef](OpKernel** kernel) {
+ Status s;
+ auto created =
+ CreateOpKernel(DEVICE_CPU, &device_, cpu_allocator(), *ndef, &s);
+ if (s.ok()) {
+ *kernel = created.release();
+ }
+ return s;
+ };
+ }
+};
+
+TEST_F(OpSegmentTest, Basic) {
+ OpSegment opseg;
+ OpKernel* op;
+
+ opseg.AddHold("A");
+ opseg.AddHold("B");
+ for (int i = 0; i < 10; ++i) {
+ // Register in session A.
+ auto* ndef = &float_nodedefs_[i];
+ EXPECT_OK(opseg.FindOrCreate("A", ndef->name(), &op, GetFn(ndef)));
+ ValidateOpAndTypes(op, *ndef, DT_FLOAT);
+
+ // Register in session B.
+ ndef = &int32_nodedefs_[i];
+ EXPECT_OK(opseg.FindOrCreate("B", ndef->name(), &op, GetFn(ndef)));
+ ValidateOpAndTypes(op, *ndef, DT_INT32);
+ }
+
+ auto reterr = [](OpKernel** kernel) {
+ return errors::Internal("Should not be called");
+ };
+ for (int i = 0; i < 10; ++i) {
+ // Lookup op in session A.
+ EXPECT_OK(opseg.FindOrCreate("A", strings::StrCat("op", i), &op, reterr));
+ ValidateOpAndTypes(op, float_nodedefs_[i], DT_FLOAT);
+
+ // Lookup op in session B.
+ EXPECT_OK(opseg.FindOrCreate("B", strings::StrCat("op", i), &op, reterr));
+ ValidateOpAndTypes(op, int32_nodedefs_[i], DT_INT32);
+ }
+
+ opseg.RemoveHold("A");
+ opseg.RemoveHold("B");
+}
+
+TEST_F(OpSegmentTest, SessionNotFound) {
+ OpSegment opseg;
+ OpKernel* op;
+ NodeDef def = float_nodedefs_[0];
+ Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def));
+ EXPECT_TRUE(errors::IsNotFound(s)) << s;
+}
+
+TEST_F(OpSegmentTest, CreateFailure) {
+ OpSegment opseg;
+ OpKernel* op;
+ NodeDef def = float_nodedefs_[0];
+ def.set_op("nonexistop");
+ opseg.AddHold("A");
+ Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def));
+ EXPECT_TRUE(errors::IsNotFound(s)) << s;
+ opseg.RemoveHold("A");
+}
+
+TEST_F(OpSegmentTest, AddRemoveHolds) {
+ OpSegment opseg;
+ OpKernel* op;
+ const auto& ndef = int32_nodedefs_[0];
+
+ // No op.
+ opseg.RemoveHold("null");
+
+ // Thread1 register the op and wants to ensure it alive.
+ opseg.AddHold("foo");
+ EXPECT_OK(opseg.FindOrCreate("foo", ndef.name(), &op, GetFn(&ndef)));
+
+ // Thread2 starts some execution needs "op" to be alive.
+ opseg.AddHold("foo");
+
+ // Thread1 clears session "foo". E.g., a master sends CleanupGraph
+ // before an execution finishes.
+ opseg.RemoveHold("foo");
+
+ // Thread2 should still be able to access "op".
+ ValidateOpAndTypes(op, ndef, DT_INT32);
+
+ // Thread2 then remove its hold on "foo".
+ opseg.RemoveHold("foo");
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/queue_interface.h b/tensorflow/core/framework/queue_interface.h
new file mode 100644
index 0000000000..a765c211cb
--- /dev/null
+++ b/tensorflow/core/framework/queue_interface.h
@@ -0,0 +1,77 @@
+#ifndef TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_
+#define TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+// All implementations must be thread-safe.
+class QueueInterface : public ResourceBase {
+ public:
+ typedef std::vector<Tensor> Tuple;
+ typedef AsyncOpKernel::DoneCallback DoneCallback;
+ typedef std::function<void(const Tuple&)> CallbackWithTuple;
+
+ virtual Status ValidateTuple(const Tuple& tuple) = 0;
+ virtual Status ValidateManyTuple(const Tuple& tuple) = 0;
+
+ // Stashes a function object for future execution, that will eventually
+ // enqueue the tuple of tensors into the queue, and returns immediately. The
+ // function object is guaranteed to call 'callback'.
+ virtual void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
+ DoneCallback callback) = 0;
+
+ // Same as above, but the component tensors are sliced along the 0th dimension
+ // to make multiple queue-element components.
+ virtual void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
+ DoneCallback callback) = 0;
+
+ // Stashes a function object for future execution, that will eventually
+ // dequeue an element from the queue and call 'callback' with that tuple
+ // element as argument.
+ virtual void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) = 0;
+
+ // Same as above, but the stashed function object will attempt to dequeue
+ // num_elements items.
+ virtual void TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ CallbackWithTuple callback) = 0;
+
+ // Signals that no more elements will be enqueued, and optionally
+ // cancels pending Enqueue(Many) operations.
+ //
+ // After calling this function, subsequent calls to Enqueue(Many)
+ // will fail. If `cancel_pending_enqueues` is true, all pending
+ // calls to Enqueue(Many) will fail as well.
+ //
+ // After calling this function, all current and subsequent calls to
+ // Dequeue(Many) will fail instead of blocking (though they may
+ // succeed if they can be satisfied by the elements in the queue at
+ // the time it was closed).
+ virtual void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
+ DoneCallback callback) = 0;
+
+ // Assuming *this represents a shared queue, verify that it matches
+ // another instantiation indicated by node_def.
+ virtual Status MatchesNodeDef(const NodeDef& node_def) = 0;
+
+ // Returns the number of elements in the queue.
+ virtual int32 size() = 0;
+
+ virtual const DataTypeVector& component_dtypes() const = 0;
+
+ string DebugString() override { return "A queue"; }
+
+ protected:
+ virtual ~QueueInterface() {}
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_
diff --git a/tensorflow/core/framework/reader_interface.h b/tensorflow/core/framework/reader_interface.h
new file mode 100644
index 0000000000..b307c37f01
--- /dev/null
+++ b/tensorflow/core/framework/reader_interface.h
@@ -0,0 +1,66 @@
+#ifndef TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_
+#define TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+class QueueInterface;
+class ReaderInterface;
+
+// Readers are the mechanism for reading records from files in
+// TensorFlow graphs. Each supported file format has a corresponding
+// ReaderInterface descendant and a corresponding Op & OpKernel
+// (implemented using ReaderOpKernel from reader_op_kernel.h).
+//
+// To use a Reader, you first encode "work" (some string, typically a
+// filename) in the Reader's "work queue". It then processes the
+// "work" (reading records from the file), to produce key/value
+// strings. The methods of this class are called by ReaderFoo ops,
+// so see ../ops/io_ops.cc for detailed descriptions.
+//
+// All descendants of this class must be thread-safe.
+//
+// See the design document here:
+// https://docs.google.com/document/d/1UAgZOoeehYr20TdzW2CoZ30V-aqQphU4SwKXsW7eJv4/edit#
+
+// TODO(josh11b): Switch this to Async.
+class ReaderInterface : public ResourceBase {
+ public:
+ // Read a single record into *key / *value. May get more work from
+ // *queue if the current work is complete. Sets the status on
+ // *context with an OutOfRange Status if the the current work is
+ // complete and the queue is done (closed and empty).
+ // This method may block.
+ virtual void Read(QueueInterface* queue, string* key, string* value,
+ OpKernelContext* context) = 0;
+
+ // Restore this reader to its newly-constructed state.
+ virtual Status Reset() = 0;
+
+ // Accessors
+ virtual int64 NumRecordsProduced() = 0;
+ virtual int64 NumWorkUnitsCompleted() = 0;
+
+ // -- Serialization/Restoration support --
+ // Not all readers will support saving and restoring state.
+ virtual Status SerializeState(string* state) = 0;
+ // Note: Must Reset on error.
+ virtual Status RestoreState(const string& state) = 0;
+
+ string DebugString() override { return "a reader"; }
+
+ protected:
+ virtual ~ReaderInterface() {}
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_
diff --git a/tensorflow/core/framework/reader_op_kernel.cc b/tensorflow/core/framework/reader_op_kernel.cc
new file mode 100644
index 0000000000..719f27d94b
--- /dev/null
+++ b/tensorflow/core/framework/reader_op_kernel.cc
@@ -0,0 +1,39 @@
+#include "tensorflow/core/framework/reader_op_kernel.h"
+
+namespace tensorflow {
+
+ReaderOpKernel::ReaderOpKernel(OpKernelConstruction* context)
+ : OpKernel(context), have_handle_(false) {
+ OP_REQUIRES_OK(context, context->allocate_persistent(
+ tensorflow::DT_STRING,
+ tensorflow::TensorShape({2}), &handle_, nullptr));
+}
+
+ReaderOpKernel::~ReaderOpKernel() {
+ if (have_handle_ && cinfo_.resource_is_private_to_kernel()) {
+ TF_CHECK_OK(cinfo_.resource_manager()->Delete<ReaderInterface>(
+ cinfo_.container(), cinfo_.name()));
+ }
+}
+
+void ReaderOpKernel::Compute(OpKernelContext* ctx) {
+ mutex_lock l(mu_);
+ if (!have_handle_) {
+ OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), false));
+ ReaderInterface* reader;
+ OP_REQUIRES_OK(ctx,
+ cinfo_.resource_manager()->LookupOrCreate<ReaderInterface>(
+ cinfo_.container(), cinfo_.name(), &reader,
+ [this](ReaderInterface** ret) {
+ *ret = factory_();
+ return Status::OK();
+ }));
+ auto h = handle_.AccessTensor(ctx)->flat<string>();
+ h(0) = cinfo_.container();
+ h(1) = cinfo_.name();
+ have_handle_ = true;
+ }
+ ctx->set_output_ref(0, &mu_, handle_.AccessTensor(ctx));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/reader_op_kernel.h b/tensorflow/core/framework/reader_op_kernel.h
new file mode 100644
index 0000000000..8e5cc50c9b
--- /dev/null
+++ b/tensorflow/core/framework/reader_op_kernel.h
@@ -0,0 +1,42 @@
+#ifndef TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_
+#define TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_
+
+#include <functional>
+#include <string>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/reader_interface.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Implementation for ops providing a Reader.
+class ReaderOpKernel : public OpKernel {
+ public:
+ explicit ReaderOpKernel(OpKernelConstruction* context);
+ ~ReaderOpKernel() override;
+
+ void Compute(OpKernelContext* context) override;
+
+ // Must be called by descendants before the first call to Compute()
+ // (typically called during construction). factory must return a
+ // ReaderInterface descendant allocated with new that ReaderOpKernel
+ // will take ownership of.
+ void SetReaderFactory(std::function<ReaderInterface*()> factory) {
+ mutex_lock l(mu_);
+ DCHECK(!have_handle_);
+ factory_ = factory;
+ }
+
+ private:
+ mutex mu_;
+ bool have_handle_ GUARDED_BY(mu_);
+ PersistentTensor handle_ GUARDED_BY(mu_);
+ ContainerInfo cinfo_;
+ std::function<ReaderInterface*()> factory_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_
diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h
new file mode 100644
index 0000000000..18473aea2e
--- /dev/null
+++ b/tensorflow/core/framework/register_types.h
@@ -0,0 +1,90 @@
+#ifndef TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
+#define TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
+// This file is used by cuda code and must remain compilable by nvcc.
+
+#include "tensorflow/core/platform/port.h"
+
+// Macros to apply another macro to lists of supported types. If you change
+// the lists of types, please also update the list in types.cc.
+//
+// See example uses of these macros in core/ops.
+//
+//
+// Each of these TF_CALL_XXX_TYPES(m) macros invokes the macro "m" multiple
+// times by passing each invocation a data type supported by TensorFlow.
+//
+// The different variations pass different subsets of the types.
+// TF_CALL_ALL_TYPES(m) applied "m" to all types supported by TensorFlow.
+// The set of types depends on the compilation platform.
+//.
+// This can be used to register a different template instantiation of
+// an OpKernel for different signatures, e.g.:
+/*
+ #define REGISTER_PARTITION(type) \
+ REGISTER_TF_OP_KERNEL("partition", DEVICE_CPU, #type ", int32", \
+ PartitionOp<type>);
+ TF_CALL_ALL_TYPES(REGISTER_PARTITION)
+ #undef REGISTER_PARTITION
+*/
+
+#ifndef __ANDROID__
+
+// Call "m" for all number types that support the comparison operations "<" and
+// ">".
+#define TF_CALL_REAL_NUMBER_TYPES(m) \
+ m(float); \
+ m(double); \
+ m(int64); \
+ m(int32); \
+ m(uint8); \
+ m(int16); \
+ m(int8)
+
+#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
+ m(float); \
+ m(double); \
+ m(int64); \
+ m(uint8); \
+ m(int16); \
+ m(int8)
+
+// Call "m" for all number types, including complex64.
+#define TF_CALL_NUMBER_TYPES(m) \
+ TF_CALL_REAL_NUMBER_TYPES(m); \
+ m(complex64)
+
+#define TF_CALL_NUMBER_TYPES_NO_INT32(m) \
+ TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m); \
+ m(complex64)
+
+// Call "m" on all types.
+#define TF_CALL_ALL_TYPES(m) \
+ TF_CALL_NUMBER_TYPES(m); \
+ m(bool); \
+ m(string)
+
+// Call "m" on all types supported on GPU.
+#define TF_CALL_GPU_NUMBER_TYPES(m) \
+ m(float); \
+ m(double)
+
+#else // __ANDROID__
+
+#define TF_CALL_REAL_NUMBER_TYPES(m) \
+ m(float); \
+ m(int32)
+
+#define TF_CALL_NUMBER_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m)
+
+#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) m(float)
+
+#define TF_CALL_NUMBER_TYPES_NO_INT32(m) TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m)
+
+#define TF_CALL_ALL_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m)
+
+// Maybe we could put an empty macro here for Android?
+#define TF_CALL_GPU_NUMBER_TYPES(m) m(float)
+
+#endif // __ANDROID__
+
+#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc
new file mode 100644
index 0000000000..7f551ea65f
--- /dev/null
+++ b/tensorflow/core/framework/rendezvous.cc
@@ -0,0 +1,263 @@
+#include "tensorflow/core/framework/rendezvous.h"
+
+#include <unordered_map>
+#include <utility>
+
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+
+/* static */
+string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation,
+ const string& dst_device, const string& name,
+ const FrameAndIter& frame_iter) {
+ // NOTE: ';' is not used in the device name's job name.
+ //
+ // We include both sender and receiver in the key to facilitate
+ // debugging. For correctness, we only need to encode the receiver.
+ //
+ // "src_incarnation" is used to distinguish a worker when it
+ // restarts.
+ return strings::StrCat(src_device, ";", strings::FpToString(src_incarnation),
+ ";", dst_device, ";", name, ";", frame_iter.frame_id,
+ ":", frame_iter.iter_id);
+}
+
+/* static */
+Status Rendezvous::ParseKey(const string& key, ParsedKey* out) {
+ // TODO(zhifengc): This code is not fast enough.
+ std::vector<string> parts = str_util::Split(key, ';');
+ if (parts.size() == 5 &&
+ DeviceNameUtils::ParseFullName(parts[0], &out->src) &&
+ strings::StringToFp(parts[1], &out->src_incarnation) &&
+ DeviceNameUtils::ParseFullName(parts[2], &out->dst) &&
+ !parts[3].empty()) {
+ out->src_device = parts[0];
+ out->dst_device = parts[2];
+ out->edge_name = parts[3];
+ return Status::OK();
+ }
+ return errors::InvalidArgument("Invalid rendezvous key: ", key);
+}
+
+Rendezvous::~Rendezvous() {}
+
+Status Rendezvous::Recv(const string& key, const Args& recv_args, Tensor* val,
+ bool* is_dead) {
+ Status ret;
+ Notification n;
+ RecvAsync(key, recv_args,
+ [&ret, &n, val, is_dead](const Status& s, const Args& send_args,
+ const Args& recv_args, const Tensor& v,
+ const bool dead) {
+ ret = s;
+ *val = v;
+ *is_dead = dead;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return ret;
+}
+
+class LocalRendezvousImpl : public Rendezvous {
+ public:
+ explicit LocalRendezvousImpl(bool tolerate_dup_recv)
+ : tolerate_dup_recv_(tolerate_dup_recv) {}
+
+ Status Send(const string& key, const Args& send_args, const Tensor& val,
+ const bool is_dead) override {
+ VLOG(2) << "Send " << this << " " << key;
+ DoneCallback waiter = nullptr;
+ Args recv_args;
+ {
+ mutex_lock l(mu_);
+ if (!status_.ok()) {
+ return status_;
+ }
+ Item* item = nullptr;
+ Table::iterator iter = table_.find(key);
+ if (iter == table_.end()) {
+ // There is no waiter for this message. Insert the message
+ // into the waiters table. The waiter will pick it up when
+ // arrives.
+ item = new Item;
+ item->waiter = nullptr;
+ item->value = val;
+ item->is_dead = is_dead;
+ if (send_args.device_context) {
+ send_args.device_context->Ref();
+ item->send_dev_context = send_args.device_context;
+ }
+ item->recv_dev_context = nullptr;
+
+ // The allocator attributes of item->value.
+ item->send_alloc_attrs = send_args.alloc_attrs;
+
+ CHECK(table_.insert({key, item}).second);
+ return Status::OK();
+ } else {
+ item = iter->second;
+ if (item->waiter == nullptr) {
+ // There is already a message in the table under the key.
+ // Should not happen unless it has a waiter.
+ return errors::Aborted("Duplicated send: ", key);
+ }
+ // Mark item as complete.
+ item->has_been_recvd = true;
+ waiter = item->waiter;
+ item->waiter = nullptr;
+ // The ref on recv_dev_context transfers below.
+ recv_args.device_context = item->recv_dev_context;
+ recv_args.alloc_attrs = item->recv_alloc_attrs;
+ item->recv_dev_context = nullptr;
+ if (tolerate_dup_recv_) {
+ item->value = val;
+ item->is_dead = is_dead;
+ if (send_args.device_context) {
+ send_args.device_context->Ref();
+ item->send_dev_context = send_args.device_context;
+ }
+ item->send_alloc_attrs = send_args.alloc_attrs;
+ }
+ }
+ } // mutex
+ // Notify the waiter by invoking its done closure, outside scope
+ // of the table lock.
+ waiter(Status::OK(), send_args, recv_args, val, is_dead);
+ if (recv_args.device_context) recv_args.device_context->Unref();
+ return Status::OK();
+ }
+
+ void RecvAsync(const string& key, const Args& recv_args,
+ DoneCallback done) override {
+ VLOG(2) << "Recv " << this << " " << key;
+ mu_.lock();
+ if (!status_.ok()) {
+ // Rendezvous has been aborted.
+ Status s = status_;
+ mu_.unlock();
+ done(s, Args(), recv_args, Tensor(), false);
+ return;
+ }
+ Table::iterator iter = table_.find(key);
+ if (iter != table_.end()) {
+ Item* item = iter->second;
+ if (item->has_been_recvd && !tolerate_dup_recv_) {
+ mu_.unlock();
+ done(errors::Aborted("Duplicated recv: ", key), Args(), recv_args,
+ Tensor(), false);
+ } else if (item->waiter == nullptr || tolerate_dup_recv_) {
+ // A message has already arrived and is stored in the table
+ // under this key. Consumes the message and invokes the done
+ // closure.
+ Tensor v = item->value;
+ if (!tolerate_dup_recv_) {
+ item->value = Tensor();
+ }
+ item->has_been_recvd = true;
+ // Before dropping the table lock, capture the item values.
+ // DeviceContext is only non-null for non-CPU devices.
+ // If we capture the send_dev_context, we need to hold a ref on
+ // it. Our caller will have a ref on the recv_dev_context,
+ // which is not in our table.
+ DeviceContext* send_dev_context = item->send_dev_context;
+ if (send_dev_context) send_dev_context->Ref();
+ bool is_dead = item->is_dead;
+ mu_.unlock();
+ Args send_args;
+ send_args.device_context = item->send_dev_context;
+ send_args.alloc_attrs = item->send_alloc_attrs;
+ done(Status::OK(), send_args, recv_args, v, is_dead);
+ if (send_dev_context) send_dev_context->Unref();
+ } else {
+ // Already have a waiter in the waiters table under this key,
+ // which should not happen.
+ mu_.unlock();
+ done(errors::Aborted("Duplicated recv: ", key), Args(), recv_args,
+ Tensor(), false);
+ }
+ return;
+ }
+ // Waiting for a message that has not arrived yet. Insert into the
+ // waiting table. The done closure will be invoked when the
+ // message arrives.
+ Item* item = new Item;
+ item->waiter = done;
+ if (recv_args.device_context) {
+ item->recv_dev_context = recv_args.device_context;
+ item->recv_alloc_attrs = recv_args.alloc_attrs;
+ item->recv_dev_context->Ref();
+ }
+ CHECK(table_.insert({key, item}).second);
+ mu_.unlock();
+ return;
+ }
+
+ void StartAbort(const Status& status) override {
+ CHECK(!status.ok());
+ std::vector<Item*> items;
+ {
+ mutex_lock l(mu_);
+ if (!status_.ok()) return;
+ status_ = status;
+ items.reserve(table_.size());
+ for (const auto& p : table_) items.push_back(p.second);
+ table_.clear();
+ }
+ for (Item* item : items) {
+ if (item->waiter != nullptr) {
+ item->waiter(status, Args(), Args(), Tensor(), false);
+ }
+ delete item;
+ }
+ }
+
+ private:
+ typedef LocalRendezvousImpl ME;
+ const bool tolerate_dup_recv_;
+
+ struct Item {
+ DoneCallback waiter = nullptr;
+ Tensor value;
+ bool is_dead = false;
+ bool has_been_recvd = false;
+ DeviceContext* send_dev_context = nullptr;
+ DeviceContext* recv_dev_context = nullptr;
+ AllocatorAttributes send_alloc_attrs;
+ AllocatorAttributes recv_alloc_attrs;
+
+ ~Item() {
+ if (send_dev_context) {
+ send_dev_context->Unref();
+ }
+ if (recv_dev_context) {
+ recv_dev_context->Unref();
+ }
+ }
+ };
+ typedef std::unordered_map<string, Item*> Table;
+
+ // TODO(zhifengc): shard table_.
+ mutex mu_;
+ Table table_ GUARDED_BY(mu_);
+ Status status_;
+
+ ~LocalRendezvousImpl() override {
+ for (auto i : table_) {
+ delete i.second;
+ }
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousImpl);
+};
+
+Rendezvous* NewLocalRendezvous(bool tolerate_dup_recv) {
+ return new LocalRendezvousImpl(tolerate_dup_recv);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/rendezvous.h b/tensorflow/core/framework/rendezvous.h
new file mode 100644
index 0000000000..94fbfb2523
--- /dev/null
+++ b/tensorflow/core/framework/rendezvous.h
@@ -0,0 +1,102 @@
+#ifndef TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_
+#define TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_
+
+#include <string>
+
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/framework/control_flow.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+// A Rendezvous is an abstraction for passing a Tensor
+// from a producer to a consumer, where the consumer may safely
+// request the Tensor before or after it has been produced. A
+// producer never blocks when using a Rendezvous. A consumer has the
+// choice of making a blocking call or providing a callback: in either
+// case, the consumer receives the Tensor as soon as it is available.
+//
+// A Rendezvous key encodes a single <producer, consumer> pair. It is
+// an error to call Send() or Recv*() more than once with the same
+// key.
+class Rendezvous : public core::RefCounted {
+ public:
+ struct Args {
+ DeviceContext* device_context = nullptr;
+ AllocatorAttributes alloc_attrs;
+ };
+
+ // Constructs a rendezvouz key for the tensor of "name" sent from
+ // "src_device" to "dst_device". The tensor is generated in the frame
+ // and iteration specified by "frame_iter".
+ static string CreateKey(const string& src_device, uint64 src_incarnation,
+ const string& dst_device, const string& name,
+ const FrameAndIter& frame_iter);
+
+ // Parses the key constructed by CreateKey and parse src/dst device
+ // names into structures respectively.
+ struct ParsedKey {
+ string src_device;
+ DeviceNameUtils::ParsedName src;
+ uint64 src_incarnation = 0;
+ string dst_device;
+ DeviceNameUtils::ParsedName dst;
+ string edge_name;
+ };
+ static Status ParseKey(const string& key, ParsedKey* out);
+
+ // The caller is a tensor producer and it sends a message (a tensor
+ // "val" and a bool "is_dead") under the given "key".
+ //
+ // {val, is_dead} is bundled as a message sent and received.
+ // Typically, is_dead is set by some control flow nodes
+ // (e.g., a not-take branch). args is passed by Send to the
+ // Recv function to communicate any information that the Recv
+ // function might need. This is typically only necessary for
+ // Send/Recv on the same worker.
+ //
+ // Send() never blocks.
+ virtual Status Send(const string& key, const Args& args, const Tensor& val,
+ const bool is_dead) = 0;
+
+ // Callback provided by a tensor consumer waiting on the rendezvous.
+ // It will be invoked when the tensor is available, or when a non-OK
+ // status arises in the production of that tensor. It also gets
+ // two Rendezvous::Args, one provided by the sender, the other by the
+ // receiver, which may be needed when a non-CPU device is in use
+ // by either side.
+ typedef std::function<void(const Status&, const Args&, const Args&,
+ const Tensor&, const bool)> DoneCallback;
+
+ virtual void RecvAsync(const string& key, const Args& args,
+ DoneCallback done) = 0;
+
+ // Synchronous wrapper for RecvAsync.
+ Status Recv(const string& key, const Args& args, Tensor* val, bool* is_dead);
+
+ // Aborts all pending and future Send/Recv with the given "status".
+ //
+ // StartAbort() does not wait for ongoing calls to finish.
+ // REQUIRES: !status.ok()
+ virtual void StartAbort(const Status& status) = 0;
+
+ protected:
+ ~Rendezvous() override;
+};
+
+// Returns a Rendezvous instance that is limited to use only by
+// producers and consumers in the local process. The caller assumes
+// ownership of one Ref() on the returned object.
+//
+// If "tolerate_dup_recv" is true, then the Rendezvous will retain
+// already Recv'd values and make them available to duplicate Recv
+// calls. This may be useful if the RPC layer is not reliable, but
+// comes at the cost of higher memory consumption.
+Rendezvous* NewLocalRendezvous(bool tolerate_dup_recv = false);
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_RENDEZVOUS_H_
diff --git a/tensorflow/core/framework/rendezvous_test.cc b/tensorflow/core/framework/rendezvous_test.cc
new file mode 100644
index 0000000000..32011a468f
--- /dev/null
+++ b/tensorflow/core/framework/rendezvous_test.cc
@@ -0,0 +1,314 @@
+#include "tensorflow/core/framework/rendezvous.h"
+
+#include <gtest/gtest.h>
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+TEST(RendezvousTest, Key) {
+ const string key = Rendezvous::CreateKey(
+ "/job:mnist/replica:1/task:2/CPU:0", 7890,
+ "/job:mnist/replica:1/task:2/GPU:0", "var0", FrameAndIter(0, 0));
+ EXPECT_EQ(key,
+ "/job:mnist/replica:1/task:2/CPU:0;"
+ "0000000000001ed2;" // 7890 = 0x1ed2
+ "/job:mnist/replica:1/task:2/GPU:0;"
+ "var0;"
+ "0:0");
+ Rendezvous::ParsedKey parsed;
+ EXPECT_OK(Rendezvous::ParseKey(key, &parsed));
+ EXPECT_EQ(parsed.src_device, "/job:mnist/replica:1/task:2/CPU:0");
+ EXPECT_EQ(parsed.src_incarnation, 7890);
+ EXPECT_EQ(parsed.src.type, "CPU");
+ EXPECT_EQ(parsed.dst_device, "/job:mnist/replica:1/task:2/GPU:0");
+ EXPECT_EQ(parsed.dst.type, "GPU");
+
+ EXPECT_FALSE(Rendezvous::ParseKey("foo;bar;baz", &parsed).ok());
+ EXPECT_FALSE(Rendezvous::ParseKey("/job:mnist/replica:1/task:2/CPU:0;"
+ "/job:mnist/replica:1/task:2/GPU:0;",
+ &parsed)
+ .ok());
+ EXPECT_FALSE(
+ Rendezvous::ParseKey(strings::StrCat(key, ";", key), &parsed).ok());
+}
+
+class LocalRendezvousTest : public ::testing::Test {
+ public:
+ LocalRendezvousTest()
+ : threads_(new thread::ThreadPool(Env::Default(), "test", 16)) {
+ rendez_ = NewLocalRendezvous();
+ }
+
+ ~LocalRendezvousTest() override {
+ rendez_->Unref();
+ delete threads_;
+ }
+
+ void SchedClosure(std::function<void()> fn) { threads_->Schedule(fn); }
+
+ Rendezvous* rendez_;
+
+ private:
+ thread::ThreadPool* threads_;
+};
+
+// string -> Tensor<string>
+Tensor V(const string& content) {
+ Tensor tensor(DT_STRING, TensorShape({}));
+ tensor.scalar<string>()() = content;
+ return tensor;
+}
+
+// Tensor<string> -> string
+string V(const Tensor& tensor) {
+ CHECK_EQ(tensor.dtype(), DT_STRING);
+ CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
+ return tensor.scalar<string>()();
+}
+
+TEST_F(LocalRendezvousTest, SendRecv) {
+ Rendezvous::Args args;
+ ASSERT_OK(rendez_->Send("foo", args, V("hello"), false));
+ EXPECT_TRUE(errors::IsAborted(rendez_->Send("foo", args, V("hello"), false)));
+ Tensor val(DT_STRING);
+ bool is_dead = false;
+ ASSERT_OK(rendez_->Recv("foo", args, &val, &is_dead));
+ EXPECT_EQ("hello", V(val));
+}
+
+TEST_F(LocalRendezvousTest, RecvSend) {
+ SchedClosure([this]() {
+ Env::Default()->SleepForMicroseconds(10000);
+ Rendezvous::Args args;
+ ASSERT_OK(rendez_->Send("foo", args, V("hello"), false));
+ });
+ Tensor val(DT_STRING);
+ bool is_dead = false;
+ Rendezvous::Args args;
+ ASSERT_OK(rendez_->Recv("foo", args, &val, &is_dead));
+ EXPECT_EQ("hello", V(val));
+}
+
+TEST_F(LocalRendezvousTest, DuplicateWaiterRecv) {
+ SchedClosure([this]() {
+ Tensor t(DT_STRING);
+ bool is_dead = false;
+ Rendezvous::Args args;
+ ASSERT_OK(rendez_->Recv("foo", args, &t, &is_dead));
+ ASSERT_OK(rendez_->Send("bar", args, t, is_dead));
+ });
+ Env::Default()->SleepForMicroseconds(1000000);
+ Tensor val(DT_STRING);
+ bool val_dead = false;
+ Rendezvous::Args args;
+ EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead)));
+ ASSERT_OK(rendez_->Send("foo", args, V("secret msg"), val_dead));
+ ASSERT_OK(rendez_->Recv("bar", args, &val, &val_dead));
+ EXPECT_EQ("secret msg", V(val));
+}
+
+TEST_F(LocalRendezvousTest, DuplicateSerialRecv) {
+ SchedClosure([this]() {
+ Tensor t(DT_STRING);
+ bool is_dead = false;
+ Rendezvous::Args args;
+ ASSERT_OK(rendez_->Recv("foo", args, &t, &is_dead));
+ ASSERT_OK(rendez_->Send("bar", args, t, is_dead));
+ });
+ Env::Default()->SleepForMicroseconds(1000000);
+ Tensor val(DT_STRING);
+ bool val_dead = false;
+ Rendezvous::Args args;
+ ASSERT_OK(rendez_->Send("foo", args, V("secret msg"), val_dead));
+ ASSERT_OK(rendez_->Recv("bar", args, &val, &val_dead));
+ EXPECT_EQ("secret msg", V(val));
+ EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead)));
+}
+
+// A simple structure that behaves a bit like a blocking counter. The
+// user that decrements counter to 0 does done.Notify(), and the main
+// thread waits for done to be notified.
+struct BlockingState {
+ mutex lock;
+ int counter;
+ Notification done;
+};
+
+TEST_F(LocalRendezvousTest, RandomSendRecv) {
+ static const int N = 1000;
+ BlockingState state;
+ state.counter = N;
+ for (int i = 0; i < N; ++i) {
+ SchedClosure([this, i]() {
+ random::PhiloxRandom philox(testing::RandomSeed() + i, 17);
+ random::SimplePhilox rnd(&philox);
+ Env::Default()->SleepForMicroseconds(1000 + rnd.Uniform(10000));
+ Rendezvous::Args args;
+ ASSERT_OK(rendez_->Send(strings::StrCat(i), args, V(strings::StrCat(i)),
+ false));
+ });
+ SchedClosure([this, &state, i]() {
+ random::PhiloxRandom philox(testing::RandomSeed() + N + i, 17);
+ random::SimplePhilox rnd(&philox);
+ Env::Default()->SleepForMicroseconds(1000 + rnd.Uniform(10000));
+ Tensor val(DT_STRING);
+ bool val_dead = false;
+ Rendezvous::Args args;
+ ASSERT_OK(rendez_->Recv(strings::StrCat(i), args, &val, &val_dead));
+ EXPECT_EQ(strings::StrCat(i), V(val));
+ bool done = false;
+ {
+ mutex_lock l(state.lock);
+ state.counter--;
+ if (state.counter == 0) {
+ done = true;
+ }
+ }
+ if (done) {
+ state.done.Notify();
+ }
+ });
+ }
+
+ state.done.WaitForNotification();
+}
+
+TEST_F(LocalRendezvousTest, RecvAbort) {
+ rendez_->Ref();
+ SchedClosure([this]() {
+ rendez_->StartAbort(errors::Aborted("")); // abort
+ rendez_->Unref();
+ });
+ Tensor val(DT_STRING);
+ bool val_dead = false;
+ Rendezvous::Args args;
+ Status status = rendez_->Recv("foo", args, &val, &val_dead);
+ EXPECT_TRUE(errors::IsAborted(status));
+}
+
+// Similar to RecvAbort. But this test case ensures the main thread
+// Recv() call happens after StartAbort().
+TEST_F(LocalRendezvousTest, RecvSleepAbort) {
+ rendez_->Ref();
+ SchedClosure([this]() {
+ Env::Default()->SleepForMicroseconds(1000000);
+ rendez_->StartAbort(errors::Aborted("")); // abort
+ rendez_->Unref();
+ });
+ Tensor val(DT_STRING);
+ bool val_dead = false;
+ Rendezvous::Args args;
+ Status status = rendez_->Recv("foo", args, &val, &val_dead);
+ EXPECT_TRUE(errors::IsAborted(status));
+}
+
+TEST_F(LocalRendezvousTest, AbortThenRecvOrSend) {
+ rendez_->StartAbort(errors::Aborted(""));
+ Tensor val(DT_STRING);
+ bool val_dead = false;
+ Rendezvous::Args args;
+ EXPECT_TRUE(errors::IsAborted(rendez_->Send("foo", args, val, val_dead)));
+ EXPECT_TRUE(errors::IsAborted(rendez_->Recv("foo", args, &val, &val_dead)));
+}
+
+class DummyDeviceContext : public DeviceContext {
+ public:
+ explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {}
+ ~DummyDeviceContext() override {}
+ int stream_id() const { return stream_id_; }
+
+ private:
+ const int stream_id_;
+};
+
+TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) {
+ Rendezvous::Args args;
+ args.device_context = new DummyDeviceContext(123);
+
+ ASSERT_OK(rendez_->Send("foo", args, V("hello"), false));
+
+ Notification n;
+ Rendezvous::Args args1;
+ args1.device_context = new DummyDeviceContext(1);
+ rendez_->RecvAsync("foo", args1, [&n](const Status& s,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& val, bool is_dead) {
+ CHECK_EQ(123,
+ dynamic_cast<const DummyDeviceContext*>(send_args.device_context)
+ ->stream_id());
+ n.Notify();
+ });
+
+ n.WaitForNotification();
+ args.device_context->Unref();
+ args1.device_context->Unref();
+}
+
+static void BM_SendRecv(int iters) {
+ Rendezvous* rendez = NewLocalRendezvous();
+ Tensor orig = V("val");
+ Tensor val(DT_STRING, TensorShape({}));
+ bool is_dead = false;
+ Rendezvous::Args args;
+ Status s;
+ if (iters > 0) {
+ while (iters--) {
+ s = rendez->Send("foo", args, orig, is_dead);
+ s = rendez->Recv("foo", args, &val, &is_dead);
+ }
+ CHECK_EQ(V(val), V(orig));
+ }
+ rendez->Unref();
+}
+BENCHMARK(BM_SendRecv);
+
+static void BM_RecvSend(int iters) {
+ thread::ThreadPool* pool = new thread::ThreadPool(Env::Default(), "test", 1);
+
+ // The main thread sends "foo" for iters/2 times and receives "bar"
+ // for iters/2 times. The other thread sends "bar" for iters/2
+ // times and receives "foo" for iters/2 times.
+ Rendezvous* rendez = NewLocalRendezvous();
+ pool->Schedule([rendez, iters]() {
+ Tensor bar = V("bar");
+ Tensor foo(DT_STRING, TensorShape({}));
+ bool is_dead = false;
+ Rendezvous::Args args;
+ Status s;
+ for (int i = 0; i < iters / 2; ++i) {
+ s = rendez->Recv("foo", args, &foo, &is_dead);
+ s = rendez->Send("bar", args, bar, is_dead);
+ }
+ CHECK_EQ("foo", V(foo));
+ });
+ Tensor foo = V("foo");
+ Tensor bar(DT_STRING, TensorShape({}));
+ bool is_dead = false;
+ Rendezvous::Args args;
+ Status s;
+ for (int i = 0; i < iters / 2; ++i) {
+ s = rendez->Send("foo", args, foo, is_dead);
+ s = rendez->Recv("bar", args, &bar, &is_dead);
+ }
+ CHECK_EQ("bar", V(bar));
+ delete pool;
+}
+BENCHMARK(BM_RecvSend);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
new file mode 100644
index 0000000000..42326f068e
--- /dev/null
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -0,0 +1,146 @@
+#include "tensorflow/core/framework/resource_mgr.h"
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace tensorflow {
+
+ResourceMgr::ResourceMgr() : default_container_("localhost") {}
+
+ResourceMgr::ResourceMgr(const string& default_container)
+ : default_container_(default_container) {}
+
+ResourceMgr::~ResourceMgr() { Clear(); }
+
+void ResourceMgr::Clear() {
+ mutex_lock l(mu_);
+ for (const auto& p : containers_) {
+ for (const auto& q : *p.second) {
+ q.second->Unref();
+ }
+ delete p.second;
+ }
+ containers_.clear();
+}
+
+Status ResourceMgr::DoCreate(const string& container, std::type_index type,
+ const string& name, ResourceBase* resource) {
+ {
+ mutex_lock l(mu_);
+ Container** b = &containers_[container];
+ if (*b == nullptr) {
+ *b = new Container;
+ }
+ if ((*b)->insert({{type, name}, resource}).second) {
+ return Status::OK();
+ }
+ }
+ resource->Unref();
+ return errors::AlreadyExists("Resource ", container, "/", name, "/",
+ type.name());
+}
+
+Status ResourceMgr::DoLookup(const string& container, std::type_index type,
+ const string& name,
+ ResourceBase** resource) const {
+ mutex_lock l(mu_);
+ const Container* b = gtl::FindPtrOrNull(containers_, container);
+ if (b == nullptr) {
+ return errors::NotFound("Container ", container, " does not exist.");
+ }
+ auto r = gtl::FindPtrOrNull(*b, {type, name});
+ if (r == nullptr) {
+ return errors::NotFound("Resource ", container, "/", name, "/", type.name(),
+ " does not exist.");
+ }
+ *resource = const_cast<ResourceBase*>(r);
+ (*resource)->Ref();
+ return Status::OK();
+}
+
+Status ResourceMgr::DoDelete(const string& container, std::type_index type,
+ const string& name) {
+ ResourceBase* base = nullptr;
+ {
+ mutex_lock l(mu_);
+ Container* b = gtl::FindPtrOrNull(containers_, container);
+ if (b == nullptr) {
+ return errors::NotFound("Container ", container, " does not exist.");
+ }
+ auto iter = b->find({type, name});
+ if (iter == b->end()) {
+ return errors::NotFound("Resource ", container, "/", name, "/",
+ type.name(), " does not exist.");
+ }
+ base = iter->second;
+ b->erase(iter);
+ }
+ CHECK(base != nullptr);
+ base->Unref();
+ return Status::OK();
+}
+
+Status ResourceMgr::Cleanup(const string& container) {
+ Container* b = nullptr;
+ {
+ mutex_lock l(mu_);
+ auto iter = containers_.find(container);
+ if (iter == containers_.end()) {
+ return errors::NotFound("Container ", container, " does not exist.");
+ }
+ b = iter->second;
+ containers_.erase(iter);
+ }
+ CHECK(b != nullptr);
+ for (const auto& p : *b) {
+ p.second->Unref();
+ }
+ delete b;
+ return Status::OK();
+}
+
+Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef,
+ bool use_node_name_as_default) {
+ CHECK(rmgr);
+ rmgr_ = rmgr;
+ string attr_container;
+ TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "container", &attr_container));
+ static RE2 container_re("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*");
+ if (!attr_container.empty() &&
+ !RE2::FullMatch(attr_container, container_re)) {
+ return errors::InvalidArgument("container contains invalid characters: ",
+ attr_container);
+ }
+ string attr_shared_name;
+ TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "shared_name", &attr_shared_name));
+ if (!attr_shared_name.empty() && (attr_shared_name[0] == '_')) {
+ return errors::InvalidArgument("shared_name cannot start with '_':",
+ attr_shared_name);
+ }
+ if (!attr_container.empty()) {
+ container_ = attr_container;
+ } else {
+ container_ = rmgr_->default_container();
+ }
+ if (!attr_shared_name.empty()) {
+ name_ = attr_shared_name;
+ } else if (use_node_name_as_default) {
+ name_ = ndef.name();
+ } else {
+ resource_is_private_to_kernel_ = true;
+ static std::atomic<int64> counter(0);
+ name_ = strings::StrCat("_", counter.fetch_add(1), "_", ndef.name());
+ }
+ return Status::OK();
+}
+
+string ContainerInfo::DebugString() const {
+ return strings::StrCat("[", container(), ",", name(), ",",
+ resource_is_private_to_kernel() ? "private" : "public",
+ "]");
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
new file mode 100644
index 0000000000..65e859caf1
--- /dev/null
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -0,0 +1,280 @@
+#ifndef TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_
+#define TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_
+
+#include <string>
+#include <typeindex>
+#include <typeinfo>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// A ResourceMgr instance keeps track of named and typed resources
+// grouped into containers.
+//
+// Each resource must be represented as a sub-class of ResourceBase,
+// which is reference counted explicitly. Each named resource is
+// registered with ResourceMgr under a named "container" name. At any
+// time, there is at most one instance of a resource given the container
+// name, the resource type and the resource name.
+//
+// All resources for a given container can be dropped by one call of
+// Cleanup().
+//
+// E.g.,
+// struct MyVar : public ResourceBase {
+// mutex mu;
+// Tensor val;
+// }
+//
+// ResourceMgr rm;
+//
+// // Create a var.
+// MyVar* my_var = new MyVar;
+// my_var.val = Tensor(DT_FLOAT, my_shape);
+// my_val.val.flat<float>().setZeros(); // 0 initialized.
+// ctx->SetStatus(rm.Create("my_container", "my_name", my_val));
+//
+// // += a variable.
+// MyVar* my_var = nullptr;
+// Status s = rm.Lookup("my_container", "my_name", &my_var);
+// if (s.ok()) {
+// my_var->val.flat<float>() += grad;
+// }
+// my_var->Unref(); // Or use ScopedUnref().
+// ctx->SetStatus(s);
+class ResourceBase : public core::RefCounted {
+ public:
+ // Returns a debug string for *this.
+ virtual string DebugString() = 0;
+};
+
+class ResourceMgr {
+ public:
+ ResourceMgr();
+ explicit ResourceMgr(const string& default_container);
+ ~ResourceMgr();
+
+ // Returns the default container name for *this.
+ const string& default_container() const { return default_container_; }
+
+ // Creates a resource "name" in the "container". The caller transfers
+ // the ownership of one ref on "resource" to *this
+ //
+ // REQUIRES: std::is_base_of<ResourceBase, T>
+ // REQUIRES: resource != nullptr.
+ template <typename T>
+ Status Create(const string& container, const string& name,
+ T* resource) TF_MUST_USE_RESULT;
+
+ // If "container" has a resource "name", returns it in "*resource" and
+ // the caller takes the ownership of one ref on "*resource".
+ //
+ // REQUIRES: std::is_base_of<ResourceBase, T>
+ // REQUIRES: resource != nullptr
+ template <typename T>
+ Status Lookup(const string& container, const string& name,
+ T** resource) const TF_MUST_USE_RESULT;
+
+ // If "container" has a resource "name", returns it in
+ // "*resource". Otherwise, invokes creator() to create the resource.
+ // The caller takes the ownership of one ref on "*resource".
+ //
+ // REQUIRES: std::is_base_of<ResourceBase, T>
+ // REQUIRES: resource != nullptr
+ template <typename T>
+ Status LookupOrCreate(const string& container, const string& name,
+ T** resource,
+ std::function<Status(T**)> creator) TF_MUST_USE_RESULT;
+
+ // Deletes the resource "name" from the "container".
+ //
+ // REQUIRES: std::is_base_of<ResourceBase, T>
+ template <typename T>
+ Status Delete(const string& container, const string& name) TF_MUST_USE_RESULT;
+
+ // Deletes all resources from the "container" and removes the container.
+ Status Cleanup(const string& container) TF_MUST_USE_RESULT;
+
+ // Deletes all resources in all containers.
+ void Clear();
+
+ private:
+ typedef std::pair<std::type_index, string> Key;
+ struct KeyHash {
+ std::size_t operator()(const Key& k) const {
+ return Hash64(k.second.data(), k.second.size(), k.first.hash_code());
+ }
+ };
+ struct KeyEqual {
+ bool operator()(const Key& x, const Key& y) const {
+ return (x.second == y.second) && (x.first == y.first);
+ }
+ };
+ typedef std::unordered_map<Key, ResourceBase*, KeyHash, KeyEqual> Container;
+
+ const string default_container_;
+ mutable mutex mu_;
+ std::unordered_map<string, Container*> containers_ GUARDED_BY(mu_);
+
+ Status DoCreate(const string& container, std::type_index type,
+ const string& name,
+ ResourceBase* resource) TF_MUST_USE_RESULT;
+ Status DoLookup(const string& container, std::type_index type,
+ const string& name,
+ ResourceBase** resource) const TF_MUST_USE_RESULT;
+ Status DoDelete(const string& container, std::type_index type,
+ const string& name) TF_MUST_USE_RESULT;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ResourceMgr);
+};
+
+// Policy helper to decide which container/shared_name to use for a
+// stateful kernel that accesses shared resource.
+class ContainerInfo {
+ public:
+ // Analyze the node attribute of 'ndef' and decides the container and
+ // resource name the kernel should use for accessing the shared
+ // resource.
+ //
+ // 'ndef' is expected to have node attribute "container" and
+ // "shared_name". Returns non-OK if they are not provided or they are
+ // invalid.
+ //
+ // The policy is as following:
+ // * If the attribute "container" is non-empty, it is used as is.
+ // Otherwise, uses the resource manager's default container.
+ // * If the attribute "shared_name" is non-empty, it is used as is.
+ // Otherwise, if "use_node_name_as_default" is true, the kernel's
+ // node name is used as the resource name. Otherwise, a string
+ // unique to this process is used.
+ Status Init(ResourceMgr* rmgr, const NodeDef& ndef,
+ bool use_node_name_as_default);
+ Status Init(ResourceMgr* rmgr, const NodeDef& ndef) {
+ return Init(rmgr, ndef, false);
+ }
+
+ // The policy decides that the kernel should access the resource in
+ // resource_manager(), the resource is in the container() and its
+ // name is name(). If resource_is_private_to_kernel() is true, the
+ // kernel should delete the resource when the kernel is deleted.
+ ResourceMgr* resource_manager() const { return rmgr_; }
+ const string& container() const { return container_; }
+ const string& name() const { return name_; }
+ bool resource_is_private_to_kernel() const {
+ return resource_is_private_to_kernel_;
+ }
+
+ // Returns a readable string for *this.
+ string DebugString() const;
+
+ private:
+ ResourceMgr* rmgr_ = nullptr;
+ string container_;
+ string name_;
+ bool resource_is_private_to_kernel_ = false;
+};
+
+// Helper for kernels to obtain 'resource' from the
+// ctx->resource_manager().
+//
+// "input_name" specifies the kernel's ref input which gives a string
+// tensor with two elements, which specifies the container and
+// resource name.
+//
+// Returns OK if the resource is found and transfers one ref of
+// *resource to the caller. Otherwise, returns an error.
+template <typename T>
+Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
+ T** resource);
+
+// Implementation details below.
+
+template <typename T>
+void CheckDeriveFromResourceBase() {
+ static_assert(std::is_base_of<ResourceBase, T>::value,
+ "T must derive from ResourceBase");
+}
+
+template <typename T>
+Status ResourceMgr::Create(const string& container, const string& name,
+ T* resource) {
+ CheckDeriveFromResourceBase<T>();
+ CHECK(resource != nullptr);
+ return DoCreate(container, std::type_index(typeid(T)), name, resource);
+}
+
+template <typename T>
+Status ResourceMgr::Lookup(const string& container, const string& name,
+ T** resource) const {
+ CheckDeriveFromResourceBase<T>();
+ ResourceBase* found = nullptr;
+ Status s = DoLookup(container, std::type_index(typeid(T)), name, &found);
+ if (s.ok()) {
+ // It's safe to down cast 'found' to T* since
+ // typeid(T).hash_code() is part of the map key.
+ *resource = static_cast<T*>(found);
+ }
+ return s;
+}
+
+template <typename T>
+Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
+ T** resource,
+ std::function<Status(T**)> creator) {
+ Status s;
+ *resource = nullptr;
+ while (*resource == nullptr) {
+ s = Lookup(container, name, resource);
+ if (s.ok()) break;
+ s = creator(resource);
+ if (!s.ok()) break;
+ s = Create(container, name, *resource);
+ if (s.ok()) {
+ (*resource)->Ref();
+ break;
+ }
+ // Rare event. Concurrent racy creation. Redo the lookup.
+ *resource = nullptr;
+ }
+ return s;
+}
+
+template <typename T>
+Status ResourceMgr::Delete(const string& container, const string& name) {
+ CheckDeriveFromResourceBase<T>();
+ return DoDelete(container, std::type_index(typeid(T)), name);
+}
+
+template <typename T>
+Status GetResourceFromContext(OpKernelContext* ctx, const string& input_name,
+ T** resource) {
+ string container;
+ string shared_name;
+ {
+ mutex* mu;
+ TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
+ mutex_lock l(*mu);
+ Tensor tensor;
+ TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
+ if (tensor.NumElements() != 2) {
+ return errors::InvalidArgument(
+ "Resource handle must have 2 elements, but had shape: ",
+ tensor.shape().DebugString());
+ }
+ container = tensor.flat<string>()(0);
+ shared_name = tensor.flat<string>()(1);
+ }
+ return ctx->resource_manager()->Lookup(container, shared_name, resource);
+}
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_
diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc
new file mode 100644
index 0000000000..9f7ce3dde3
--- /dev/null
+++ b/tensorflow/core/framework/resource_mgr_test.cc
@@ -0,0 +1,173 @@
+#include "tensorflow/core/framework/resource_mgr.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+
+class Resource : public ResourceBase {
+ public:
+ explicit Resource(const string& label) : label_(label) {}
+ ~Resource() override {}
+
+ string DebugString() { return strings::StrCat("R/", label_); }
+
+ private:
+ string label_;
+};
+
+class Other : public ResourceBase {
+ public:
+ explicit Other(const string& label) : label_(label) {}
+ ~Other() override {}
+
+ string DebugString() { return strings::StrCat("O/", label_); }
+
+ private:
+ string label_;
+};
+
+template <typename T>
+string Find(const ResourceMgr& rm, const string& container,
+ const string& name) {
+ T* r;
+ TF_CHECK_OK(rm.Lookup(container, name, &r));
+ const string ret = r->DebugString();
+ r->Unref();
+ return ret;
+}
+
+template <typename T>
+string LookupOrCreate(ResourceMgr* rm, const string& container,
+ const string& name, const string& label) {
+ T* r;
+ TF_CHECK_OK(rm->LookupOrCreate<T>(container, name, &r, [&label](T** ret) {
+ *ret = new T(label);
+ return Status::OK();
+ }));
+ const string ret = r->DebugString();
+ r->Unref();
+ return ret;
+}
+
+static void HasError(const Status& s, const string& substr) {
+ EXPECT_TRUE(StringPiece(s.ToString()).contains(substr))
+ << s << ", expected substring " << substr;
+}
+
+template <typename T>
+Status FindErr(const ResourceMgr& rm, const string& container,
+ const string& name) {
+ T* r;
+ Status s = rm.Lookup(container, name, &r);
+ CHECK(!s.ok());
+ return s;
+}
+
+TEST(ResourceMgrTest, Basic) {
+ ResourceMgr rm;
+ TF_CHECK_OK(rm.Create("foo", "bar", new Resource("cat")));
+ TF_CHECK_OK(rm.Create("foo", "baz", new Resource("dog")));
+ TF_CHECK_OK(rm.Create("foo", "bar", new Other("tiger")));
+
+ // Expected to fail.
+ HasError(rm.Create("foo", "bar", new Resource("kitty")),
+ "Already exists: Resource foo/bar");
+
+ // Expected to be found.
+ EXPECT_EQ("R/cat", Find<Resource>(rm, "foo", "bar"));
+ EXPECT_EQ("R/dog", Find<Resource>(rm, "foo", "baz"));
+ EXPECT_EQ("O/tiger", Find<Other>(rm, "foo", "bar"));
+
+ // Expected to be not found.
+ HasError(FindErr<Resource>(rm, "bar", "foo"), "Not found: Container bar");
+ HasError(FindErr<Resource>(rm, "foo", "xxx"), "Not found: Resource foo/xxx");
+ HasError(FindErr<Other>(rm, "foo", "baz"), "Not found: Resource foo/baz");
+
+ // Delete foo/bar/Resource.
+ TF_CHECK_OK(rm.Delete<Resource>("foo", "bar"));
+ HasError(FindErr<Resource>(rm, "foo", "bar"), "Not found: Resource foo/bar");
+
+ TF_CHECK_OK(rm.Create("foo", "bar", new Resource("kitty")));
+ EXPECT_EQ("R/kitty", Find<Resource>(rm, "foo", "bar"));
+
+ // Drop the whole container foo.
+ TF_CHECK_OK(rm.Cleanup("foo"));
+ HasError(FindErr<Resource>(rm, "foo", "bar"), "Not found: Container foo");
+}
+
+TEST(ResourceMgr, CreateOrLookup) {
+ ResourceMgr rm;
+ EXPECT_EQ("R/cat", LookupOrCreate<Resource>(&rm, "foo", "bar", "cat"));
+ EXPECT_EQ("R/cat", LookupOrCreate<Resource>(&rm, "foo", "bar", "dog"));
+ EXPECT_EQ("R/cat", Find<Resource>(rm, "foo", "bar"));
+
+ EXPECT_EQ("O/tiger", LookupOrCreate<Other>(&rm, "foo", "bar", "tiger"));
+ EXPECT_EQ("O/tiger", LookupOrCreate<Other>(&rm, "foo", "bar", "lion"));
+ TF_CHECK_OK(rm.Delete<Other>("foo", "bar"));
+ HasError(FindErr<Other>(rm, "foo", "bar"), "Not found: Resource foo/bar");
+}
+
+Status ComputePolicy(const string& attr_container,
+ const string& attr_shared_name,
+ bool use_node_name_as_default, string* result) {
+ ContainerInfo cinfo;
+ ResourceMgr rmgr;
+ NodeDef ndef;
+ ndef.set_name("foo");
+ if (attr_container != "none") {
+ AddNodeAttr("container", attr_container, &ndef);
+ }
+ if (attr_shared_name != "none") {
+ AddNodeAttr("shared_name", attr_shared_name, &ndef);
+ }
+ TF_RETURN_IF_ERROR(cinfo.Init(&rmgr, ndef, use_node_name_as_default));
+ *result = cinfo.DebugString();
+ return Status::OK();
+}
+
+string Policy(const string& attr_container, const string& attr_shared_name,
+ bool use_node_name_as_default) {
+ string ret;
+ TF_CHECK_OK(ComputePolicy(attr_container, attr_shared_name,
+ use_node_name_as_default, &ret));
+ return ret;
+}
+
+TEST(ContainerInfo, Basic) {
+ // Correct cases.
+ EXPECT_EQ(Policy("", "", false), "[localhost,_0_foo,private]");
+ EXPECT_EQ(Policy("", "", true), "[localhost,foo,public]");
+ EXPECT_EQ(Policy("", "bar", false), "[localhost,bar,public]");
+ EXPECT_EQ(Policy("", "bar", true), "[localhost,bar,public]");
+ EXPECT_EQ(Policy("cat", "", false), "[cat,_1_foo,private]");
+ EXPECT_EQ(Policy("cat", "", true), "[cat,foo,public]");
+ EXPECT_EQ(Policy("cat", "bar", false), "[cat,bar,public]");
+ EXPECT_EQ(Policy("cat", "bar", true), "[cat,bar,public]");
+}
+
+Status WrongPolicy(const string& attr_container, const string& attr_shared_name,
+ bool use_node_name_as_default) {
+ string dbg;
+ auto s = ComputePolicy(attr_container, attr_shared_name,
+ use_node_name_as_default, &dbg);
+ CHECK(!s.ok());
+ return s;
+}
+
+TEST(ContainerInfo, Error) {
+ // Missing attribute.
+ HasError(WrongPolicy("none", "", false), "No attr");
+ HasError(WrongPolicy("", "none", false), "No attr");
+ HasError(WrongPolicy("none", "none", false), "No attr");
+
+ // Invalid container.
+ HasError(WrongPolicy("12$%", "", false), "container contains invalid char");
+
+ // Invalid shared name.
+ HasError(WrongPolicy("", "_foo", false), "shared_name cannot start with '_'");
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/step_stats.proto b/tensorflow/core/framework/step_stats.proto
new file mode 100644
index 0000000000..78610350ec
--- /dev/null
+++ b/tensorflow/core/framework/step_stats.proto
@@ -0,0 +1,58 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/tensor_description.proto";
+
+// TODO(tucker): The next 4 message defs are very similar to
+// the *LogEntry messages in profile.proto. They should be
+// unified in one place.
+
+message AllocatorMemoryUsed {
+ string allocator_name = 1;
+ int64 total_bytes = 2;
+ int64 peak_bytes = 3;
+}
+
+enum AllocationType {
+ AT_NOTUSED = 0; // tensor was not filled in
+ AT_ALLOCATED = 1; // tensor was allocated by the Op
+ AT_EXISTING = 2; // tensor was set to share the value of an existing tensor
+ AT_REF = 3; // tensor was set to be a reference to an existing tensor
+}
+
+// Output sizes recorded for a single execution of a graph node.
+message NodeOutput {
+ int32 slot = 1;
+ // Was the tensor allocated by this Op or a previous computation
+ AllocationType allocation_type = 2;
+ TensorDescription tensor_description = 3;
+};
+
+// Time/size stats recorded for a single execution of a graph node.
+message NodeExecStats {
+ // TODO(tucker): Use some more compact form of node identity than
+ // the full string name. Either all processes should agree on a
+ // global id (cost_id?) for each node, or we should use a hash of
+ // the name.
+ string node_name = 1;
+ int64 all_start_micros = 2;
+ int64 op_start_rel_micros = 3;
+ int64 op_end_rel_micros = 4;
+ int64 all_end_rel_micros = 5;
+ repeated AllocatorMemoryUsed memory = 6;
+ repeated NodeOutput output = 7;
+ string timeline_label = 8;
+ int64 scheduled_micros = 9;
+ uint32 thread_id = 10;
+};
+
+message DeviceStepStats {
+ string device = 1;
+ repeated NodeExecStats node_stats = 2;
+}
+
+message StepStats {
+ repeated DeviceStepStats dev_stats = 1;
+};
diff --git a/tensorflow/core/framework/summary.proto b/tensorflow/core/framework/summary.proto
new file mode 100644
index 0000000000..0e6e659f2f
--- /dev/null
+++ b/tensorflow/core/framework/summary.proto
@@ -0,0 +1,67 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+// Serialization format for histogram module in
+// core/lib/histogram/histogram.h
+message HistogramProto {
+ double min = 1;
+ double max = 2;
+ double num = 3;
+ double sum = 4;
+ double sum_squares = 5;
+
+ // Parallel arrays encoding the bucket boundaries and the bucket values.
+ // bucket(i) is the count for the bucket i. The range for
+ // a bucket is:
+ // i == 0: -DBL_MAX .. bucket_limit(0)
+ // i != 0: bucket_limit(i-1) .. bucket_limit(i)
+ repeated double bucket_limit = 6 [packed = true];
+ repeated double bucket = 7 [packed = true];
+};
+
+// A Summary is a set of named values to be displayed by the
+// visualizer.
+//
+// Summaries are produced regularly during training, as controlled by
+// the "summary_interval_secs" attribute of the training operation.
+// Summaries are also produced at the end of an evaluation.
+message Summary {
+ message Image {
+ // Dimensions of the image.
+ int32 height = 1;
+ int32 width = 2;
+ // Valid colorspace values are
+ // 1 - grayscale
+ // 2 - grayscale + alpha
+ // 3 - RGB
+ // 4 - RGBA
+ // 5 - DIGITAL_YUV
+ // 6 - BGRA
+ int32 colorspace = 3;
+ // Image data in encoded format. All image formats supported by
+ // image_codec::CoderUtil can be stored here.
+ bytes encoded_image_string = 4;
+ }
+
+ message Value {
+ // Tag name for the data. Will be used as the title of the graph
+ // in the visualizer.
+ //
+ // Tag is usually "op_name:value_name", where "op_name" itself can have
+ // structure to indicate grouping.
+ string tag = 1;
+
+ // Value associated with the tag.
+ oneof value {
+ float simple_value = 2;
+ bytes obsolete_old_style_histogram = 3;
+ Image image = 4;
+ HistogramProto histo = 5;
+ }
+ }
+
+ // Set of values for the summary.
+ repeated Value value = 1;
+}
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
new file mode 100644
index 0000000000..4a1b65db97
--- /dev/null
+++ b/tensorflow/core/framework/tensor.cc
@@ -0,0 +1,570 @@
+// Implementation notes:
+//
+// Tensor.cc uses a few templated classes and structs to facilitate
+// implementation of the Tensor class.
+//
+// * Buffer<T>: provides the implementation for a typed array T[n].
+// The array is allocated by the given allocator. It runs T's
+// default constructors and destructors when T is not a simple type
+// (e.g., string.), and skips them otherwise.
+//
+// * Helper<T>: provides various routines given type T. The routines
+// includes running the constructor and destructor of T[], encoding
+// an decoding T[] into/from a Cord, etc.
+
+#include "tensorflow/core/public/tensor.h"
+
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/type_traits.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/tensor_coding.h"
+
+namespace tensorflow {
+namespace {
+
+// Typed ref-counted buffer: T[n].
+template <typename T>
+class Buffer : public TensorBuffer {
+ public:
+ Buffer(Allocator* a, int64 n);
+
+ void* data() const override { return data_; }
+ size_t size() const override { return sizeof(T) * elem_; }
+ TensorBuffer* root_buffer() override { return this; }
+ void FillAllocationDescription(AllocationDescription* proto) const override {
+ int64 rb = size();
+ proto->set_requested_bytes(rb);
+ proto->set_allocator_name(alloc_->Name());
+ if (alloc_->TracksAllocationSizes()) {
+ int64 ab = alloc_->AllocatedSize(data_);
+ proto->set_allocated_bytes(ab);
+ }
+ }
+
+ private:
+ Allocator* alloc_;
+ T* data_;
+ int64 elem_;
+
+ ~Buffer() override;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Buffer);
+};
+
+// is_simple<T>::value if T[] can be safely constructed and destructed
+// without running T() and ~T(). We do not use std::is_trivial<T>
+// directly because std::complex<float> is not trival but its array
+// can be constructed and destructed without running its default ctor
+// and dtor.
+template <typename T>
+struct is_simple {
+ static const bool value = std::is_trivial<T>::value ||
+ std::is_same<T, complex64>::value ||
+ is_quantized<T>::value;
+};
+
+template <>
+struct is_simple<bfloat16> {
+ static const bool value = true;
+};
+
+// A set of helper functions depending on T.
+template <typename T>
+struct Helper {
+ // By default, we assume T is a simple type (float, int32, etc.)
+ static_assert(is_simple<T>::value, "T is not a simple type.");
+ typedef protobuf::RepeatedField<T> RepeatedFieldType;
+
+ // No constructor to run.
+ static void RunCtor(T* p, int n) {}
+
+ // No destructor to run.
+ static void RunDtor(T* p, int n) {}
+
+ // Encoder of simple type T to a string. We do a copy.
+ template <typename Destination>
+ static void Encode(TensorBuffer* in, int64 n, Destination* out) {
+ DCHECK_EQ(in->size(), sizeof(T) * n);
+ port::AssignRefCounted(StringPiece(in->base<const char>(), in->size()), in,
+ out);
+ }
+
+ // Decoder of simple type T. Copy the bytes from "in" into the
+ // tensor buffer.
+ template <typename Source>
+ static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
+ if (in.size() != sizeof(T) * n) {
+ LOG(ERROR) << "Input size was " << in.size() << " and expected "
+ << sizeof(T) * n;
+ return nullptr;
+ }
+ Buffer<T>* buf = new Buffer<T>(a, n);
+ port::CopyToArray(in, buf->template base<char>());
+ return buf;
+ }
+
+ // Memory usage.
+ static int64 TotalBytes(TensorBuffer* in, int64 n) {
+ DCHECK_EQ(in->size(), sizeof(T) * n);
+ return in->size();
+ }
+};
+
+// Helper specialization for string (the only non-simple type we
+// support).
+template <>
+struct Helper<string> {
+ // Proto message uses RepeatedFieldType to hold repeated T.
+ typedef protobuf::RepeatedPtrField<string> RepeatedFieldType;
+
+ // Runs string's default constructor for p[0], p[1], ..., p[n-1].
+ static void RunCtor(string* p, int n) {
+ for (int i = 0; i < n; ++p, ++i) new (p) string();
+ }
+
+ // Runs T's default destructor for p[0], p[1], ..., p[n-1].
+ static void RunDtor(string* p, int n) {
+ for (int i = 0; i < n; ++p, ++i) p->~string();
+ }
+
+ // Encodes "n" elements of type string stored in "in" into Cord
+ // "out", which is usually the TensorProto::tensor_content.
+ template <typename Destination>
+ static void Encode(TensorBuffer* in, int64 n, Destination* out) {
+ port::EncodeStringList(in->base<const string>(), n, out);
+ }
+
+ // Decodes "n" elements of type string from "in" and constructs a
+ // buffer out of it. Returns nullptr if the decoding fails. "in" is
+ // usually the TensorProto::tensor_content.
+ template <typename Source>
+ static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
+ Buffer<string>* buf = new Buffer<string>(a, n);
+ string* strings = buf->template base<string>();
+ if (port::DecodeStringList(in, strings, n)) {
+ return buf;
+ } else {
+ buf->Unref();
+ return nullptr;
+ }
+ }
+
+ // Returns the estimated memory usage of "n" elements of type T
+ // stored in buffer "in".
+ static int64 TotalBytes(TensorBuffer* in, int n) {
+ int64 tot = in->size();
+ DCHECK_EQ(tot, sizeof(string) * n);
+ const string* p = in->base<const string>();
+ for (int i = 0; i < n; ++i, ++p) tot += p->size();
+ return tot;
+ }
+};
+
+template <typename T>
+struct ProtoHelper {};
+
+// For a C++ type "T" (float, double, int32, etc.), the repeated field
+// "N"_val (float_val, int_val, label_val, etc.) of type "F" (float,
+// int32, string, etc) in the TensorProto is used for serializing the
+// tensor of type "T".
+#define PROTO_TRAITS(T, F, N) \
+ template <> \
+ struct ProtoHelper<T> { \
+ typedef Helper<F>::RepeatedFieldType FieldType; \
+ static FieldType::const_iterator Begin(const TensorProto& proto) { \
+ return proto.N##_val().begin(); \
+ } \
+ static size_t NumElements(const TensorProto& proto) { \
+ return proto.N##_val().size(); \
+ } \
+ static void Fill(const T* data, size_t n, TensorProto* proto) { \
+ typename ProtoHelper<T>::FieldType copy(data, data + n); \
+ proto->mutable_##N##_val()->Swap(&copy); \
+ } \
+ };
+PROTO_TRAITS(float, float, float);
+PROTO_TRAITS(double, double, double);
+PROTO_TRAITS(int32, int32, int);
+PROTO_TRAITS(uint8, int32, int);
+PROTO_TRAITS(int16, int32, int);
+PROTO_TRAITS(int8, int32, int);
+PROTO_TRAITS(int64, int64, int64);
+PROTO_TRAITS(bool, bool, bool);
+PROTO_TRAITS(string, string, string);
+PROTO_TRAITS(qint8, int32, int);
+PROTO_TRAITS(quint8, int32, int);
+#undef PROTO_TRAITS
+
+template <>
+struct ProtoHelper<complex64> {
+ typedef Helper<float>::RepeatedFieldType FieldType;
+ static const complex64* Begin(const TensorProto& proto) {
+ return reinterpret_cast<const complex64*>(proto.scomplex_val().data());
+ }
+ static size_t NumElements(const TensorProto& proto) {
+ return proto.scomplex_val().size() / 2;
+ }
+ static void Fill(const complex64* data, size_t n, TensorProto* proto) {
+ const float* p = reinterpret_cast<const float*>(data);
+ FieldType copy(p, p + n * 2);
+ proto->mutable_scomplex_val()->Swap(&copy);
+ }
+};
+
+template <>
+struct ProtoHelper<qint32> {
+ typedef Helper<int32>::RepeatedFieldType FieldType;
+ static const qint32* Begin(const TensorProto& proto) {
+ return reinterpret_cast<const qint32*>(proto.int_val().data());
+ }
+ static size_t NumElements(const TensorProto& proto) {
+ return proto.int_val().size();
+ }
+ static void Fill(const qint32* data, size_t n, TensorProto* proto) {
+ const int32* p = reinterpret_cast<const int32*>(data);
+ FieldType copy(p, p + n);
+ proto->mutable_int_val()->Swap(&copy);
+ }
+};
+
+template <>
+struct ProtoHelper<bfloat16> {
+ typedef Helper<float>::RepeatedFieldType FieldType;
+ static const bfloat16* Begin(const TensorProto& proto) {
+ return reinterpret_cast<const bfloat16*>(proto.int_val().data());
+ }
+ static size_t NumElements(const TensorProto& proto) {
+ return proto.int_val().size();
+ }
+ static void Fill(const bfloat16* data, size_t n, TensorProto* proto) {
+ proto->mutable_int_val()->Reserve(n);
+ for (size_t i = 0; i < n; ++i) {
+ proto->mutable_int_val()->AddAlreadyReserved(data[i].value);
+ }
+ }
+};
+
+template <typename T>
+Buffer<T>::Buffer(Allocator* a, int64 n)
+ : alloc_(a), data_(a->Allocate<T>(n)), elem_(n) {
+ if (data_) Helper<T>::RunCtor(data_, elem_);
+}
+
+template <typename T>
+Buffer<T>::~Buffer() {
+ if (data_) {
+ Helper<T>::RunDtor(data_, elem_);
+ alloc_->Deallocate<T>(data_);
+ }
+}
+
+// Allocates a T[n] buffer. Fills in the buffer with repeated values
+// in "in". If "in" has less values than "n", fills the rest of T[n]
+// with the last value. If "in" has no values, fills T[n] with the
+// default value for T.
+//
+// This routine is using the typed fields (float_val, etc.) in the
+// tenor proto as opposed to the untyped binary representation
+// (tensor_content). This is used when we expect the TensorProto is
+// used by a client program which may not know how to encode a tensor
+// in the compact binary representation.
+template <typename T>
+TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) {
+ CHECK_GT(n, 0);
+ Buffer<T>* buf = new Buffer<T>(a, n);
+ T* data = buf->template base<T>();
+ const int64 in_n = ProtoHelper<T>::NumElements(in);
+ auto begin = ProtoHelper<T>::Begin(in);
+ if (n <= in_n) {
+ std::copy_n(begin, n, data);
+ } else if (in_n > 0) {
+ std::copy_n(begin, in_n, data);
+ const T& last = *(data + in_n - 1);
+ std::fill_n(data + in_n, n - in_n, last);
+ } else {
+ std::fill_n(data, n, T());
+ }
+ return buf;
+}
+
+// Copies T[n] stored in the buffer "in" into the repeated field in
+// "out" corresponding to type T.
+template <typename T>
+void ToProtoField(const TensorBuffer& in, int64 n, TensorProto* out) {
+ const T* data = in.base<const T>();
+ // NOTE: T may not the same as
+ // ProtoHelper<T>::FieldType::value_type. E.g., T==int16,
+ // ProtoHelper<T>::FieldType::value_type==int32. If performance is
+ // critical, we can specialize T=float and do memcpy directly.
+ ProtoHelper<T>::Fill(data, n, out);
+}
+
+void RefIfNonNull(core::RefCounted* buf) {
+ if (buf) buf->Ref();
+}
+
+void UnrefIfNonNull(core::RefCounted* buf) {
+ if (buf) buf->Unref();
+}
+
+} // end namespace
+
+Tensor::Tensor() : Tensor(DT_FLOAT) {}
+
+Tensor::Tensor(DataType type) : type_(type), shape_({0}), buf_(nullptr) {}
+
+Tensor::Tensor(const Tensor& other)
+ : type_(other.dtype()), shape_(other.shape()), buf_(other.buf_) {
+ RefIfNonNull(buf_);
+}
+
+Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
+ : type_(type), shape_(shape), buf_(buf) {
+ RefIfNonNull(buf);
+}
+
+bool Tensor::IsInitialized() const {
+ return buf_ != nullptr && buf_->data() != nullptr;
+}
+
+Tensor::~Tensor() { UnrefIfNonNull(buf_); }
+
+void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) {
+ CHECK_EQ(shape.num_elements(), other.NumElements());
+ type_ = other.dtype();
+ shape_ = shape;
+ if (buf_ != other.buf_) {
+ UnrefIfNonNull(buf_);
+ buf_ = other.buf_;
+ RefIfNonNull(buf_);
+ }
+}
+
+// The macro CASES() expands to a switch statement conditioned on
+// TYPE_ENUM. Each case expands the STMTS after a typedef for T.
+#define SINGLE_ARG(...) __VA_ARGS__
+#define CASE(TYPE, STMTS) \
+ case DataTypeToEnum<TYPE>::value: { \
+ typedef TYPE T; \
+ STMTS; \
+ break; \
+ }
+#define CASES(TYPE_ENUM, STMTS) \
+ switch (TYPE_ENUM) { \
+ CASE(float, SINGLE_ARG(STMTS)) \
+ CASE(double, SINGLE_ARG(STMTS)) \
+ CASE(int32, SINGLE_ARG(STMTS)) \
+ CASE(uint8, SINGLE_ARG(STMTS)) \
+ CASE(int16, SINGLE_ARG(STMTS)) \
+ CASE(int8, SINGLE_ARG(STMTS)) \
+ CASE(string, SINGLE_ARG(STMTS)) \
+ CASE(complex64, SINGLE_ARG(STMTS)) \
+ CASE(int64, SINGLE_ARG(STMTS)) \
+ CASE(bool, SINGLE_ARG(STMTS)) \
+ CASE(qint32, SINGLE_ARG(STMTS)) \
+ CASE(quint8, SINGLE_ARG(STMTS)) \
+ CASE(qint8, SINGLE_ARG(STMTS)) \
+ CASE(bfloat16, SINGLE_ARG(STMTS)) \
+ case DT_INVALID: \
+ LOG(FATAL) << "Type not set"; \
+ break; \
+ default: \
+ LOG(FATAL) << "Unexpected type: " << TYPE_ENUM; \
+ break; \
+ }
+
+Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
+ : type_(type), shape_(shape), buf_(nullptr) {
+ CHECK_NOTNULL(a);
+ if (shape_.num_elements() > 0) {
+ CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
+ }
+}
+
+Tensor::Tensor(DataType type, const TensorShape& shape)
+ : Tensor(cpu_allocator(), type, shape) {}
+
+template <typename T>
+class SubBuffer : public TensorBuffer {
+ public:
+ // This buffer is an alias to buf[delta, delta + n).
+ SubBuffer(TensorBuffer* buf, int64 delta, int64 n)
+ : root_(buf->root_buffer()), data_(buf->base<T>() + delta), elem_(n) {
+ // Sanity check. The caller should ensure the sub buffer is valid.
+ CHECK_LE(root_->base<T>(), this->base<T>());
+ T* root_limit = root_->base<T>() + root_->size() / sizeof(T);
+ CHECK_LE(this->base<T>(), root_limit);
+ CHECK_LE(this->base<T>() + n, root_limit);
+ // Hold a ref of the underlying root buffer.
+ // NOTE: 'buf' is a sub-buffer inside the 'root_' buffer.
+ root_->Ref();
+ }
+
+ void* data() const override { return data_; }
+ size_t size() const override { return sizeof(T) * elem_; }
+ TensorBuffer* root_buffer() override { return root_; }
+ void FillAllocationDescription(AllocationDescription* proto) const override {
+ root_->FillAllocationDescription(proto);
+ }
+
+ private:
+ TensorBuffer* root_;
+ T* data_;
+ int64 elem_;
+
+ ~SubBuffer() override { root_->Unref(); }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SubBuffer);
+};
+
+Tensor Tensor::Slice(int64 start, int64 limit) const {
+ CHECK_GE(dims(), 1);
+ CHECK_LE(0, start);
+ CHECK_LE(start, limit);
+ int64 dim0_size = shape_.dim_size(0);
+ CHECK_LE(limit, dim0_size);
+ if ((start == 0) && (limit == dim0_size)) {
+ return *this;
+ }
+ Tensor ret;
+ ret.type_ = type_;
+ ret.shape_ = shape_;
+ ret.buf_ = nullptr;
+ if (dim0_size > 0) {
+ const int64 elems_per_dim0 = NumElements() / dim0_size;
+ const int64 delta = start * elems_per_dim0;
+ dim0_size = limit - start;
+ ret.shape_.set_dim(0, dim0_size);
+ const int64 num_elems = dim0_size * elems_per_dim0;
+ if (buf_) {
+ CASES(type_, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
+ }
+ }
+ return ret;
+}
+
+bool Tensor::FromProto(const TensorProto& proto) {
+ return FromProto(cpu_allocator(), proto);
+}
+
+bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
+ CHECK_NOTNULL(a);
+ TensorBuffer* p = nullptr;
+ if (!TensorShape::IsValid(proto.tensor_shape())) return false;
+ if (proto.dtype() == DT_INVALID) return false;
+ TensorShape shape(proto.tensor_shape());
+ const int64 N = shape.num_elements();
+ if (N > 0 && proto.dtype()) {
+ if (!proto.tensor_content().empty()) {
+ const auto& content = proto.tensor_content();
+ CASES(proto.dtype(), p = Helper<T>::Decode(a, content, N));
+ } else {
+ CASES(proto.dtype(), p = FromProtoField<T>(a, proto, N));
+ }
+ if (p == nullptr) return false;
+ }
+ type_ = proto.dtype();
+ shape_ = shape;
+ UnrefIfNonNull(buf_);
+ buf_ = p;
+ return true;
+}
+
+void Tensor::AsProtoField(TensorProto* proto) const {
+ proto->Clear();
+ proto->set_dtype(dtype());
+ shape_.AsProto(proto->mutable_tensor_shape());
+ if (buf_) {
+ CASES(dtype(), ToProtoField<T>(*buf_, shape_.num_elements(), proto));
+ }
+}
+
+void Tensor::AsProtoTensorContent(TensorProto* proto) const {
+ proto->Clear();
+ proto->set_dtype(type_);
+ shape_.AsProto(proto->mutable_tensor_shape());
+ if (buf_) {
+ CASES(dtype(), Helper<T>::Encode(buf_, shape_.num_elements(),
+ proto->mutable_tensor_content()));
+ }
+}
+
+size_t Tensor::TotalBytes() const {
+ if (shape_.num_elements() == 0) return 0;
+ CHECK(buf_) << "null buf_ with non-zero shape size " << shape_.num_elements();
+ CASES(dtype(), return Helper<T>::TotalBytes(buf_, shape_.num_elements()));
+ return 0; // Makes compiler happy.
+}
+
+bool Tensor::CanUseDMA() const {
+ CASES(dtype(), return is_simple<T>::value);
+ return false; // Makes compiler happy.
+}
+
+#undef CASES
+#undef CASE
+
+string Tensor::SummarizeValue(int64 max_entries) const {
+ string ret;
+ for (int64 i = 0; i < std::min(max_entries, NumElements()); ++i) {
+ if (i > 0) strings::StrAppend(&ret, " ");
+ switch (dtype()) {
+ case DT_STRING:
+ strings::StrAppend(&ret, str_util::CEscape(flat<string>()(i)));
+ break;
+ case DT_BOOL:
+ strings::StrAppend(&ret, flat<bool>()(i) ? "True" : "False");
+ break;
+
+#define CASE(DT_ENUM) \
+ case DT_ENUM: \
+ strings::StrAppend(&ret, flat<EnumToDataType<DT_ENUM>::Type>()(i)); \
+ break
+
+ CASE(DT_FLOAT);
+ CASE(DT_DOUBLE);
+ CASE(DT_INT32);
+ CASE(DT_UINT8);
+ CASE(DT_INT16);
+ CASE(DT_INT8);
+ CASE(DT_INT64);
+
+#undef CASE
+ default:
+ // TODO(zhifengc, josh11b): Pretty-print other types (bool,
+ // complex64, quantized, bfloat16).
+ strings::StrAppend(&ret, " ?");
+ }
+ }
+ if (max_entries < NumElements()) strings::StrAppend(&ret, "...");
+
+ return ret;
+}
+
+StringPiece Tensor::tensor_data() const {
+ if (buf_ == nullptr) return StringPiece(); // Don't die for empty tensors
+ return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
+}
+
+string Tensor::DebugString() const {
+ return strings::StrCat("Tensor<type: ", DataTypeString(dtype()), " shape: ",
+ shape().ShortDebugString(), " values: ",
+ SummarizeValue(3), ">");
+}
+
+void Tensor::FillDescription(TensorDescription* description) const {
+ description->set_dtype(dtype());
+ shape().AsProto(description->mutable_shape());
+ buf_->FillAllocationDescription(
+ description->mutable_allocation_description());
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor.proto b/tensorflow/core/framework/tensor.proto
new file mode 100644
index 0000000000..b42694afde
--- /dev/null
+++ b/tensorflow/core/framework/tensor.proto
@@ -0,0 +1,57 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/types.proto";
+
+// Protocol buffer representing a tensor.
+message TensorProto {
+ DataType dtype = 1;
+
+ // Shape of the tensor. TODO(mdevin): sort out the 0-rank issues.
+ TensorShapeProto tensor_shape = 2;
+
+ // Only one of the representations below is set, one of "tensor_contents" and
+ // the "xxx_val" attributes. We are not using oneof because as oneofs cannot
+ // contain repeated fields it would require another extra set of messages.
+
+ // Version number.
+ //
+ // In version 0, if the "repeated xxx" representations contain only one
+ // element, that element is repeated to fill the shape. This makes it easy
+ // to represent a constant Tensor with a single value.
+ int32 version_number = 3;
+
+ // Serialized content from TensorBase::Serialize() This representation can be
+ // used for all tensor types.
+ bytes tensor_content = 4;
+
+ // Type specific representations that make it easy to create tensor protos in
+ // all languages. Only the representation corresponding to "dtype" can
+ // be set. The values hold the flattened representation of the tensor in
+ // row major order.
+
+ // DT_FLOAT.
+ repeated float float_val = 5 [packed = true];
+
+ // DT_DOUBLE.
+ repeated double double_val = 6 [packed = true];
+
+ // DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
+ repeated int32 int_val = 7 [packed = true];
+
+ // DT_STRING
+ repeated bytes string_val = 8;
+
+ // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
+ // and imaginary parts of i-th single precision complex.
+ repeated float scomplex_val = 9 [packed = true];
+
+ // DT_INT64
+ repeated int64 int64_val = 10 [packed = true];
+
+ // DT_BOOL
+ repeated bool bool_val = 11 [packed = true];
+};
diff --git a/tensorflow/core/framework/tensor_description.proto b/tensorflow/core/framework/tensor_description.proto
new file mode 100644
index 0000000000..1fff3ee155
--- /dev/null
+++ b/tensorflow/core/framework/tensor_description.proto
@@ -0,0 +1,19 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/types.proto";
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/allocation_description.proto";
+
+message TensorDescription {
+ // Data type of tensor elements
+ DataType dtype = 1;
+
+ // Shape of the tensor.
+ TensorShapeProto shape = 2;
+
+ // Information about the size and allocator used for the data
+ AllocationDescription allocation_description = 4;
+};
diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc
new file mode 100644
index 0000000000..3db2ffaaca
--- /dev/null
+++ b/tensorflow/core/framework/tensor_shape.cc
@@ -0,0 +1,138 @@
+#include "tensorflow/core/public/tensor_shape.h"
+
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+// An upper limit of the total number of elements in a tensor.
+static const int64 kMaxElements = (1LL << 40);
+
+bool TensorShape::IsValid(const TensorShapeProto& proto) {
+ int64 num_elements = 1;
+ for (const auto& d : proto.dim()) {
+ if (d.size() < 0) return false;
+ num_elements *= d.size();
+ if (num_elements > kMaxElements) return false;
+ }
+ return true;
+}
+
+TensorShape::TensorShape(const TensorShapeProto& proto) {
+ dim_sizes_.reserve(proto.dim_size());
+ num_elements_ = 1;
+ for (const auto& d : proto.dim()) {
+ AddDim(d.size());
+ }
+}
+
+TensorShape::TensorShape(gtl::ArraySlice<int64> dim_sizes) {
+ dim_sizes_.reserve(dim_sizes.size());
+ num_elements_ = 1;
+ for (auto s : dim_sizes) {
+ AddDim(s);
+ }
+}
+
+TensorShape::TensorShape() : num_elements_(1) {}
+
+void TensorShape::Clear() {
+ dim_sizes_.clear();
+ num_elements_ = 1;
+}
+
+void TensorShape::AddDim(int64 size) {
+ CHECK_GE(size, 0);
+ dim_sizes_.push_back(size);
+ num_elements_ *= size;
+ CHECK_LE(0, num_elements_);
+ CHECK_LE(num_elements_, kMaxElements);
+}
+
+void TensorShape::AppendShape(const TensorShape& shape) {
+ for (auto d : shape) AddDim(d.size);
+}
+
+void TensorShape::InsertDim(int d, int64 size) {
+ CHECK_GE(d, 0);
+ CHECK_LE(d, dims());
+ CHECK_GE(size, 0);
+ dim_sizes_.insert(dim_sizes_.begin() + d, size);
+ num_elements_ *= size;
+ CHECK_LE(0, num_elements_);
+ CHECK_LE(num_elements_, kMaxElements);
+}
+
+void TensorShape::set_dim(int d, int64 size) {
+ CHECK_GE(d, 0);
+ CHECK_LT(d, dims());
+ CHECK_GE(size, 0);
+
+ // Update the number of elements. num_elements_ is int64.
+ dim_sizes_[d] = size;
+ recompute_dims();
+}
+
+void TensorShape::RemoveDim(int d) {
+ CHECK_GE(d, 0);
+ CHECK_LT(d, dims());
+
+ // Update the number of elements and remove the dimension from the
+ // sizes.
+ dim_sizes_.erase(dim_sizes_.begin() + d);
+ recompute_dims();
+}
+
+void TensorShape::recompute_dims() {
+ num_elements_ = 1;
+ for (auto s : dim_sizes_) {
+ num_elements_ *= s;
+ CHECK_LE(0, num_elements_);
+ CHECK_LE(num_elements_, kMaxElements);
+ }
+}
+
+bool TensorShape::IsSameSize(const TensorShape& b) const {
+ if (b.dims() != dims()) return false;
+ for (int d = 0; d < dims(); d++) {
+ if (dim_size(d) != b.dim_size(d)) return false;
+ }
+ return true;
+}
+
+void TensorShape::AsProto(TensorShapeProto* proto) const {
+ proto->Clear();
+ for (size_t d = 0; d < dim_sizes_.size(); ++d) {
+ auto* dim = proto->add_dim();
+ dim->set_size(dim_sizes_[d]);
+ }
+}
+
+TensorShapeIter TensorShape::begin() const { return TensorShapeIter(this, 0); }
+
+TensorShapeIter TensorShape::end() const {
+ return TensorShapeIter(this, dims());
+}
+
+string TensorShape::DebugString() const {
+ TensorShapeProto proto;
+ AsProto(&proto);
+ return proto.ShortDebugString();
+}
+
+string TensorShape::ShortDebugString() const {
+ return strings::StrCat(
+ "[", str_util::Join(gtl::ArraySlice<int64>(dim_sizes_), ","), "]");
+}
+
+bool TensorShapeUtils::StartsWith(const TensorShape& shape,
+ const TensorShape& prefix) {
+ if (shape.dims() < prefix.dims()) return false;
+ for (int i = 0; i < prefix.dims(); i++) {
+ if (shape.dim_size(i) != prefix.dim_size(i)) return false;
+ }
+ return true;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_shape.proto b/tensorflow/core/framework/tensor_shape.proto
new file mode 100644
index 0000000000..8fe7cce13d
--- /dev/null
+++ b/tensorflow/core/framework/tensor_shape.proto
@@ -0,0 +1,29 @@
+// Protocol buffer representing the shape of tensors.
+
+syntax = "proto3";
+// option cc_enable_arenas = true;
+
+package tensorflow;
+
+// Dimensions of a tensor and the type of data it contains.
+message TensorShapeProto {
+ // One dimension of the tensor.
+ message Dim {
+ // Size of the tensor in that dimension.
+ int64 size = 1;
+
+ // Optional name of the tensor dimension.
+ string name = 2;
+ };
+
+ // Dimensions of the tensor, such as {"input", 30}, {"output", 40} for a 30 x
+ // 40 2D tensor. The names are optional.
+ //
+ // The order of entries in "dim" matters: It indicates the layout of the
+ // values in the tensor in-memory representation.
+ //
+ // The first entry in "dim" is the outermost dimension used to layout the
+ // values, the last entry is the innermost dimension. This matches the
+ // in-memory layout of RowMajor Eigen tensors.
+ repeated Dim dim = 2;
+};
diff --git a/tensorflow/core/framework/tensor_shape_test.cc b/tensorflow/core/framework/tensor_shape_test.cc
new file mode 100644
index 0000000000..adac1a4787
--- /dev/null
+++ b/tensorflow/core/framework/tensor_shape_test.cc
@@ -0,0 +1,75 @@
+#include "tensorflow/core/public/tensor_shape.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+TEST(TensorShapeTest, Default) {
+ // The default TensorShape constructor constructs a shape of 0-dim
+ // and 1-element.
+ TensorShape s;
+ EXPECT_EQ(s.dims(), 0);
+ EXPECT_EQ(s.num_elements(), 1);
+}
+
+TEST(TensorShapeTest, set_dim) {
+ TensorShape s({10, 5});
+
+ s.set_dim(0, 20);
+ ASSERT_EQ(2, s.dims());
+ EXPECT_EQ(20, s.dim_size(0));
+ EXPECT_EQ(100, s.num_elements());
+
+ s.set_dim(1, 2);
+ ASSERT_EQ(2, s.dims());
+ EXPECT_EQ(2, s.dim_size(1));
+ EXPECT_EQ(40, s.num_elements());
+}
+
+TEST(TensorShapeTest, RemoveDim) {
+ TensorShape s({10, 5});
+ s.RemoveDim(0);
+ EXPECT_EQ(5, s.num_elements());
+ ASSERT_EQ(1, s.dims());
+}
+
+TEST(TensorShapeTest, RemoveAndAddDim) {
+ TensorShape s({10, 5, 20});
+ s.RemoveDim(1);
+ s.AddDim(100);
+
+ EXPECT_EQ(20000, s.num_elements());
+ ASSERT_EQ(3, s.dims());
+}
+
+TEST(TensorShapeTest, InvalidShapeProto) {
+ TensorShapeProto proto;
+ EXPECT_TRUE(TensorShape::IsValid(proto));
+
+ proto.add_dim()->set_size(357);
+ proto.add_dim()->set_size(982);
+ EXPECT_TRUE(TensorShape::IsValid(proto));
+
+ proto.Clear();
+ proto.add_dim()->set_size(-357);
+ proto.add_dim()->set_size(-982);
+ EXPECT_FALSE(TensorShape::IsValid(proto));
+
+ proto.Clear();
+ proto.add_dim()->set_size(1LL << 20);
+ proto.add_dim()->set_size((1LL << 20) + 1);
+ EXPECT_FALSE(TensorShape::IsValid(proto));
+}
+
+TEST(TensorShapeTest, SetDimForEmptyTensor) {
+ TensorShape s({10, 5, 20});
+ EXPECT_EQ(1000, s.num_elements());
+ s.set_dim(1, 0);
+ EXPECT_EQ(0, s.num_elements());
+ s.set_dim(1, 7);
+ EXPECT_EQ(1400, s.num_elements());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_slice.cc b/tensorflow/core/framework/tensor_slice.cc
new file mode 100644
index 0000000000..473d9463ee
--- /dev/null
+++ b/tensorflow/core/framework/tensor_slice.cc
@@ -0,0 +1,226 @@
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+
+TensorSlice::TensorSlice(int dim) { SetFullSlice(dim); }
+
+TensorSlice::TensorSlice(const TensorSliceProto& proto) {
+ starts_.reserve(proto.extent_size());
+ lengths_.reserve(proto.extent_size());
+ for (const auto& e : proto.extent()) {
+ starts_.push_back(e.start());
+ lengths_.push_back(GetExtentLength(e));
+ }
+}
+
+TensorSlice::TensorSlice(std::initializer_list<std::pair<int, int>> extents) {
+ starts_.reserve(extents.size());
+ lengths_.reserve(extents.size());
+ for (const auto& e : extents) {
+ starts_.push_back(e.first);
+ lengths_.push_back(e.second);
+ }
+}
+
+Status TensorSlice::Parse(const string& str, TensorSlice* slice) {
+ std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty());
+ slice->starts_.reserve(items.size());
+ slice->lengths_.reserve(items.size());
+ for (const string& x : items) {
+ int s, l;
+ if (x == "-") {
+ // "everything"
+ s = 0;
+ l = kFullExtent;
+ } else {
+ char junk;
+ if (sscanf(x.c_str(), "%d,%d%c", &s, &l, &junk) != 2) {
+ return errors::InvalidArgument(
+ "Expected a pair of numbers or '-' "
+ "but got '",
+ x, "': string = ", str);
+ }
+ if (s < 0 || l <= 0) {
+ return errors::InvalidArgument(
+ "Expected non-negative start and "
+ "positive length but got start = ",
+ s, ", length = ", l, ": string = ", str);
+ }
+ }
+ slice->starts_.push_back(s);
+ slice->lengths_.push_back(l);
+ }
+
+ return Status::OK();
+}
+
+void TensorSlice::Clear() {
+ starts_.clear();
+ lengths_.clear();
+}
+
+void TensorSlice::SetFullSlice(int dim) {
+ Clear();
+ starts_.reserve(dim);
+ lengths_.reserve(dim);
+ for (int d = 0; d < dim; ++d) {
+ starts_.push_back(0);
+ lengths_.push_back(kFullExtent);
+ }
+}
+
+void TensorSlice::Extend(int dim) {
+ int old_dim = dims();
+ DCHECK_LE(old_dim, dim);
+ starts_.resize(dim);
+ lengths_.resize(dim);
+ for (int d = old_dim; d < dim; ++d) {
+ starts_[d] = 0;
+ lengths_[d] = kFullExtent;
+ }
+}
+
+void TensorSlice::AsProto(TensorSliceProto* proto) const {
+ for (int d = 0; d < dims(); ++d) {
+ TensorSliceProto::Extent* e = proto->add_extent();
+ // We only need to record the explicit slice for non-full slices
+ if (!IsFullAt(d)) {
+ e->set_start(starts_[d]);
+ e->set_length(lengths_[d]);
+ }
+ }
+}
+
+string TensorSlice::DebugString() const {
+ string buffer;
+ bool first = true;
+ for (int d = 0; d < dims(); ++d) {
+ if (!first) {
+ buffer.append(":");
+ }
+ string s;
+ if (IsFullAt(d)) {
+ buffer.append("-");
+ } else {
+ strings::StrAppend(&buffer, starts_[d], ",", lengths_[d]);
+ }
+ first = false;
+ }
+ return buffer;
+}
+
+bool TensorSlice::Intersect(const TensorSlice& other,
+ TensorSlice* result) const {
+ // First, if two slices have different ranks, they obviously don't overlap
+ // -- in fact they are not compatible.
+ if (dims() != other.dims()) {
+ return false;
+ }
+
+ // Setting the result to the right dimension
+ if (result) {
+ result->SetFullSlice(dims());
+ }
+ // The two slices overlap if they overlap in all dimensions.
+ for (int d = 0; d < dims(); ++d) {
+ if (IsFullAt(d)) {
+ if (result) {
+ result->set_start(d, other.start(d));
+ result->set_length(d, other.length(d));
+ }
+ } else if (other.IsFullAt(d)) {
+ if (result) {
+ result->set_start(d, start(d));
+ result->set_length(d, length(d));
+ }
+ } else {
+ // If we have an intersection here, it should have a start that is the
+ // max of the two starts and an end that is the min of the two ends.
+ int s = std::max(start(d), other.start(d));
+ int l = std::min(end(d), other.end(d)) - s;
+ if (l > 0) {
+ // We have a real intersection
+ if (result) {
+ result->set_start(d, s);
+ result->set_length(d, l);
+ }
+ } else {
+ // We don't have an intersection for this dimension -- thus we don't
+ // have any intersection at all.
+ if (result) {
+ result->Clear();
+ }
+ return false;
+ }
+ }
+ }
+ // If we are here, we know there is overlap in every dimension.
+ return true;
+}
+
+void TensorSlice::ComputeRelative(const TensorSlice& sub,
+ TensorSlice* relative) const {
+ DCHECK_EQ(dims(), sub.dims());
+ relative->SetFullSlice(dims());
+ for (int d = 0; d < dims(); ++d) {
+ if (IsFullAt(d)) {
+ relative->set_start(d, sub.start(d));
+ relative->set_length(d, sub.length(d));
+ } else {
+ // Otherwise the relative start is the difference between the start of
+ // sub and the start of base
+ relative->set_start(d, sub.start(d) - start(d));
+ relative->set_length(d, sub.length(d));
+ }
+ }
+}
+
+// static
+bool TensorSlice::HasExtentLength(const TensorSliceProto::Extent& extent) {
+ return extent.has_length_case() == TensorSliceProto::Extent::kLength;
+}
+
+// static
+int64 TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) {
+ if (!HasExtentLength(extent)) return -1;
+ return extent.length();
+}
+
+Status TensorSlice::SliceTensorShape(const TensorShape& shape,
+ TensorShape* result_shape) const {
+ result_shape->Clear();
+ // Mismatching ranks: we can't apply the slice at all.
+ if (shape.dims() != dims()) {
+ return errors::Internal("Mismatching ranks: shape = ", shape.DebugString(),
+ ", slice = ", DebugString());
+ }
+ for (int d = 0; d < dims(); ++d) {
+ if (IsFullAt(d)) {
+ result_shape->AddDim(shape.dim_size(d));
+ } else {
+ // Check if the extent applies to the dimension
+ if (end(d) <= shape.dim_size(d)) {
+ // Yes: the end is within the range of the dim -- we adjust the result
+ // shape so that its size along this dimension is the length of the
+ // slice.
+ result_shape->AddDim(length(d));
+ } else {
+ // The extent doesn't apply to the dimension
+ result_shape->Clear();
+ return errors::Internal("Extent in dimension ", d,
+ " out of bounds: shape = ", shape.DebugString(),
+ ", slice = ", DebugString());
+ }
+ }
+ }
+ // If we are here, we have successfully applied the shape.
+ return Status::OK();
+}
+
+const int TensorSlice::kFullExtent = -1;
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_slice.h b/tensorflow/core/framework/tensor_slice.h
new file mode 100644
index 0000000000..8e2f108c3f
--- /dev/null
+++ b/tensorflow/core/framework/tensor_slice.h
@@ -0,0 +1,189 @@
+#ifndef TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
+#define TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
+
+#include <string>
+#include "tensorflow/core/framework/tensor_slice.pb.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// A tensor slice represents a slice of a given tensor. It is represented by a
+// list of (start, length) pairs, where the size of the list is the rank of the
+// tensor.
+
+class TensorSlice {
+ public:
+ // Construct a tensor slice: you have a number of ways:
+ // -- creating an empty slice
+ // -- from just a dimension (in this case it will create a full slice)
+ // -- from an array of pairs of integers.
+ // -- from a TensorSliceProto protocol buffer
+ // -- from a string format of "start,lenth:start,length..." where each
+ // "start,length" pair represents the slice on one dimension. We allow a
+ // special "-" that means "everything for this dimension". One such example
+ // is: 0,10:-:14,1:-:-
+ TensorSlice() {}
+ explicit TensorSlice(int dim);
+ explicit TensorSlice(const TensorSliceProto& proto);
+ explicit TensorSlice(std::initializer_list<std::pair<int, int>> extents);
+
+ static Status Parse(const string& str, TensorSlice* output);
+ static TensorSlice ParseOrDie(const string& str) {
+ TensorSlice ret;
+ Status s = Parse(str, &ret);
+ if (!s.ok()) {
+ LOG(FATAL) << "Could not parse TensorSlice";
+ }
+ return ret;
+ }
+
+ void Clear();
+
+ // Accessors
+ int dims() const { return starts_.size(); }
+
+ int start(int d) const {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ return starts_[d];
+ }
+
+ int length(int d) const {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ return lengths_[d];
+ }
+
+ int end(int d) const {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ return start(d) + length(d);
+ }
+
+ void set_start(int d, int x) {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ DCHECK_GE(x, 0);
+ starts_[d] = x;
+ }
+
+ void set_length(int d, int x) {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ lengths_[d] = x;
+ }
+
+ // If we have a full slice along dimension "d".
+ bool IsFullAt(int d) const { return lengths_[d] < 0; }
+
+ // Set the slice to be a full slice of "dim" dimensions
+ void SetFullSlice(int dim);
+
+ // Extend a slice to "dim" dimensions: all the added dimensions are full.
+ // Requires: dim >= dims().
+ void Extend(int dim);
+
+ // Conversion of a TensorSlice to other formats
+ void AsProto(TensorSliceProto* proto) const;
+ string DebugString() const;
+
+ // Fill *indices and *sizes from *this (so that we can use the slice()
+ // function in eigen tensor). We need a tensor shape in case some of the
+ // slices are full slices.
+ // We allow NDIMS to be greater than dims(), in which case we will pad the
+ // higher dimensions with trivial dimensions.
+ template <int NDIMS>
+ void FillIndicesAndSizes(const TensorShape& shape,
+ Eigen::DSizes<ptrdiff_t, NDIMS>* indices,
+ Eigen::DSizes<ptrdiff_t, NDIMS>* sizes) const;
+
+ // Interaction with other TensorSlices.
+
+ // Compute the intersection with another slice and if "result" is not
+ // nullptr, store the results in *result; returns true is there is any real
+ // intersection.
+ bool Intersect(const TensorSlice& other, TensorSlice* result) const;
+ // A short hand.
+ bool Overlaps(const TensorSlice& other) const {
+ return Intersect(other, nullptr);
+ }
+
+ // Interaction with TensorShape.
+
+ // Slices a shape and stores the result into *result_shape.
+ // Requires that the shape and *this have the same rank.
+ // For example, given a tensor shape of {3, 4, 5}, and a slice of
+ // 1,2:-:0,2, the result shape is {2, 4, 2}.
+ Status SliceTensorShape(const TensorShape& shape,
+ TensorShape* result_shape) const;
+
+ // Given slice "sub" where "sub" is fully contained in *this,
+ // (meaning that the intersection of "sub" and *this equals "sub"), computes
+ // the "relative" slice of "sub" with respect to *this.
+ //
+ // In other words, if we use A>S to denote slicing a shape S with a slice A,
+ // then the function is computing a slice X such that:
+ // X > (this > S) = sub > S
+ // for any shape S.
+ //
+ // In general, along every dimension, the start of the relative slice is the
+ // start of the "sub" slice minus the start of *this; the length of the
+ // relative slice is the length of the "sub" slice.
+ //
+ // For example, say we have a shape of {3, 4, 5}, "this" is 0,2:-:1,2, and
+ // "sub" is 1,1:2:2,1,2, then the related slice is 1,1:2,2:0,2.
+ //
+ // The caller needs to make sure that "sub" is indeed a sub-slice of *this;
+ // otherwise the result is undefined.
+ void ComputeRelative(const TensorSlice& sub, TensorSlice* relative) const;
+
+ // Returns true if the length field was specified in an Extent.
+ static bool HasExtentLength(const TensorSliceProto::Extent& extent);
+
+ // Returns the value of the length field in an Extent, or -1 if it
+ // is not present.
+ static int64 GetExtentLength(const TensorSliceProto::Extent& extent);
+
+ private:
+ // a length value of kFullExtent (-1) means we have a full slice at this
+ // dimension. It's defined in tensor_slice.cc.
+ static const int kFullExtent;
+
+ // TODO(yangke): switch to Eigen once it supports variable size arrays.
+ // A value of
+ gtl::InlinedVector<int, 4> starts_;
+ gtl::InlinedVector<int, 4> lengths_;
+};
+
+template <int NDIMS>
+void TensorSlice::FillIndicesAndSizes(
+ const TensorShape& shape, Eigen::DSizes<ptrdiff_t, NDIMS>* indices,
+ Eigen::DSizes<ptrdiff_t, NDIMS>* sizes) const {
+ CHECK_EQ(shape.dims(), dims()) << "Incompatible dimensions between shape "
+ << "slices: shape = " << shape.DebugString()
+ << ", slice = " << DebugString();
+ CHECK_GE(NDIMS, dims()) << "Asking for a " << NDIMS << "-dim slice from "
+ << "a slice of dimension " << dims();
+ for (int d = 0; d < dims(); ++d) {
+ if (IsFullAt(d)) {
+ (*indices)[d] = 0;
+ (*sizes)[d] = shape.dim_size(d);
+ } else {
+ (*indices)[d] = starts_[d];
+ (*sizes)[d] = lengths_[d];
+ }
+ }
+ for (int d = dims(); d < NDIMS; ++d) {
+ (*indices)[d] = 0;
+ (*sizes)[d] = 1;
+ }
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
diff --git a/tensorflow/core/framework/tensor_slice.proto b/tensorflow/core/framework/tensor_slice.proto
new file mode 100644
index 0000000000..ca676bc766
--- /dev/null
+++ b/tensorflow/core/framework/tensor_slice.proto
@@ -0,0 +1,34 @@
+// Protocol buffer representing slices of a tensor
+
+syntax = "proto3";
+// option cc_enable_arenas = true;
+
+package tensorflow;
+
+// Can only be interpreted if you know the corresponding TensorShape.
+message TensorSliceProto {
+ // Extent of the slice in one dimension.
+ message Extent {
+ // Either both or no attributes must be set. When no attribute is set
+ // means: All data in that dimension.
+
+ // Start index of the slice, starting at 0.
+ int64 start = 1;
+
+ // Length of the slice: if the length is missing or -1 we will
+ // interpret this as "everything in this dimension". We use
+ // "oneof" to preserve information about whether the length is
+ // present without changing the serialization format from the
+ // prior proto2 version of this proto.
+ oneof has_length {
+ int64 length = 2;
+ }
+ };
+
+ // Extent of the slice in all tensor dimensions.
+ //
+ // Must have one entry for each of the dimension of the tensor that this
+ // slice belongs to. The order of sizes is the same as the order of
+ // dimensions in the TensorShape.
+ repeated Extent extent = 1;
+};
diff --git a/tensorflow/core/framework/tensor_slice_test.cc b/tensorflow/core/framework/tensor_slice_test.cc
new file mode 100644
index 0000000000..5f718a56b6
--- /dev/null
+++ b/tensorflow/core/framework/tensor_slice_test.cc
@@ -0,0 +1,246 @@
+#include "tensorflow/core/framework/tensor_slice.h"
+
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace {
+
+// Basic tests
+TEST(TensorSliceTest, Basic) {
+ {
+ // Repeatedly setting FullSlice should work.
+ TensorSlice s(3);
+ EXPECT_EQ("-:-:-", s.DebugString());
+
+ s.SetFullSlice(4);
+ EXPECT_EQ("-:-:-:-", s.DebugString());
+ }
+}
+
+// Testing for serialization and parsing for the string format of slices.
+TEST(TensorSliceTest, Serialization) {
+ // Serialization
+ {
+ TensorSlice s({{0, -1}, {0, 10}, {14, 1}, {0, -1}});
+ EXPECT_EQ("-:0,10:14,1:-", s.DebugString());
+ }
+
+ {
+ TensorSliceProto proto;
+ // Define ptxt outside ASSERT_TRUE call to work around bug in some
+ // versions of gcc that breaks when you use raw string literals
+ // inside macro expansions.
+ // See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971
+ const char* ptxt = R"PROTO(
+ extent { }
+ extent { start: 0 length: 10 }
+ extent { start: 14 length: 1 }
+ extent { }
+ )PROTO";
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(ptxt, &proto));
+ TensorSlice s(proto);
+ EXPECT_EQ("-:0,10:14,1:-", s.DebugString());
+ }
+
+ // Parsing
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("-:-:1,3:4,5");
+ TensorSliceProto proto;
+ s.AsProto(&proto);
+ EXPECT_EQ(
+ "extent { } "
+ "extent { } "
+ "extent { start: 1 length: 3 } "
+ "extent { start: 4 length: 5 }",
+ proto.ShortDebugString());
+ }
+
+ // Failed parsing
+ {
+ TensorSlice slice;
+ Status s = TensorSlice::Parse("-:-:1,3:4:5", &slice);
+ EXPECT_EQ(
+ "Invalid argument: "
+ "Expected a pair of numbers or '-' but got '4': "
+ "string = -:-:1,3:4:5",
+ s.ToString());
+ }
+ {
+ TensorSlice slice;
+ Status s = TensorSlice::Parse("-:-1,3", &slice);
+ EXPECT_EQ(
+ "Invalid argument: "
+ "Expected non-negative start and positive length but got "
+ "start = -1, length = 3: string = -:-1,3",
+ s.ToString());
+ }
+}
+
+// Testing the slice intersection
+TEST(TensorSliceTest, Intersection) {
+ // "EVERYTHING" intersects with everything
+ {
+ TensorSlice a = TensorSlice::ParseOrDie("-:-");
+ TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4");
+ TensorSlice c;
+ EXPECT_TRUE(a.Intersect(b, &c));
+ EXPECT_EQ("1,2:3,4", c.DebugString());
+ }
+
+ {
+ TensorSlice a = TensorSlice::ParseOrDie("-:-");
+ TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4");
+ TensorSlice c;
+ EXPECT_TRUE(b.Intersect(a, &c));
+ EXPECT_EQ("1,2:3,4", c.DebugString());
+ }
+
+ // Overlap at all dimensions
+ {
+ TensorSlice a = TensorSlice::ParseOrDie("1,5:2,6:3,7:5,10");
+ TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4:9,10:12,1");
+ TensorSlice c;
+ EXPECT_TRUE(a.Intersect(b, &c));
+ EXPECT_EQ("1,2:3,4:9,1:12,1", c.DebugString());
+ }
+
+ // A mixture of everything and non-trivial slices
+ {
+ TensorSlice a = TensorSlice::ParseOrDie("-:1,1");
+ TensorSlice b = TensorSlice::ParseOrDie("-:0,2");
+ TensorSlice c;
+ EXPECT_TRUE(a.Intersect(b, &c));
+ EXPECT_EQ("-:1,1", c.DebugString());
+ }
+
+ // No overlap on dimension 3: "3,1" and "4,5" don't intersect
+ {
+ TensorSlice a = TensorSlice::ParseOrDie("1,2:3,1:5,6");
+ TensorSlice b = TensorSlice::ParseOrDie("1,3:4,5:1,6");
+ TensorSlice c;
+ EXPECT_FALSE(a.Intersect(b, &c));
+ EXPECT_EQ("", c.DebugString());
+ }
+ // No intersection when there are different dimensions
+ {
+ TensorSlice a = TensorSlice::ParseOrDie("1,2:3,1:-");
+ TensorSlice b = TensorSlice::ParseOrDie("-:-");
+ TensorSlice c;
+ EXPECT_FALSE(a.Intersect(b, &c));
+ EXPECT_EQ("", c.DebugString());
+ }
+}
+
+// Testing applying a slice to a tensor shape
+TEST(TensorSliceTest, SliceTensorShape) {
+ // A proper application
+ {
+ TensorSlice a = TensorSlice::ParseOrDie("1,1:-:4,1:2,6");
+ TensorShape x({2, 4, 5, 8});
+ TensorShape y;
+ EXPECT_OK(a.SliceTensorShape(x, &y));
+ EXPECT_EQ(
+ "dim { size: 1 } "
+ "dim { size: 4 } "
+ "dim { size: 1 } "
+ "dim { size: 6 }",
+ y.DebugString());
+ }
+
+ // An invalid application -- dimension 2 is out of bound
+ {
+ TensorSlice a = TensorSlice::ParseOrDie("1,1:1,4:-:-");
+ TensorShape x({2, 4, 5, 8});
+ TensorShape y;
+ EXPECT_EQ(
+ "Internal: "
+ "Extent in dimension 1 out of bounds: "
+ "shape = dim { size: 2 } "
+ "dim { size: 4 } "
+ "dim { size: 5 } "
+ "dim { size: 8 }, "
+ "slice = 1,1:1,4:-:-",
+ a.SliceTensorShape(x, &y).ToString());
+ EXPECT_EQ("", y.DebugString());
+ }
+}
+
+// Testing the computation of relative slices.
+TEST(TensorSliceTest, ComputeRelative) {
+ // Easy case: base is "everything"
+ {
+ TensorSlice base = TensorSlice::ParseOrDie("-:-:-:-");
+ TensorSlice sub = TensorSlice::ParseOrDie("-:1,2:-:3,4");
+ TensorSlice relative;
+ base.ComputeRelative(sub, &relative);
+ EXPECT_EQ("-:1,2:-:3,4", relative.DebugString());
+ }
+
+ // A slightly more complicated case
+ {
+ TensorSlice base = TensorSlice::ParseOrDie("1,2:3,4:-:5,1");
+ TensorSlice sub = TensorSlice::ParseOrDie("1,1:4,2:3,3:5,1");
+ TensorSlice relative;
+ base.ComputeRelative(sub, &relative);
+ EXPECT_EQ("0,1:1,2:3,3:0,1", relative.DebugString());
+ }
+}
+
+TEST(TensorSliceTest, ExtentLength) {
+ TensorSliceProto proto;
+ // Define ptxt outside ASSERT_TRUE call to work around bug in some
+ // versions of gcc that breaks when you use raw string literals
+ // inside macro expansions.
+ // See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971
+ const char* ptxt = R"PROTO(
+ extent { }
+ extent { start: 0 length: 10 }
+ extent { start: 14 length: 1 }
+ extent { }
+ )PROTO";
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(ptxt, &proto));
+ EXPECT_FALSE(TensorSlice::HasExtentLength(proto.extent(0)));
+ EXPECT_TRUE(TensorSlice::HasExtentLength(proto.extent(1)));
+ EXPECT_TRUE(TensorSlice::HasExtentLength(proto.extent(2)));
+ EXPECT_FALSE(TensorSlice::HasExtentLength(proto.extent(3)));
+ EXPECT_EQ(-1, TensorSlice::GetExtentLength(proto.extent(0)));
+ EXPECT_EQ(10, TensorSlice::GetExtentLength(proto.extent(1)));
+ EXPECT_EQ(1, TensorSlice::GetExtentLength(proto.extent(2)));
+ EXPECT_EQ(-1, TensorSlice::GetExtentLength(proto.extent(3)));
+}
+
+TEST(TensorSliceTest, Deserialization) {
+ // Serialization of
+ // extent { length: 5 }
+ // extent { start: 0 length: 10 }
+ // extent { start: 14 length: 1 }
+ // extent { start: 1 }
+ // extent { }
+ // in proto2 and proto3:
+ const char pb2[] =
+ "\x0A\x02\x10\x05\x0A\x04\x08\x00"
+ "\x10\x0A\x0A\x04\x08\x0E\x10\x01\x0A\x02\x08\x01\x0A\x00";
+ const char pb3[] =
+ "\x0A\x02\x10\x05\x0A\x02"
+ "\x10\x0A\x0A\x04\x08\x0E\x10\x01\x0A\x02\x08\x01\x0A\x00";
+ // (The difference is that in the proto3 version, "start: 0" isn't included
+ // since 0 is start's default value.)
+
+ TensorSliceProto proto2;
+ ASSERT_TRUE(proto2.ParseFromArray(pb2, sizeof(pb2) - 1));
+ TensorSlice ts2(proto2);
+
+ TensorSliceProto proto3;
+ ASSERT_TRUE(proto3.ParseFromArray(pb3, sizeof(pb3) - 1));
+ TensorSlice ts3(proto3);
+
+ // Both serializations should be interpreted the same.
+ EXPECT_EQ("0,5:0,10:14,1:-:-", ts2.DebugString());
+ EXPECT_EQ("0,5:0,10:14,1:-:-", ts3.DebugString());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
new file mode 100644
index 0000000000..4963c2c219
--- /dev/null
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -0,0 +1,551 @@
+#include "tensorflow/core/public/tensor.h"
+
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TEST(TensorTest, Default) {
+ Tensor t;
+ EXPECT_EQ(t.dtype(), DT_FLOAT);
+ EXPECT_EQ(t.dims(), 1);
+ EXPECT_EQ(t.NumElements(), 0);
+}
+
+TEST(TensorTest, DataType_Traits) {
+ EXPECT_TRUE(std::is_trivial<float>::value);
+ EXPECT_TRUE(std::is_trivial<double>::value);
+ EXPECT_TRUE(std::is_trivial<int32>::value);
+ EXPECT_TRUE(std::is_trivial<uint8>::value);
+ EXPECT_TRUE(std::is_trivial<int16>::value);
+ EXPECT_TRUE(std::is_trivial<int8>::value);
+ EXPECT_TRUE(std::is_trivial<int64>::value);
+ EXPECT_TRUE(std::is_trivial<bool>::value);
+ EXPECT_FALSE(std::is_trivial<string>::value);
+
+ EXPECT_EQ(sizeof(bool), 1);
+
+ // Unfortunately. std::complex::complex() initializes (0, 0).
+ EXPECT_FALSE(std::is_trivial<complex64>::value);
+ EXPECT_FALSE(std::is_trivial<std::complex<double>>::value);
+ EXPECT_TRUE(std::is_trivial<float[2]>::value);
+ struct MyComplex {
+ float re, im;
+ };
+ EXPECT_TRUE(std::is_trivial<MyComplex>::value);
+}
+
+template <typename T>
+void TestCopies(const Tensor& t) {
+ {
+ LOG(INFO) << "CopyFrom()";
+ Tensor t2(t.dtype());
+ EXPECT_TRUE(t2.CopyFrom(t, t.shape()));
+ test::ExpectTensorEqual<T>(t, t2);
+ }
+ {
+ LOG(INFO) << "operator=()";
+ Tensor t2(t.dtype());
+ t2 = t;
+ test::ExpectTensorEqual<T>(t, t2);
+ }
+ {
+ LOG(INFO) << "deep copy";
+ Tensor t2(t.dtype(), t.shape());
+ t2.flat<T>() = t.flat<T>();
+ test::ExpectTensorEqual<T>(t, t2);
+ }
+ {
+ LOG(INFO) << "AsProtoField()";
+ TensorProto proto;
+ t.AsProtoField(&proto);
+ Tensor t2(t.dtype());
+ EXPECT_TRUE(t2.FromProto(proto));
+ test::ExpectTensorEqual<T>(t, t2);
+ }
+ {
+ LOG(INFO) << "AsProtoTensorContent()";
+ TensorProto proto;
+ t.AsProtoTensorContent(&proto);
+ Tensor t2(t.dtype());
+ EXPECT_TRUE(t2.FromProto(proto));
+ test::ExpectTensorEqual<T>(t, t2);
+ // Make another copy via tensor_content field.
+ *proto.mutable_tensor_content() = proto.tensor_content();
+ Tensor t3(t.dtype());
+ EXPECT_TRUE(t3.FromProto(proto));
+ test::ExpectTensorEqual<T>(t, t2);
+ }
+ {
+ LOG(INFO) << "AsTensor";
+ gtl::ArraySlice<T> values(t.flat<T>().data(), t.NumElements());
+ Tensor t2 = test::AsTensor(values, t.shape());
+ test::ExpectTensorEqual<T>(t, t2);
+ }
+}
+
+TEST(Tensor_Float, Simple) {
+ Tensor t(DT_FLOAT, TensorShape({10, 20}));
+ EXPECT_TRUE(t.shape().IsSameSize(TensorShape({10, 20})));
+ for (int64 a = 0; a < t.shape().dim_size(0); a++) {
+ for (int64 b = 0; b < t.shape().dim_size(1); b++) {
+ t.matrix<float>()(a, b) = static_cast<float>(a * b);
+ }
+ }
+ TestCopies<float>(t);
+}
+
+TEST(Tensor_QInt8, Simple) {
+ Tensor t(DT_QINT8, TensorShape({2, 2}));
+ EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2})));
+ for (int64 a = 0; a < t.shape().dim_size(0); a++) {
+ for (int64 b = 0; b < t.shape().dim_size(1); b++) {
+ t.matrix<qint8>()(a, b) = qint8(a * b);
+ }
+ }
+ TestCopies<qint8>(t);
+}
+
+TEST(Tensor_QUInt8, Simple) {
+ Tensor t(DT_QUINT8, TensorShape({2, 2}));
+ EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2})));
+ for (int64 a = 0; a < t.shape().dim_size(0); a++) {
+ for (int64 b = 0; b < t.shape().dim_size(1); b++) {
+ t.matrix<Eigen::QUInt8>()(a, b) = Eigen::QUInt8(a * b);
+ }
+ }
+ TestCopies<Eigen::QUInt8>(t);
+}
+
+TEST(Tensor_QInt32, Simple) {
+ Tensor t(DT_QINT32, TensorShape({2, 2}));
+ EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2})));
+ for (int64 a = 0; a < t.shape().dim_size(0); a++) {
+ for (int64 b = 0; b < t.shape().dim_size(1); b++) {
+ t.matrix<qint32>()(a, b) = qint32(static_cast<int32>(a * b));
+ }
+ }
+ TestCopies<qint32>(t);
+}
+
+TEST(Tensor_Float, Reshape) {
+ Tensor t(DT_FLOAT, TensorShape({2, 3, 4, 5}));
+ EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 3, 4, 5})));
+
+ {
+ auto tensor = t.tensor<float, 4>();
+ EXPECT_EQ(2, tensor.dimension(0));
+ EXPECT_EQ(3, tensor.dimension(1));
+ EXPECT_EQ(4, tensor.dimension(2));
+ EXPECT_EQ(5, tensor.dimension(3));
+
+ // Set first and last elements.
+ tensor(0, 0, 0, 0) = 0.01f;
+ tensor(1, 2, 3, 4) = 0.02f;
+ }
+ {
+ auto shaped = t.shaped<float, 1>({120});
+ EXPECT_EQ(120, shaped.dimension(0));
+ EXPECT_EQ(shaped(0), 0.01f);
+ EXPECT_EQ(shaped(119), 0.02f);
+ }
+ {
+ auto shaped = t.shaped<float, 2>({6, 20});
+ EXPECT_EQ(6, shaped.dimension(0));
+ EXPECT_EQ(20, shaped.dimension(1));
+ EXPECT_EQ(shaped(0, 0), 0.01f);
+ EXPECT_EQ(shaped(5, 19), 0.02f);
+ }
+ {
+ auto shaped = t.shaped<float, 3>({6, 4, 5});
+ EXPECT_EQ(6, shaped.dimension(0));
+ EXPECT_EQ(4, shaped.dimension(1));
+ EXPECT_EQ(5, shaped.dimension(2));
+ EXPECT_EQ(shaped(0, 0, 0), 0.01f);
+ EXPECT_EQ(shaped(5, 3, 4), 0.02f);
+ }
+ {
+ auto shaped = t.shaped<float, 4>({2, 3, 4, 5});
+ EXPECT_EQ(2, shaped.dimension(0));
+ EXPECT_EQ(3, shaped.dimension(1));
+ EXPECT_EQ(4, shaped.dimension(2));
+ EXPECT_EQ(5, shaped.dimension(3));
+
+ EXPECT_EQ(shaped(0, 0, 0, 0), 0.01f);
+ EXPECT_EQ(shaped(1, 2, 3, 4), 0.02f);
+ }
+ {
+ auto flat = t.flat<float>();
+ EXPECT_EQ(flat(0), 0.01f);
+ EXPECT_EQ(120, flat.dimension(0));
+ EXPECT_EQ(flat(0), 0.01f);
+ EXPECT_EQ(flat(119), 0.02f);
+ }
+ {
+ auto flat_inner_dims = t.flat_inner_dims<float>();
+ EXPECT_EQ(24, flat_inner_dims.dimension(0));
+ EXPECT_EQ(5, flat_inner_dims.dimension(1));
+ EXPECT_EQ(flat_inner_dims(0, 0), 0.01f);
+ EXPECT_EQ(flat_inner_dims(23, 4), 0.02f);
+ }
+}
+
+TEST(Tensor_Scalar, Basics) {
+ {
+ Tensor t(DT_FLOAT, TensorShape({}));
+ EXPECT_EQ(1, t.NumElements());
+ auto Tt = t.scalar<float>();
+ EXPECT_EQ(1, Tt.size());
+ EXPECT_EQ(0, Tt.rank());
+ t.scalar<float>()() = 123.45f;
+ EXPECT_FLOAT_EQ(123.45f, Tt());
+ }
+ {
+ Tensor t(DT_FLOAT, TensorShape({1}));
+ EXPECT_EQ(1, t.NumElements());
+ auto Tt = t.vec<float>();
+ EXPECT_EQ(1, Tt.size());
+ t.vec<float>()(0) = 123.45f;
+ EXPECT_FLOAT_EQ(123.45f, Tt(0));
+ }
+ {
+ Tensor t(DT_FLOAT, TensorShape({1, 1, 1}));
+ EXPECT_EQ(1, t.NumElements());
+ auto Tt = t.scalar<float>();
+ EXPECT_EQ(1, Tt.size());
+ EXPECT_EQ(0, Tt.rank());
+ t.flat<float>()(0) = 123.45f;
+ EXPECT_FLOAT_EQ(123.45f, Tt());
+ }
+ {
+ Tensor t(DT_STRING, TensorShape({}));
+ EXPECT_EQ(1, t.NumElements());
+ auto Tt = t.scalar<string>();
+ EXPECT_EQ(1, Tt.size());
+ EXPECT_EQ(0, Tt.rank());
+ t.scalar<string>()() = "foo";
+ EXPECT_EQ("foo", Tt());
+ }
+ {
+ Tensor t(DT_STRING, TensorShape({1}));
+ EXPECT_EQ(1, t.NumElements());
+ auto Tt = t.vec<string>();
+ EXPECT_EQ(1, Tt.size());
+ t.flat<string>()(0) = "foo";
+ EXPECT_EQ("foo", Tt(0));
+ }
+ {
+ Tensor t(DT_STRING, TensorShape({1, 1, 1}));
+ EXPECT_EQ(1, t.NumElements());
+ auto Tt = t.scalar<string>();
+ EXPECT_EQ(1, Tt.size());
+ EXPECT_EQ(0, Tt.rank());
+ t.flat<string>()(0) = "bar";
+ EXPECT_EQ("bar", Tt());
+ }
+ {
+ Tensor t(DT_FLOAT, TensorShape({0, 1}));
+ EXPECT_EQ(0, t.NumElements());
+ auto Tt = t.flat<float>();
+ EXPECT_EQ(0, Tt.size());
+ auto Tm = t.matrix<float>();
+ EXPECT_EQ(0, Tm.size());
+ EXPECT_EQ(0, Tm.dimensions()[0]);
+ EXPECT_EQ(1, Tm.dimensions()[1]);
+ }
+}
+
+TEST(Tensor_Float, Reshape_And_Slice_Assignment) {
+ // A test to experiment with a way to assign to a subset of a tensor
+ Tensor t(DT_FLOAT, TensorShape({10, 4, 3, 2}));
+ EXPECT_TRUE(t.shape().IsSameSize(TensorShape({10, 4, 3, 2})));
+
+ // Get the N dimensional tensor (N==4 here)
+ auto e_t = t.tensor<float, 4>();
+ // Reshape to view it as a two-dimensional tensor
+ auto e_2d = t.shaped<float, 2>({10, 4 * 3 * 2});
+ for (int i = 0; i < 10; i++) {
+ // Assign a 1 x 4*3*2 matrix (really vector) to a slice of size
+ // 1 x 4*3*2 in e_t.
+ Eigen::Tensor<float, 2, Eigen::RowMajor> m(1, 4 * 3 * 2);
+ m.setConstant(i * 2.0);
+
+ Eigen::DSizes<Eigen::DenseIndex, 2> indices(i, 0);
+ Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, 4 * 3 * 2);
+ e_2d.slice(indices, sizes) = m;
+ }
+ for (int i = 0; i < 10; i++) {
+ for (int j = 0; j < 4; j++) {
+ for (int k = 0; k < 3; k++) {
+ for (int l = 0; l < 2; l++) {
+ EXPECT_EQ(e_t(i, j, k, l), i * 2.0f);
+ LOG(INFO) << i << "," << j << "," << k << "," << l
+ << " &e_t(i, j, k, l): " << &e_t(i, j, k, l) << " = "
+ << e_t(i, j, k, l);
+ }
+ }
+ }
+ }
+}
+
+TEST(Tensor_String, Simple) {
+ Tensor t = test::AsTensor<string>(
+ {"hello", "world", "machine", "learning", "new", "york"},
+ TensorShape({3, 2}));
+ auto s = t.shape();
+ ASSERT_EQ(s.dims(), 2);
+ ASSERT_EQ(s.dim_size(0), 3);
+ ASSERT_EQ(s.dim_size(1), 2);
+ auto m = t.matrix<string>();
+ EXPECT_EQ(t.TotalBytes(), 3 * 2 * sizeof(string) + 5 + 5 + 7 + 8 + 3 + 4);
+
+ EXPECT_EQ(m(0, 0), "hello");
+ EXPECT_EQ(m(0, 1), "world");
+ EXPECT_EQ(m(1, 0), "machine");
+ EXPECT_EQ(m(1, 1), "learning");
+ EXPECT_EQ(m(2, 0), "new");
+ EXPECT_EQ(m(2, 1), "york");
+
+ TestCopies<string>(t);
+}
+
+TEST(Tensor_Float, SimpleWithHelper) {
+ Tensor t1 = test::AsTensor<float>({0, 1, 2, 3, 4, 5}, {2, 3});
+ Tensor t2(t1.dtype(), t1.shape());
+ t2.flat<float>() = t1.flat<float>() * 2.0f;
+ Tensor t3 = test::AsTensor<float>({0, 2, 4, 6, 8, 10}, t1.shape());
+ test::ExpectTensorEqual<float>(t2, t3);
+}
+
+TEST(Tensor_Int32, SimpleWithHelper) {
+ Tensor t1 = test::AsTensor<int32>({0, 1, 2, 3, 4, 5}, {2, 3});
+ Tensor t2(t1.dtype(), t1.shape());
+ t2.flat<int32>() = t1.flat<int32>() * 2;
+ Tensor t3 = test::AsTensor<int32>({0, 2, 4, 6, 8, 10}, t1.shape());
+ test::ExpectTensorEqual<int32>(t2, t3);
+}
+
+TEST(Tensor_QInt8, SimpleWithHelper) {
+ Tensor t1 = test::AsTensor<qint8>({0, 1, 2, 3, 4, 5}, {2, 3});
+ Tensor t2(t1.dtype(), t1.shape());
+ t2.flat<qint8>() = t1.flat<qint8>() + qint8(-2);
+ Tensor t3 = test::AsTensor<qint8>({-2, -1, 0, 1, 2, 3}, {2, 3});
+ test::ExpectTensorEqual<qint8>(t2, t3);
+}
+
+TEST(Tensor_QUInt8, SimpleWithHelper) {
+ Tensor t1 = test::AsTensor<quint8>({0, 1, 2, 3, 4, 5}, {2, 3});
+ Tensor t2(t1.dtype(), t1.shape());
+ t2.flat<quint8>() = t1.flat<quint8>() + quint8(2);
+ Tensor t3 = test::AsTensor<quint8>({2, 3, 4, 5, 6, 7}, {2, 3});
+ test::ExpectTensorEqual<quint8>(t2, t3);
+}
+
+TEST(Tensor_Int64, SimpleWithHelper) {
+ Tensor t1 = test::AsTensor<int64>(
+ {0LL << 48, 1LL << 48, 2LL << 48, 3LL << 48, 4LL << 48, 5LL << 48},
+ {2, 3});
+ Tensor t2(t1.dtype(), t1.shape());
+ t2.flat<int64>() = t1.flat<int64>() * static_cast<int64>(2);
+ Tensor t3 = test::AsTensor<int64>(
+ {0LL << 48, 2LL << 48, 4LL << 48, 6LL << 48, 8LL << 48, 10LL << 48},
+ {2, 3});
+ test::ExpectTensorEqual<int64>(t2, t3);
+}
+
+TEST(Tensor_String, SimpleWithHelper) {
+ Tensor t1 = test::AsTensor<string>({"0", "1", "2", "3", "4", "5"}, {2, 3});
+ Tensor t2(DT_STRING, {2, 3});
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ t2.matrix<string>()(i, j) = strings::StrCat(i * 3 + j);
+ }
+ }
+
+ // Test with helper.
+ test::ExpectTensorEqual<string>(t1, t2);
+}
+
+TEST(Tensor_Bool, SimpleWithHelper) {
+ Tensor t1 =
+ test::AsTensor<bool>({false, true, false, true, false, true}, {2, 3});
+
+ Tensor t2(DT_BOOL, {2, 3});
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ t2.matrix<bool>()(i, j) = (((i + j) % 2) != 0);
+ }
+ }
+
+ // Test with helper.
+ test::ExpectTensorEqual<bool>(t1, t2);
+}
+
+TEST(Tensor_Complex, Simple) {
+ Tensor t(DT_COMPLEX64, {4, 5, 3, 7});
+ t.flat<complex64>().setRandom();
+ TestCopies<complex64>(t);
+}
+
+TEST(Tensor_Complex, SimpleWithHelper) {
+ {
+ Tensor t1 = test::AsTensor<complex64>({0,
+ {1, 1},
+ complex64(2),
+ complex64(3, 3),
+ complex64(0, 4),
+ complex64(2, 5)},
+ {2, 3});
+ Tensor t2(t1.dtype(), t1.shape());
+ t2.flat<complex64>() = t1.flat<complex64>() * complex64(0, 2);
+ Tensor t3 = test::AsTensor<complex64>(
+ {0, {-2, 2}, {0, 4}, {-6, 6}, {-8, 0}, {-10, 4}},
+ // shape
+ {2, 3});
+ test::ExpectTensorEqual<complex64>(t2, t3);
+ }
+
+ // Does some numeric operations for complex numbers.
+ {
+ const float PI = std::acos(-1);
+ const complex64 rotate_45 = std::polar(1.0f, PI / 4);
+
+ // x contains all the 8-th root of unity.
+ Tensor x(DT_COMPLEX64, TensorShape({8}));
+ for (int i = 0; i < 8; ++i) {
+ x.vec<complex64>()(i) = std::pow(rotate_45, i);
+ }
+
+ // Shift the roots by 45 degree.
+ Tensor y(DT_COMPLEX64, TensorShape({8}));
+ y.vec<complex64>() = x.vec<complex64>() * rotate_45;
+ Tensor y_expected(DT_COMPLEX64, TensorShape({8}));
+ for (int i = 0; i < 8; ++i) {
+ y_expected.vec<complex64>()(i) = std::pow(rotate_45, i + 1);
+ }
+ test::ExpectTensorNear<complex64>(y, y_expected, 1e-5);
+
+ // Raise roots to the power of 8.
+ Tensor z(DT_COMPLEX64, TensorShape({8}));
+ z.vec<complex64>() = x.vec<complex64>().pow(8);
+ Tensor z_expected(DT_COMPLEX64, TensorShape({8}));
+ for (int i = 0; i < 8; ++i) {
+ z_expected.vec<complex64>()(i) = 1;
+ }
+ test::ExpectTensorNear<complex64>(z, z_expected, 1e-5);
+ }
+}
+
+// On the alignment.
+//
+// As of 2015/8, tensorflow::Tensor allocates its buffer with 32-byte
+// alignment. Tensor::tensor/flat/vec/matrix methods requires the the
+// buffer satisfies Eigen::Aligned (e.g., 16-bytes aligned usually,
+// and 32-bytes for AVX). Tensor::Slice requires the caller to ensure
+// its result is aligned if the caller intends to use those methods.
+// In this test case, we simply make sure each slice is 32-byte
+// aligned: sizeof(float) * 4 * 2 = 32.
+TEST(Tensor, Slice_Basic) {
+ Tensor saved;
+ { // General
+ Tensor x(DT_FLOAT, TensorShape({10, 4, 34}));
+ // Fills in known values.
+ for (int i = 0; i < 10; ++i) {
+ x.Slice(i, i + 1).flat<float>().setConstant(i * 1.f);
+ }
+ // A simple slice along dim0.
+ Tensor y = x.Slice(4, 8);
+ EXPECT_TRUE(y.shape().IsSameSize(TensorShape({4, 4, 34})));
+ auto tx = x.tensor<float, 3>();
+ auto ty = y.tensor<float, 3>();
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ for (int k = 0; k < 34; ++k) {
+ EXPECT_EQ(ty(i, j, k), 4.0 + i);
+ EXPECT_EQ(&tx(4 + i, j, k), &ty(i, j, k));
+ }
+ }
+ }
+ // A simple slice equivalent to identity.
+ TestCopies<float>(y);
+ y = x.Slice(0, 10);
+ test::ExpectTensorEqual<float>(x, y);
+ EXPECT_EQ(x.flat<float>().data(), y.flat<float>().data());
+
+ // A slice of a slice.
+ auto z = x.Slice(4, 8).Slice(2, 3);
+ auto tz = z.tensor<float, 3>();
+ EXPECT_EQ(1, z.dim_size(0));
+ for (int j = 0; j < 4; ++j) {
+ for (int k = 0; k < 34; ++k) {
+ EXPECT_EQ(tz(0, j, k), 6.0);
+ }
+ }
+
+ // x and y will be out of scope. But 'saved' should be alive.
+ saved = z;
+ }
+ {
+ EXPECT_EQ(1, saved.dim_size(0));
+ auto tsaved = saved.tensor<float, 3>();
+ for (int j = 0; j < 4; ++j) {
+ for (int k = 0; k < 34; ++k) {
+ EXPECT_EQ(tsaved(0, j, k), 6.0);
+ }
+ }
+ }
+ { // Empty
+ Tensor x(DT_FLOAT, TensorShape({10, 0, 34}));
+ x.flat<float>().setRandom();
+ Tensor y = x.Slice(4, 8);
+ EXPECT_TRUE(y.shape().IsSameSize(TensorShape({4, 0, 34})));
+ }
+
+ {
+ // Test unaligned access via a Slice.
+ Tensor x(DT_FLOAT, TensorShape({30}));
+ x.flat<float>().setConstant(0.0);
+
+ // Take an unaligned slice.
+ Tensor y = x.Slice(1, 13);
+ y.unaligned_flat<float>().setConstant(1.0);
+ for (int64 i = 0; i < y.NumElements(); ++i) {
+ EXPECT_EQ(1.0, y.unaligned_flat<float>()(i));
+ }
+ }
+}
+
+static void BM_CreateAndDestroy(int iters) {
+ TensorShape shape({10, 20});
+ while (--iters) {
+ Tensor t(DT_FLOAT, shape);
+ }
+}
+BENCHMARK(BM_CreateAndDestroy);
+
+static void BM_Assign(int iters) {
+ Tensor a(DT_FLOAT, TensorShape({10, 20}));
+ Tensor b(DT_FLOAT, TensorShape({10, 20}));
+ bool a_to_b = true;
+ while (--iters) {
+ if (a_to_b) {
+ b = a;
+ } else {
+ a = b;
+ }
+ a_to_b = !a_to_b;
+ }
+}
+BENCHMARK(BM_Assign);
+
+// Ensure tensor_data() works on empty tensors
+TEST(Tensor, EmptyTensorData) {
+ Tensor empty;
+ EXPECT_EQ(empty.tensor_data().size(), 0);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_testutil.cc b/tensorflow/core/framework/tensor_testutil.cc
new file mode 100644
index 0000000000..b6cd12a864
--- /dev/null
+++ b/tensorflow/core/framework/tensor_testutil.cc
@@ -0,0 +1,43 @@
+#include <cmath>
+#include "tensorflow/core/framework/tensor_testutil.h"
+
+namespace tensorflow {
+namespace test {
+
+template <typename T>
+bool IsClose(const T& x, const T& y, double atol, double rtol) {
+ return fabs(x - y) < atol + rtol * fabs(x);
+}
+
+template <typename T>
+void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) {
+ auto Tx = x.flat<T>();
+ auto Ty = y.flat<T>();
+ for (int i = 0; i < Tx.size(); ++i) {
+ if (!IsClose(Tx(i), Ty(i), atol, rtol)) {
+ LOG(ERROR) << "x = " << x.DebugString();
+ LOG(ERROR) << "y = " << y.DebugString();
+ LOG(ERROR) << "atol = " << atol << " rtol = " << rtol
+ << " tol = " << atol + rtol * std::fabs(Tx(i));
+ EXPECT_TRUE(false) << i << "-th element is not close " << Tx(i) << " vs. "
+ << Ty(i);
+ }
+ }
+}
+
+void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) {
+ internal::AssertSameTypeDims(x, y);
+ switch (x.dtype()) {
+ case DT_FLOAT:
+ ExpectClose<float>(x, y, atol, rtol);
+ break;
+ case DT_DOUBLE:
+ ExpectClose<double>(x, y, atol, rtol);
+ break;
+ default:
+ LOG(FATAL) << "Unexpected type : " << DataTypeString(x.dtype());
+ }
+}
+
+} // end namespace test
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_testutil.h b/tensorflow/core/framework/tensor_testutil.h
new file mode 100644
index 0000000000..53d6da0fb2
--- /dev/null
+++ b/tensorflow/core/framework/tensor_testutil.h
@@ -0,0 +1,189 @@
+#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_
+#define TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_
+
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace test {
+
+// Constructs a scalar tensor with 'val'.
+template <typename T>
+Tensor AsScalar(const T& val) {
+ Tensor ret(DataTypeToEnum<T>::value, {});
+ ret.scalar<T>()() = val;
+ return ret;
+}
+
+// Constructs a flat tensor with 'vals'.
+template <typename T>
+Tensor AsTensor(gtl::ArraySlice<T> vals) {
+ Tensor ret(DataTypeToEnum<T>::value, {static_cast<int64>(vals.size())});
+ std::copy_n(vals.data(), vals.size(), ret.flat<T>().data());
+ return ret;
+}
+
+// Constructs a tensor of "shape" with values "vals".
+template <typename T>
+Tensor AsTensor(gtl::ArraySlice<T> vals, const TensorShape& shape) {
+ Tensor ret;
+ CHECK(ret.CopyFrom(AsTensor(vals), shape));
+ return ret;
+}
+
+// Fills in '*tensor' with 'vals'. E.g.,
+// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
+// test::FillValues<float>(&x, {11, 21, 21, 22});
+template <typename T>
+void FillValues(Tensor* tensor, gtl::ArraySlice<T> vals) {
+ auto flat = tensor->flat<T>();
+ CHECK_EQ(flat.size(), vals.size());
+ if (flat.size() > 0) {
+ std::copy_n(vals.data(), vals.size(), flat.data());
+ }
+}
+
+// Fills in '*tensor' with a sequence of value of val, val+1, val+2, ...
+// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
+// test::FillIota<float>(&x, 1.0);
+template <typename T>
+void FillIota(Tensor* tensor, const T& val) {
+ auto flat = tensor->flat<T>();
+ std::iota(flat.data(), flat.data() + flat.size(), val);
+}
+
+// Fills in '*tensor' with a sequence of value of fn(0), fn(1), ...
+// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2}));
+// test::FillFn<float>(&x, [](int i)->float { return i*i; });
+template <typename T>
+void FillFn(Tensor* tensor, std::function<T(int)> fn) {
+ auto flat = tensor->flat<T>();
+ for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i);
+}
+
+// Expects "x" and "y" are tensors of the same type, same shape, and
+// identical values.
+template <typename T>
+void ExpectTensorEqual(const Tensor& x, const Tensor& y);
+
+// Expects "x" and "y" are tensors of the same type, same shape, and
+// approxmiate equal values, each within "abs_err".
+template <typename T>
+void ExpectTensorNear(const Tensor& x, const Tensor& y, const T& abs_err);
+
+// Expects "x" and "y" are tensors of the same type (float or double),
+// same shape and element-wise difference between x and y is no more
+// than atol + rtol * abs(x).
+void ExpectClose(const Tensor& x, const Tensor& y, double atol = 1e-6,
+ double rtol = 1e-6);
+
+// Implementation details.
+
+namespace internal {
+
+template <typename T>
+struct is_floating_point_type {
+ static const bool value = std::is_same<T, float>::value ||
+ std::is_same<T, double>::value ||
+ std::is_same<T, std::complex<float> >::value ||
+ std::is_same<T, std::complex<double> >::value;
+};
+
+template <typename T>
+static void ExpectEqual(const T& a, const T& b) {
+ EXPECT_EQ(a, b);
+}
+
+template <>
+void ExpectEqual<float>(const float& a, const float& b) {
+ EXPECT_FLOAT_EQ(a, b);
+}
+
+template <>
+void ExpectEqual<double>(const double& a, const double& b) {
+ EXPECT_DOUBLE_EQ(a, b);
+}
+
+template <>
+void ExpectEqual<complex64>(const complex64& a, const complex64& b) {
+ EXPECT_FLOAT_EQ(a.real(), b.real()) << a << " vs. " << b;
+ EXPECT_FLOAT_EQ(a.imag(), b.imag()) << a << " vs. " << b;
+}
+
+inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) {
+ ASSERT_EQ(x.dtype(), y.dtype());
+ ASSERT_TRUE(x.IsSameSize(y))
+ << "x.shape [" << x.shape().DebugString() << "] vs "
+ << "y.shape [ " << y.shape().DebugString() << "]";
+}
+
+template <typename T, bool is_fp = is_floating_point_type<T>::value>
+struct Expector;
+
+template <typename T>
+struct Expector<T, false> {
+ static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }
+
+ static void Equal(const Tensor& x, const Tensor& y) {
+ ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
+ AssertSameTypeDims(x, y);
+ auto a = x.flat<T>();
+ auto b = y.flat<T>();
+ for (int i = 0; i < a.size(); ++i) {
+ ExpectEqual(a(i), b(i));
+ }
+ }
+};
+
+// Partial specialization for float and double.
+template <typename T>
+struct Expector<T, true> {
+ static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }
+
+ static void Equal(const Tensor& x, const Tensor& y) {
+ ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
+ AssertSameTypeDims(x, y);
+ auto a = x.flat<T>();
+ auto b = y.flat<T>();
+ for (int i = 0; i < a.size(); ++i) {
+ ExpectEqual(a(i), b(i));
+ }
+ }
+
+ static void Near(const T& a, const T& b, const double abs_err) {
+ if (a != b) { // Takes care of inf.
+ EXPECT_LE(std::abs(a - b), abs_err) << "a = " << a << " b = " << b;
+ }
+ }
+
+ static void Near(const Tensor& x, const Tensor& y, const double abs_err) {
+ ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
+ AssertSameTypeDims(x, y);
+ auto a = x.flat<T>();
+ auto b = y.flat<T>();
+ for (int i = 0; i < a.size(); ++i) {
+ Near(a(i), b(i), abs_err);
+ }
+ }
+};
+
+} // namespace internal
+
+template <typename T>
+void ExpectTensorEqual(const Tensor& x, const Tensor& y) {
+ internal::Expector<T>::Equal(x, y);
+}
+
+template <typename T>
+void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) {
+ static_assert(internal::is_floating_point_type<T>::value,
+ "T is not a floating point types.");
+ internal::Expector<T>::Near(x, y, abs_err);
+}
+
+} // namespace test
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_TENSOR_TESTUTIL_H_
diff --git a/tensorflow/core/framework/tensor_types.h b/tensorflow/core/framework/tensor_types.h
new file mode 100644
index 0000000000..077d86d442
--- /dev/null
+++ b/tensorflow/core/framework/tensor_types.h
@@ -0,0 +1,92 @@
+#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_
+#define TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+// Helper to define Tensor types given that the scalar is of type T.
+template <typename T, int NDIMS = 1>
+struct TTypes {
+ // Rank-<NDIMS> tensor of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor>,
+ Eigen::Aligned> Tensor;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, NDIMS, Eigen::RowMajor>,
+ Eigen::Aligned> ConstTensor;
+
+ // Unaligned Rank-<NDIMS> tensor of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor> >
+ UnalignedTensor;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, NDIMS, Eigen::RowMajor> >
+ UnalignedConstTensor;
+
+ typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, int>,
+ Eigen::Aligned> Tensor32Bit;
+
+ // Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
+ typedef Eigen::TensorMap<
+ Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor>,
+ Eigen::Aligned> Scalar;
+ typedef Eigen::TensorMap<
+ Eigen::TensorFixedSize<const T, Eigen::Sizes<>, Eigen::RowMajor>,
+ Eigen::Aligned> ConstScalar;
+
+ // Unaligned Scalar tensor of scalar type T.
+ typedef Eigen::TensorMap<Eigen::TensorFixedSize<
+ T, Eigen::Sizes<>, Eigen::RowMajor> > UnalignedScalar;
+ typedef Eigen::TensorMap<Eigen::TensorFixedSize<
+ const T, Eigen::Sizes<>, Eigen::RowMajor> > UnalignedConstScalar;
+
+ // Rank-1 tensor (vector) of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, Eigen::Aligned>
+ Flat;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
+ Eigen::Aligned> ConstFlat;
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, Eigen::Aligned>
+ Vec;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
+ Eigen::Aligned> ConstVec;
+
+ // Unaligned Rank-1 tensor (vector) of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor> > UnalignedFlat;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor> >
+ UnalignedConstFlat;
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor> > UnalignedVec;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor> >
+ UnalignedConstVec;
+
+ // Rank-2 tensor (matrix) of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>
+ Matrix;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Aligned> ConstMatrix;
+
+ // Unaligned Rank-2 tensor (matrix) of scalar type T.
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor> >
+ UnalignedMatrix;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor> >
+ UnalignedConstMatrix;
+};
+
+typedef typename TTypes<float, 1>::Tensor32Bit::Index Index32;
+
+template <typename DSizes>
+Eigen::DSizes<Index32, DSizes::count> To32BitDims(const DSizes& in) {
+ Eigen::DSizes<Index32, DSizes::count> out;
+ for (int i = 0; i < DSizes::count; ++i) {
+ out[i] = in[i];
+ }
+ return out;
+}
+
+template <typename TensorType>
+typename TTypes<typename TensorType::Scalar,
+ TensorType::NumIndices>::Tensor32Bit
+To32Bit(TensorType in) {
+ typedef typename TTypes<typename TensorType::Scalar,
+ TensorType::NumIndices>::Tensor32Bit RetType;
+ return RetType(in.data(), To32BitDims(in.dimensions()));
+}
+
+} // namespace tensorflow
+#endif // TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_
diff --git a/tensorflow/core/framework/tensor_util.cc b/tensorflow/core/framework/tensor_util.cc
new file mode 100644
index 0000000000..7353191c74
--- /dev/null
+++ b/tensorflow/core/framework/tensor_util.cc
@@ -0,0 +1,28 @@
+#include "tensorflow/core/framework/tensor_util.h"
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+namespace tensor {
+
+Tensor DeepCopy(const Tensor& other) {
+ Tensor tmp = Tensor(other.dtype(), other.shape());
+ if (DataTypeCanUseMemcpy(other.dtype())) {
+ StringPiece other_data = other.tensor_data();
+
+ // We use StringPiece as a convenient map over the tensor buffer,
+ // but we cast the type to get to the underlying buffer to do the
+ // copy.
+ StringPiece tmp_data = tmp.tensor_data();
+ memcpy(const_cast<char*>(tmp_data.data()), other_data.data(),
+ other_data.size());
+ } else {
+ CHECK_EQ(DT_STRING, other.dtype());
+ tmp.flat<string>() = other.flat<string>();
+ }
+ return tmp;
+}
+
+} // namespace tensor
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h
new file mode 100644
index 0000000000..a8dde1d0ca
--- /dev/null
+++ b/tensorflow/core/framework/tensor_util.h
@@ -0,0 +1,21 @@
+#ifndef TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_
+#define TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_
+
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+namespace tensor {
+
+// DeepCopy returns a tensor whose contents are a deep copy of the
+// contents of 'other'. This function is intended only for
+// convenience, not speed.
+//
+// REQUIRES: 'other' must point to data stored in CPU memory.
+// REQUIRES: 'other' must be a Tensor of a copy-able type if
+// 'other' is not appropriately memory-aligned.
+Tensor DeepCopy(const Tensor& other);
+
+} // namespace tensor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_
diff --git a/tensorflow/core/framework/tensor_util_test.cc b/tensorflow/core/framework/tensor_util_test.cc
new file mode 100644
index 0000000000..fef7468151
--- /dev/null
+++ b/tensorflow/core/framework/tensor_util_test.cc
@@ -0,0 +1,124 @@
+#include "tensorflow/core/framework/tensor_util.h"
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+TEST(TensorUtil, DeepCopy0d) {
+ Tensor x(DT_FLOAT, TensorShape({}));
+ x.scalar<float>()() = 10.0;
+
+ // Make y a deep copy of x and then change it.
+ Tensor y = tensor::DeepCopy(x);
+ y.scalar<float>()() = 20.0;
+
+ // x doesn't change
+ EXPECT_EQ(10.0, x.scalar<float>()());
+
+ // Change x.
+ x.scalar<float>()() = 30.0;
+
+ // Y doesn't change.
+ EXPECT_EQ(20.0, y.scalar<float>()());
+
+ Tensor z = tensor::DeepCopy(y);
+
+ // Change y.
+ y.scalar<float>()() = 40.0;
+
+ // The final states should all be different.
+ EXPECT_EQ(20.0, z.scalar<float>()());
+ EXPECT_EQ(30.0, x.scalar<float>()());
+ EXPECT_EQ(40.0, y.scalar<float>()());
+
+ // Should have the same shape and type.
+ EXPECT_EQ(TensorShape({}), x.shape());
+ EXPECT_EQ(TensorShape({}), y.shape());
+ EXPECT_EQ(TensorShape({}), z.shape());
+
+ EXPECT_EQ(DT_FLOAT, x.dtype());
+ EXPECT_EQ(DT_FLOAT, y.dtype());
+ EXPECT_EQ(DT_FLOAT, z.dtype());
+}
+
+TEST(TensorUtil, DeepCopy) {
+ Tensor x(DT_FLOAT, TensorShape({1}));
+ x.flat<float>()(0) = 10.0;
+
+ // Make y a deep copy of x and then change it.
+ Tensor y = tensor::DeepCopy(x);
+ y.flat<float>()(0) = 20.0;
+
+ // x doesn't change
+ EXPECT_EQ(10.0, x.flat<float>()(0));
+
+ // Change x.
+ x.flat<float>()(0) = 30.0;
+
+ // Y doesn't change.
+ EXPECT_EQ(20.0, y.flat<float>()(0));
+
+ Tensor z = tensor::DeepCopy(y);
+
+ // Change y.
+ y.flat<float>()(0) = 40.0;
+
+ // The final states should all be different.
+ EXPECT_EQ(20.0, z.flat<float>()(0));
+ EXPECT_EQ(30.0, x.flat<float>()(0));
+ EXPECT_EQ(40.0, y.flat<float>()(0));
+
+ // Should have the same shape and type.
+ EXPECT_EQ(TensorShape({1}), x.shape());
+ EXPECT_EQ(TensorShape({1}), y.shape());
+ EXPECT_EQ(TensorShape({1}), z.shape());
+
+ EXPECT_EQ(DT_FLOAT, x.dtype());
+ EXPECT_EQ(DT_FLOAT, y.dtype());
+ EXPECT_EQ(DT_FLOAT, z.dtype());
+
+ // Test string deep copy
+ Tensor str1(DT_STRING, TensorShape({2}));
+ str1.flat<string>()(0) = "foo1";
+ str1.flat<string>()(1) = "foo2";
+ Tensor str2 = tensor::DeepCopy(str1);
+ str2.flat<string>()(0) = "bar1";
+ str2.flat<string>()(1) = "bar2";
+ EXPECT_NE(str2.flat<string>()(0), str1.flat<string>()(0));
+}
+
+TEST(TensorUtil, DeepCopySlice) {
+ Tensor x(DT_INT32, TensorShape({10}));
+ x.flat<int32>().setConstant(1);
+
+ // Slice 'x' -- y still refers to the same buffer.
+ Tensor y = x.Slice(2, 6);
+
+ // Do a deep copy of y, which is a slice.
+ Tensor z = tensor::DeepCopy(y);
+
+ // Set x to be different.
+ x.flat<int32>().setConstant(2);
+
+ EXPECT_EQ(TensorShape({10}), x.shape());
+ EXPECT_EQ(TensorShape({4}), y.shape());
+ EXPECT_EQ(TensorShape({4}), z.shape());
+ EXPECT_EQ(DT_INT32, x.dtype());
+ EXPECT_EQ(DT_INT32, y.dtype());
+ EXPECT_EQ(DT_INT32, z.dtype());
+
+ // x and y should now all be '2', but z should be '1'.
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_EQ(2, x.flat<int32>()(i));
+ }
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(2, y.unaligned_flat<int32>()(i));
+ EXPECT_EQ(1, z.flat<int32>()(i));
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tracking_allocator.cc b/tensorflow/core/framework/tracking_allocator.cc
new file mode 100644
index 0000000000..78311ded19
--- /dev/null
+++ b/tensorflow/core/framework/tracking_allocator.cc
@@ -0,0 +1,100 @@
+#include "tensorflow/core/framework/tracking_allocator.h"
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+TrackingAllocator::TrackingAllocator(Allocator* allocator)
+ : allocator_(allocator),
+ ref_(1),
+ allocated_(0),
+ high_watermark_(0),
+ total_bytes_(0) {}
+
+void* TrackingAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
+ void* ptr = allocator_->AllocateRaw(alignment, num_bytes);
+ // If memory is exhausted AllocateRaw returns nullptr, and we should
+ // pass this through to the caller
+ if (nullptr == ptr) {
+ return ptr;
+ }
+ if (allocator_->TracksAllocationSizes()) {
+ size_t allocated_bytes = allocator_->AllocatedSize(ptr);
+ {
+ mutex_lock lock(mu_);
+ allocated_ += allocated_bytes;
+ high_watermark_ = std::max(high_watermark_, allocated_);
+ total_bytes_ += allocated_bytes;
+ ++ref_;
+ }
+ } else {
+ mutex_lock lock(mu_);
+ total_bytes_ += num_bytes;
+ ++ref_;
+ }
+ return ptr;
+}
+
+void TrackingAllocator::DeallocateRaw(void* ptr) {
+ // freeing a null ptr is a no-op
+ if (nullptr == ptr) {
+ return;
+ }
+ bool should_delete;
+ // fetch the following outside the lock in case the call to
+ // AllocatedSize is slow
+ bool tracks_allocation_sizes = allocator_->TracksAllocationSizes();
+ size_t allocated_bytes = 0;
+ if (tracks_allocation_sizes) {
+ allocated_bytes = allocator_->AllocatedSize(ptr);
+ }
+ Allocator* allocator = allocator_;
+ {
+ mutex_lock lock(mu_);
+ if (tracks_allocation_sizes) {
+ CHECK_GE(allocated_, allocated_bytes);
+ allocated_ -= allocated_bytes;
+ }
+ should_delete = UnRef();
+ }
+ allocator->DeallocateRaw(ptr);
+ if (should_delete) {
+ delete this;
+ }
+}
+
+bool TrackingAllocator::TracksAllocationSizes() {
+ return allocator_->TracksAllocationSizes();
+}
+
+size_t TrackingAllocator::RequestedSize(void* ptr) {
+ return allocator_->RequestedSize(ptr);
+}
+
+size_t TrackingAllocator::AllocatedSize(void* ptr) {
+ return allocator_->AllocatedSize(ptr);
+}
+
+std::pair<size_t, size_t> TrackingAllocator::GetSizesAndUnRef() {
+ size_t high_watermark;
+ size_t total_bytes;
+ bool should_delete;
+ {
+ mutex_lock lock(mu_);
+ high_watermark = high_watermark_;
+ total_bytes = total_bytes_;
+ should_delete = UnRef();
+ }
+ if (should_delete) {
+ delete this;
+ }
+ return std::make_pair(total_bytes, high_watermark);
+}
+
+bool TrackingAllocator::UnRef() {
+ CHECK_GE(ref_, 1);
+ --ref_;
+ return (ref_ == 0);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/tracking_allocator.h b/tensorflow/core/framework/tracking_allocator.h
new file mode 100644
index 0000000000..f809e3822c
--- /dev/null
+++ b/tensorflow/core/framework/tracking_allocator.h
@@ -0,0 +1,80 @@
+#ifndef TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_
+#define TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace tensorflow {
+
+// TrackingAllocator is a wrapper for an Allocator. It keeps a running
+// count of the number of bytes allocated through the wrapper. It is
+// used by the Executor to "charge" allocations to particular Op
+// executions. Each Op gets a separate TrackingAllocator wrapper
+// around the underlying allocator.
+//
+// The implementation assumes the invariant that all calls to
+// AllocateRaw by an Op (or work items spawned by the Op) will occur
+// before the Op's Compute method returns. Thus the high watermark is
+// established once Compute returns.
+//
+// DeallocateRaw can be called long after the Op has finished,
+// e.g. when an output tensor is deallocated, and the wrapper cannot
+// be deleted until the last of these calls has occurred. The
+// TrackingAllocator keeps track of outstanding calls using a
+// reference count, and deletes itself once the last call has been
+// received and the high watermark has been retrieved.
+class TrackingAllocator : public Allocator {
+ public:
+ explicit TrackingAllocator(Allocator* allocator);
+ string Name() override { return allocator_->Name(); }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override;
+ void DeallocateRaw(void* ptr) override;
+ bool TracksAllocationSizes() override;
+ size_t RequestedSize(void* ptr) override;
+ size_t AllocatedSize(void* ptr) override;
+
+ // If the underlying allocator tracks allocation sizes, this returns
+ // a pair where the first value is the total number of bytes
+ // allocated through this wrapper, and the second value is the high
+ // watermark of bytes allocated through this wrapper. If the
+ // underlying allocator does not track allocation sizes the first
+ // value is the total number of bytes requested through this wrapper
+ // and the second is 0.
+ //
+ // After GetSizesAndUnref is called, the only further calls allowed
+ // on this wrapper are calls to DeallocateRaw with pointers that
+ // were allocated by this wrapper and have not yet been
+ // deallocated. After this call completes and all allocated pointers
+ // have been deallocated the wrapper will delete itself.
+ std::pair<size_t, size_t> GetSizesAndUnRef();
+
+ private:
+ ~TrackingAllocator() override {}
+ bool UnRef() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ Allocator* allocator_; // not owned.
+ mutex mu_;
+ // the number of calls to AllocateRaw that have not yet been matched
+ // by a corresponding call to DeAllocateRaw, plus 1 if the Executor
+ // has not yet read out the high watermark.
+ int ref_ GUARDED_BY(mu_);
+ // the current number of outstanding bytes that have been allocated
+ // by this wrapper, or 0 if the underlying allocator does not track
+ // allocation sizes.
+ size_t allocated_ GUARDED_BY(mu_);
+ // the maximum number of outstanding bytes that have been allocated
+ // by this wrapper, or 0 if the underlying allocator does not track
+ // allocation sizes.
+ size_t high_watermark_ GUARDED_BY(mu_);
+ // the total number of bytes that have been allocated by this
+ // wrapper if the underlying allocator tracks allocation sizes,
+ // otherwise the total number of bytes that have been requested by
+ // this allocator.
+ size_t total_bytes_ GUARDED_BY(mu_);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_
diff --git a/tensorflow/core/framework/tracking_allocator_test.cc b/tensorflow/core/framework/tracking_allocator_test.cc
new file mode 100644
index 0000000000..90ce851775
--- /dev/null
+++ b/tensorflow/core/framework/tracking_allocator_test.cc
@@ -0,0 +1,115 @@
+#include "tensorflow/core/framework/tracking_allocator.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+class TestableSizeTrackingAllocator : public Allocator {
+ public:
+ string Name() override { return "test"; }
+ void* AllocateRaw(size_t /*alignment*/, size_t num_bytes) override {
+ void* ptr = malloc(num_bytes);
+ size_map_[ptr] = num_bytes;
+ return ptr;
+ }
+ void DeallocateRaw(void* ptr) override {
+ const auto& iter = size_map_.find(ptr);
+ EXPECT_NE(size_map_.end(), iter);
+ size_map_.erase(iter);
+ free(ptr);
+ }
+ bool TracksAllocationSizes() override { return true; }
+ size_t RequestedSize(void* ptr) override {
+ const auto& iter = size_map_.find(ptr);
+ EXPECT_NE(size_map_.end(), iter);
+ return iter->second;
+ }
+
+ private:
+ std::unordered_map<void*, size_t> size_map_;
+};
+
+class NoMemoryAllocator : public Allocator {
+ public:
+ string Name() override { return "test"; }
+ void* AllocateRaw(size_t /*alignment*/, size_t num_bytes) override {
+ return nullptr;
+ }
+ void DeallocateRaw(void* ptr) override {}
+ bool TracksAllocationSizes() override { return true; }
+};
+
+TEST(TrackingAllocatorTest, SimpleNoTracking) {
+ Allocator* a = cpu_allocator();
+
+ EXPECT_FALSE(a->TracksAllocationSizes());
+
+ TrackingAllocator* ta = new TrackingAllocator(a);
+
+ void* p1 = ta->AllocateRaw(4, 4);
+ ta->Deallocate(p1);
+ void* p2 = ta->AllocateRaw(4, 12);
+
+ std::pair<size_t, size_t> sizes = ta->GetSizesAndUnRef();
+
+ EXPECT_EQ(16, sizes.first);
+ EXPECT_EQ(0, sizes.second);
+
+ ta->Deallocate(p2);
+}
+
+TEST(TrackingAllocatorTest, SimpleTracking) {
+ TestableSizeTrackingAllocator a = TestableSizeTrackingAllocator();
+
+ EXPECT_TRUE(a.TracksAllocationSizes());
+
+ TrackingAllocator* ta = new TrackingAllocator(&a);
+
+ void* p1 = ta->AllocateRaw(4, 12);
+ ta->Deallocate(p1);
+ void* p2 = ta->AllocateRaw(4, 4);
+
+ std::pair<size_t, size_t> sizes = ta->GetSizesAndUnRef();
+
+ EXPECT_EQ(16, sizes.first);
+ EXPECT_EQ(12, sizes.second);
+
+ ta->Deallocate(p2);
+}
+
+TEST(TrackingAllocatorTest, OutOfMemory) {
+ NoMemoryAllocator a;
+
+ EXPECT_TRUE(a.TracksAllocationSizes());
+
+ TrackingAllocator* ta = new TrackingAllocator(&a);
+
+ void* p1 = ta->AllocateRaw(4, 12);
+ EXPECT_EQ(nullptr, p1);
+
+ std::pair<size_t, size_t> sizes = ta->GetSizesAndUnRef();
+
+ EXPECT_EQ(0, sizes.first);
+ EXPECT_EQ(0, sizes.second);
+}
+
+TEST(TrackingAllocatorTest, FreeNullPtr) {
+ NoMemoryAllocator a;
+
+ EXPECT_TRUE(a.TracksAllocationSizes());
+
+ TrackingAllocator* ta = new TrackingAllocator(&a);
+
+ ta->DeallocateRaw(nullptr);
+
+ std::pair<size_t, size_t> sizes = ta->GetSizesAndUnRef();
+
+ EXPECT_EQ(0, sizes.first);
+ EXPECT_EQ(0, sizes.second);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/type_traits.h b/tensorflow/core/framework/type_traits.h
new file mode 100644
index 0000000000..d87b6ff49b
--- /dev/null
+++ b/tensorflow/core/framework/type_traits.h
@@ -0,0 +1,69 @@
+#ifndef TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_
+#define TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_
+
+#include <limits>
+#include <utility>
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Functions to define quantization attribute of types.
+struct true_type {
+ static const bool value = true;
+};
+struct false_type {
+ static const bool value = false;
+};
+
+// Default is_quantized is false.
+template <typename T>
+struct is_quantized : false_type {};
+
+// Specialize the quantized types.
+template <>
+struct is_quantized<qint8> : true_type {};
+template <>
+struct is_quantized<quint8> : true_type {};
+template <>
+struct is_quantized<qint32> : true_type {};
+
+// All types not specialized are marked invalid.
+template <class T>
+struct IsValidDataType {
+ static constexpr bool value = false;
+};
+
+// Extra validity checking; not part of public API.
+struct TestIsValidDataType {
+ static_assert(IsValidDataType<int64>::value, "Incorrect impl for int64");
+ static_assert(IsValidDataType<int32>::value, "Incorrect impl for int32");
+};
+
+} // namespace tensorflow
+
+// Define numeric limits for our quantized as subclasses of the
+// standard types.
+namespace std {
+template <>
+class numeric_limits<tensorflow::qint8>
+ : public numeric_limits<tensorflow::int8> {};
+template <>
+class numeric_limits<tensorflow::quint8>
+ : public numeric_limits<tensorflow::uint8> {};
+template <>
+class numeric_limits<tensorflow::qint32>
+ : public numeric_limits<tensorflow::int32> {};
+
+// Specialize is_signed for quantized types.
+template <>
+struct is_signed<tensorflow::qint8> : public is_signed<tensorflow::int8> {};
+template <>
+struct is_signed<tensorflow::quint8> : public is_signed<tensorflow::uint8> {};
+template <>
+struct is_signed<tensorflow::qint32> : public is_signed<tensorflow::int32> {};
+
+} // namespace std
+
+#endif // TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_
diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc
new file mode 100644
index 0000000000..01b9fca3b6
--- /dev/null
+++ b/tensorflow/core/framework/types.cc
@@ -0,0 +1,210 @@
+#include "tensorflow/core/framework/types.h"
+
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+bool DeviceType::operator<(const DeviceType& other) const {
+ return type_ < other.type_;
+}
+
+bool DeviceType::operator==(const DeviceType& other) const {
+ return type_ == other.type_;
+}
+
+std::ostream& operator<<(std::ostream& os, const DeviceType& d) {
+ os << d.type();
+ return os;
+}
+
+const char* const DEVICE_CPU = "CPU";
+const char* const DEVICE_GPU = "GPU";
+
+string DataTypeString(DataType dtype) {
+ if (IsRefType(dtype)) {
+ DataType non_ref = static_cast<DataType>(dtype - kDataTypeRefOffset);
+ return strings::StrCat(DataTypeString(non_ref), "_ref");
+ }
+ switch (dtype) {
+ case DT_INVALID:
+ return "INVALID";
+ case DT_FLOAT:
+ return "float";
+ case DT_DOUBLE:
+ return "double";
+ case DT_INT32:
+ return "int32";
+ case DT_UINT8:
+ return "uint8";
+ case DT_INT16:
+ return "int16";
+ case DT_INT8:
+ return "int8";
+ case DT_STRING:
+ return "string";
+ case DT_COMPLEX64:
+ return "complex64";
+ case DT_INT64:
+ return "int64";
+ case DT_BOOL:
+ return "bool";
+ case DT_QINT8:
+ return "qint8";
+ case DT_QUINT8:
+ return "quint8";
+ case DT_QINT32:
+ return "qint32";
+ case DT_BFLOAT16:
+ return "bfloat16";
+ default:
+ LOG(FATAL) << "Unrecognized DataType enum value " << dtype;
+ return "";
+ }
+}
+
+bool DataTypeFromString(StringPiece sp, DataType* dt) {
+ if (sp.ends_with("_ref")) {
+ sp.remove_suffix(4);
+ DataType non_ref;
+ if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) {
+ *dt = static_cast<DataType>(non_ref + kDataTypeRefOffset);
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ if (sp == "float" || sp == "float32") {
+ *dt = DT_FLOAT;
+ return true;
+ } else if (sp == "double" || sp == "float64") {
+ *dt = DT_DOUBLE;
+ return true;
+ } else if (sp == "int32") {
+ *dt = DT_INT32;
+ return true;
+ } else if (sp == "uint8") {
+ *dt = DT_UINT8;
+ return true;
+ } else if (sp == "int16") {
+ *dt = DT_INT16;
+ return true;
+ } else if (sp == "int8") {
+ *dt = DT_INT8;
+ return true;
+ } else if (sp == "string") {
+ *dt = DT_STRING;
+ return true;
+ } else if (sp == "complex64") {
+ *dt = DT_COMPLEX64;
+ return true;
+ } else if (sp == "int64") {
+ *dt = DT_INT64;
+ return true;
+ } else if (sp == "bool") {
+ *dt = DT_BOOL;
+ return true;
+ } else if (sp == "qint8") {
+ *dt = DT_QINT8;
+ return true;
+ } else if (sp == "quint8") {
+ *dt = DT_QUINT8;
+ return true;
+ } else if (sp == "qint32") {
+ *dt = DT_QINT32;
+ return true;
+ } else if (sp == "bfloat16") {
+ *dt = DT_BFLOAT16;
+ return true;
+ }
+ return false;
+}
+
+string DeviceTypeString(DeviceType device_type) { return device_type.type(); }
+
+string DataTypeSliceString(const DataTypeSlice types) {
+ string out;
+ for (auto it = types.begin(); it != types.end(); ++it) {
+ strings::StrAppend(&out, ((it == types.begin()) ? "" : ", "),
+ DataTypeString(*it));
+ }
+ return out;
+}
+
+DataTypeVector AllTypes() {
+ return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16,
+ DT_INT8, DT_STRING, DT_COMPLEX64, DT_INT64, DT_BOOL,
+ DT_QINT8, DT_QUINT8, DT_QINT32};
+}
+
+#ifndef __ANDROID__
+
+DataTypeVector RealNumberTypes() {
+ return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16, DT_INT8};
+}
+
+DataTypeVector QuantizedTypes() { return {DT_QINT8, DT_QUINT8, DT_QINT32}; }
+
+DataTypeVector RealAndQuantizedTypes() {
+ return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8,
+ DT_INT16, DT_INT8, DT_QINT8, DT_QUINT8, DT_QINT32};
+}
+
+DataTypeVector NumberTypes() {
+ return {DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16,
+ DT_INT8, DT_COMPLEX64, DT_QINT8, DT_QUINT8, DT_QINT32};
+}
+
+#else // __ANDROID__
+
+DataTypeVector RealNumberTypes() { return {DT_FLOAT, DT_INT32}; }
+
+DataTypeVector NumberTypes() {
+ return {DT_FLOAT, DT_INT32, DT_QINT8, DT_QUINT8, DT_QINT32};
+}
+
+DataTypeVector QuantizedTypes() { return {DT_QINT8, DT_QUINT8, DT_QINT32}; }
+
+DataTypeVector RealAndQuantizedTypes() {
+ return {DT_FLOAT, DT_INT32, DT_QINT8, DT_QUINT8, DT_QINT32};
+}
+
+#endif // __ANDROID__
+
+// TODO(jeff): Maybe unify this with Tensor::CanUseDMA, or the underlying
+// is_simple<T> in tensor.cc (and possible choose a more general name?)
+bool DataTypeCanUseMemcpy(DataType dt) {
+ switch (dt) {
+ case DT_FLOAT:
+ case DT_DOUBLE:
+ case DT_INT32:
+ case DT_UINT8:
+ case DT_INT16:
+ case DT_INT8:
+ case DT_COMPLEX64:
+ case DT_INT64:
+ case DT_BOOL:
+ case DT_QINT8:
+ case DT_QUINT8:
+ case DT_QINT32:
+ case DT_BFLOAT16:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool DataTypeIsQuantized(DataType dt) {
+ switch (dt) {
+ case DT_QINT8:
+ case DT_QUINT8:
+ case DT_QINT32:
+ return true;
+ default:
+ return false;
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
new file mode 100644
index 0000000000..2d417cf076
--- /dev/null
+++ b/tensorflow/core/framework/types.h
@@ -0,0 +1,168 @@
+#ifndef TENSORFLOW_FRAMEWORK_TYPES_H_
+#define TENSORFLOW_FRAMEWORK_TYPES_H_
+
+#include <map>
+#include <set>
+#include <string>
+
+#include "tensorflow/core/framework/bfloat16.h"
+#include "tensorflow/core/framework/numeric_types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
+
+namespace tensorflow {
+
+// MemoryType is used to describe whether input or output Tensors of
+// an OpKernel should reside in "Host memory" (e.g., CPU memory) or
+// "Device" Memory (CPU memory for CPU devices, GPU memory for GPU
+// devices).
+enum MemoryType {
+ DEVICE_MEMORY = 0,
+ HOST_MEMORY = 1,
+};
+
+// A DeviceType is just a string, but we wrap it up in a class to give
+// some type checking as we're passing these around
+class DeviceType {
+ public:
+ DeviceType(const char* type) // NOLINT(runtime/explicit)
+ : type_(type) {}
+
+ explicit DeviceType(StringPiece type) : type_(type.data(), type.size()) {}
+
+ const char* type() const { return type_.c_str(); }
+
+ bool operator<(const DeviceType& other) const;
+ bool operator==(const DeviceType& other) const;
+ bool operator!=(const DeviceType& other) const { return !(*this == other); }
+
+ private:
+ string type_;
+};
+std::ostream& operator<<(std::ostream& os, const DeviceType& d);
+
+// Convenient constants that can be passed to a DeviceType constructor
+extern const char* const DEVICE_CPU; // "CPU"
+extern const char* const DEVICE_GPU; // "GPU"
+
+typedef gtl::InlinedVector<MemoryType, 4> MemoryTypeVector;
+
+typedef gtl::InlinedVector<DataType, 4> DataTypeVector;
+typedef gtl::ArraySlice<DataType> DataTypeSlice;
+
+typedef gtl::InlinedVector<DeviceType, 4> DeviceTypeVector;
+
+// Convert the enums to strings for errors:
+string DataTypeString(DataType dtype);
+string DeviceTypeString(DeviceType device_type);
+string DataTypeSliceString(const DataTypeSlice dtypes);
+inline string DataTypeVectorString(const DataTypeVector& dtypes) {
+ return DataTypeSliceString(dtypes);
+}
+
+// If "sp" names a valid type, store it in "*dt" and return true. Otherwise,
+// return false.
+bool DataTypeFromString(StringPiece sp, DataType* dt);
+
+// DT_FLOAT + kDataTypeRefOffset == DT_FLOAT_REF, etc.
+enum { kDataTypeRefOffset = 100 };
+inline bool IsRefType(DataType dtype) {
+ return dtype > static_cast<DataType>(kDataTypeRefOffset);
+}
+inline DataType MakeRefType(DataType dtype) {
+ DCHECK(!IsRefType(dtype));
+ return static_cast<DataType>(dtype + kDataTypeRefOffset);
+}
+inline DataType RemoveRefType(DataType dtype) {
+ DCHECK(IsRefType(dtype));
+ return static_cast<DataType>(dtype - kDataTypeRefOffset);
+}
+inline DataType BaseType(DataType dtype) {
+ return IsRefType(dtype) ? RemoveRefType(dtype) : dtype;
+}
+
+// Returns true if the actual type is the same as or ref of the expected type.
+inline bool TypesCompatible(DataType expected, DataType actual) {
+ return expected == actual || expected == BaseType(actual);
+}
+
+// Does not include _ref types.
+DataTypeVector AllTypes();
+
+// Return the list of all numeric types.
+// NOTE: On Android, we only include the float and int32 types for now.
+DataTypeVector RealNumberTypes(); // Types that support '<' and '>'.
+DataTypeVector NumberTypes(); // Includes complex and quantized types.
+
+DataTypeVector QuantizedTypes();
+DataTypeVector RealAndQuantizedTypes(); // Types that support '<' and
+ // '>', including quantized
+ // types
+
+// Validates type T for whether it is a supported DataType.
+template <class T>
+struct IsValidDataType;
+
+// DataTypeToEnum<T>::v() and DataTypeToEnum<T>::value are the DataType
+// constants for T, e.g. DataTypeToEnum<float>::v() is DT_FLOAT.
+template <class T>
+struct DataTypeToEnum {
+ static_assert(IsValidDataType<T>::value, "Specified Data Type not supported");
+}; // Specializations below
+
+// EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g.
+// EnumToDataType<DT_FLOAT>::Type is float.
+template <DataType VALUE>
+struct EnumToDataType {}; // Specializations below
+
+// Template specialization for both DataTypeToEnum and EnumToDataType.
+#define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \
+ template <> \
+ struct DataTypeToEnum<TYPE> { \
+ static DataType v() { return ENUM; } \
+ static DataType ref() { return MakeRefType(ENUM); } \
+ static constexpr DataType value = ENUM; \
+ }; \
+ template <> \
+ struct IsValidDataType<TYPE> { \
+ static constexpr bool value = true; \
+ }; \
+ template <> \
+ struct EnumToDataType<ENUM> { \
+ typedef TYPE Type; \
+ }
+
+// We use Eigen's QInt implementations for our quantized int types.
+typedef Eigen::QInt8 qint8;
+typedef Eigen::QUInt8 quint8;
+typedef Eigen::QInt32 qint32;
+
+MATCH_TYPE_AND_ENUM(float, DT_FLOAT);
+MATCH_TYPE_AND_ENUM(double, DT_DOUBLE);
+MATCH_TYPE_AND_ENUM(int32, DT_INT32);
+MATCH_TYPE_AND_ENUM(uint8, DT_UINT8);
+MATCH_TYPE_AND_ENUM(int16, DT_INT16);
+MATCH_TYPE_AND_ENUM(int8, DT_INT8);
+MATCH_TYPE_AND_ENUM(string, DT_STRING);
+MATCH_TYPE_AND_ENUM(complex64, DT_COMPLEX64);
+MATCH_TYPE_AND_ENUM(int64, DT_INT64);
+MATCH_TYPE_AND_ENUM(bool, DT_BOOL);
+MATCH_TYPE_AND_ENUM(qint8, DT_QINT8);
+MATCH_TYPE_AND_ENUM(quint8, DT_QUINT8);
+MATCH_TYPE_AND_ENUM(qint32, DT_QINT32);
+MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16);
+
+#undef MATCH_TYPE_AND_ENUM
+
+bool DataTypeCanUseMemcpy(DataType dt);
+
+bool DataTypeIsQuantized(DataType dt);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_TYPES_H_
diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto
new file mode 100644
index 0000000000..e5dc9c45a0
--- /dev/null
+++ b/tensorflow/core/framework/types.proto
@@ -0,0 +1,48 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+enum DataType {
+ // Not a legal value for DataType. Used to indicate a DataType field
+ // has not been set.
+ DT_INVALID = 0;
+
+ // Data types that all computation devices are expected to be
+ // capable to support.
+ DT_FLOAT = 1;
+ DT_DOUBLE = 2;
+ DT_INT32 = 3;
+ DT_UINT8 = 4;
+ DT_INT16 = 5;
+ DT_INT8 = 6;
+ DT_STRING = 7;
+ DT_COMPLEX64 = 8; // Single-precision complex
+ DT_INT64 = 9;
+ DT_BOOL = 10;
+ DT_QINT8 = 11; // Quantized int8
+ DT_QUINT8 = 12; // Quantized uint8
+ DT_QINT32 = 13; // Quantized int32
+ DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops.
+
+ // TODO(josh11b): DT_GENERIC_PROTO = ??;
+ // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? DT_UINT16?
+ // TODO(zhifengc): DT_COMPLEX128 (double-precision complex)?
+
+ // Do not use! These are only for parameters. Every enum above
+ // should have a corresponding value below (verified by types_test).
+ DT_FLOAT_REF = 101;
+ DT_DOUBLE_REF = 102;
+ DT_INT32_REF = 103;
+ DT_UINT8_REF = 104;
+ DT_INT16_REF = 105;
+ DT_INT8_REF = 106;
+ DT_STRING_REF = 107;
+ DT_COMPLEX64_REF = 108;
+ DT_INT64_REF = 109;
+ DT_BOOL_REF = 110;
+ DT_QINT8_REF = 111;
+ DT_QUINT8_REF = 112;
+ DT_QINT32_REF = 113;
+ DT_BFLOAT16_REF = 114;
+}
diff --git a/tensorflow/core/framework/types_test.cc b/tensorflow/core/framework/types_test.cc
new file mode 100644
index 0000000000..eb92600397
--- /dev/null
+++ b/tensorflow/core/framework/types_test.cc
@@ -0,0 +1,117 @@
+#include "tensorflow/core/framework/types.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/framework/type_traits.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(TypesTest, DeviceTypeName) {
+ EXPECT_EQ("CPU", DeviceTypeString(DeviceType(DEVICE_CPU)));
+ EXPECT_EQ("GPU", DeviceTypeString(DeviceType(DEVICE_GPU)));
+}
+
+TEST(TypesTest, kDataTypeRefOffset) {
+ // Basic sanity check
+ EXPECT_EQ(DT_FLOAT + kDataTypeRefOffset, DT_FLOAT_REF);
+
+ // Use the meta-data provided by proto2 to iterate through the basic
+ // types and validate that adding kDataTypeRefOffset gives the
+ // corresponding reference type.
+ const auto* enum_descriptor = DataType_descriptor();
+ int e = DataType_MIN;
+ if (e == DT_INVALID) ++e;
+ int e_ref = e + kDataTypeRefOffset;
+ EXPECT_FALSE(DataType_IsValid(e_ref - 1))
+ << "Reference enum "
+ << enum_descriptor->FindValueByNumber(e_ref - 1)->name()
+ << " without corresponding base enum with value " << e - 1;
+ for (;
+ DataType_IsValid(e) && DataType_IsValid(e_ref) && e_ref <= DataType_MAX;
+ ++e, ++e_ref) {
+ string enum_name = enum_descriptor->FindValueByNumber(e)->name();
+ string enum_ref_name = enum_descriptor->FindValueByNumber(e_ref)->name();
+ EXPECT_EQ(enum_name + "_REF", enum_ref_name)
+ << enum_name << "_REF should have value " << e_ref << " not "
+ << enum_ref_name;
+ // Validate DataTypeString() as well.
+ DataType dt_e = static_cast<DataType>(e);
+ DataType dt_e_ref = static_cast<DataType>(e_ref);
+ EXPECT_EQ(DataTypeString(dt_e) + "_ref", DataTypeString(dt_e_ref));
+
+ // Test DataTypeFromString reverse conversion
+ DataType dt_e2, dt_e2_ref;
+ EXPECT_TRUE(DataTypeFromString(DataTypeString(dt_e), &dt_e2));
+ EXPECT_EQ(dt_e, dt_e2);
+ EXPECT_TRUE(DataTypeFromString(DataTypeString(dt_e_ref), &dt_e2_ref));
+ EXPECT_EQ(dt_e_ref, dt_e2_ref);
+ }
+ ASSERT_FALSE(DataType_IsValid(e))
+ << "Should define " << enum_descriptor->FindValueByNumber(e)->name()
+ << "_REF to be " << e_ref;
+ ASSERT_FALSE(DataType_IsValid(e_ref))
+ << "Extra reference enum "
+ << enum_descriptor->FindValueByNumber(e_ref)->name()
+ << " without corresponding base enum with value " << e;
+ ASSERT_LT(DataType_MAX, e_ref) << "Gap in reference types, missing value for "
+ << e_ref;
+
+ // Make sure there are no enums defined after the last regular type before
+ // the first reference type.
+ for (; e < DataType_MIN + kDataTypeRefOffset; ++e) {
+ EXPECT_FALSE(DataType_IsValid(e))
+ << "Discontinuous enum value "
+ << enum_descriptor->FindValueByNumber(e)->name() << " = " << e;
+ }
+}
+
+TEST(TypesTest, DataTypeFromString) {
+ DataType dt;
+ ASSERT_TRUE(DataTypeFromString("int32", &dt));
+ EXPECT_EQ(DT_INT32, dt);
+ ASSERT_TRUE(DataTypeFromString("int32_ref", &dt));
+ EXPECT_EQ(DT_INT32_REF, dt);
+ EXPECT_FALSE(DataTypeFromString("int32_ref_ref", &dt));
+ EXPECT_FALSE(DataTypeFromString("foo", &dt));
+ EXPECT_FALSE(DataTypeFromString("foo_ref", &dt));
+ ASSERT_TRUE(DataTypeFromString("int64", &dt));
+ EXPECT_EQ(DT_INT64, dt);
+ ASSERT_TRUE(DataTypeFromString("int64_ref", &dt));
+ EXPECT_EQ(DT_INT64_REF, dt);
+ ASSERT_TRUE(DataTypeFromString("quint8_ref", &dt));
+ EXPECT_EQ(DT_QUINT8_REF, dt);
+ ASSERT_TRUE(DataTypeFromString("bfloat16", &dt));
+ EXPECT_EQ(DT_BFLOAT16, dt);
+}
+
+template <typename T>
+static bool GetQuantized() {
+ return is_quantized<T>::value;
+}
+
+TEST(TypesTest, QuantizedTypes) {
+ // NOTE: GUnit cannot parse is::quantized<TYPE>::value() within the
+ // EXPECT_TRUE() clause, so we delegate through a template function.
+ EXPECT_TRUE(GetQuantized<qint8>());
+ EXPECT_TRUE(GetQuantized<quint8>());
+ EXPECT_TRUE(GetQuantized<qint32>());
+
+ EXPECT_FALSE(GetQuantized<int8>());
+ EXPECT_FALSE(GetQuantized<uint8>());
+ EXPECT_FALSE(GetQuantized<int16>());
+ EXPECT_FALSE(GetQuantized<int32>());
+
+ EXPECT_TRUE(DataTypeIsQuantized(DT_QINT8));
+ EXPECT_TRUE(DataTypeIsQuantized(DT_QUINT8));
+ EXPECT_TRUE(DataTypeIsQuantized(DT_QINT32));
+
+ EXPECT_FALSE(DataTypeIsQuantized(DT_INT8));
+ EXPECT_FALSE(DataTypeIsQuantized(DT_UINT8));
+ EXPECT_FALSE(DataTypeIsQuantized(DT_INT16));
+ EXPECT_FALSE(DataTypeIsQuantized(DT_INT32));
+ EXPECT_FALSE(DataTypeIsQuantized(DT_BFLOAT16));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc
new file mode 100644
index 0000000000..fd79ead0b1
--- /dev/null
+++ b/tensorflow/core/graph/algorithm.cc
@@ -0,0 +1,107 @@
+#include "tensorflow/core/graph/algorithm.h"
+
+#include <algorithm>
+#include <deque>
+#include <vector>
+
+namespace tensorflow {
+
+void DFS(const Graph& g, std::function<void(Node*)> enter,
+ std::function<void(Node*)> leave) {
+ // Stack of work to do.
+ struct Work {
+ Node* node;
+ bool leave; // Are we entering or leaving n?
+ };
+ std::vector<Work> stack;
+ stack.push_back(Work{g.source_node(), false});
+
+ std::vector<bool> visited(g.num_node_ids(), false);
+ while (!stack.empty()) {
+ Work w = stack.back();
+ stack.pop_back();
+
+ Node* n = w.node;
+ if (w.leave) {
+ leave(n);
+ continue;
+ }
+
+ if (visited[n->id()]) continue;
+ visited[n->id()] = true;
+ if (enter) enter(n);
+
+ // Arrange to call leave(n) when all done with descendants.
+ if (leave) stack.push_back(Work{n, true});
+
+ // Arrange to work on descendants.
+ for (Node* out : n->out_nodes()) {
+ if (!visited[out->id()]) {
+ // Note; we must not mark as visited until we actually process it.
+ stack.push_back(Work{out, false});
+ }
+ }
+ }
+}
+
+void GetPostOrder(const Graph& g, std::vector<Node*>* order) {
+ order->clear();
+ DFS(g, nullptr, [order](Node* n) { order->push_back(n); });
+}
+
+void GetReversePostOrder(const Graph& g, std::vector<Node*>* order) {
+ GetPostOrder(g, order);
+ std::reverse(order->begin(), order->end());
+}
+
+void PruneForReverseReachability(Graph* g,
+ const std::unordered_set<const Node*>& nodes) {
+ std::unordered_set<const Node*> visited;
+
+ // Compute set of nodes that we need to traverse in order to reach
+ // the nodes in "nodes" by performing a breadth-first search from those
+ // nodes, and accumulating the visited nodes.
+ std::deque<const Node*> queue;
+ for (const Node* n : nodes) {
+ queue.push_back(n);
+ }
+ while (!queue.empty()) {
+ const Node* n = queue.front();
+ queue.pop_front();
+ if (visited.insert(n).second) {
+ for (const Node* in : n->in_nodes()) {
+ queue.push_back(in);
+ }
+ }
+ }
+
+ // Make a pass over the graph to remove nodes not in "visited"
+ std::vector<Node*> all_nodes;
+ for (Node* n : g->nodes()) {
+ all_nodes.push_back(n);
+ }
+
+ for (Node* n : all_nodes) {
+ if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) {
+ g->RemoveNode(n);
+ }
+ }
+
+ // Reconnect nodes with no outgoing edges to the sink node
+ FixupSourceAndSinkEdges(g);
+}
+
+void FixupSourceAndSinkEdges(Graph* g) {
+ // Connect all nodes with no incoming edges to source.
+ // Connect all nodes with no outgoing edges to sink.
+ for (Node* n : g->nodes()) {
+ if (!n->IsSource() && n->in_edges().empty()) {
+ g->AddControlEdge(g->source_node(), n);
+ }
+ if (!n->IsSink() && n->out_edges().empty()) {
+ g->AddControlEdge(n, g->sink_node());
+ }
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/algorithm.h b/tensorflow/core/graph/algorithm.h
new file mode 100644
index 0000000000..58b74a0ace
--- /dev/null
+++ b/tensorflow/core/graph/algorithm.h
@@ -0,0 +1,40 @@
+#ifndef TENSORFLOW_GRAPH_ALGORITHM_H_
+#define TENSORFLOW_GRAPH_ALGORITHM_H_
+
+#include <functional>
+#include <unordered_set>
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// Perform a depth-first-search on g starting at the source node.
+// If enter is not empty, calls enter(n) before visiting any children of n.
+// If leave is not empty, calls leave(n) after visiting all children of n.
+extern void DFS(const Graph& g, std::function<void(Node*)> enter,
+ std::function<void(Node*)> leave);
+
+// Stores in *order the post-order numbering of all nodes
+// in graph found via a depth first search starting at the source node.
+//
+// Note that this is equivalent to topological sorting when the
+// graph does not have cycles.
+//
+// REQUIRES: order is not NULL.
+void GetPostOrder(const Graph& g, std::vector<Node*>* order);
+
+// Stores in *order the reverse post-order numbering of all nodes
+void GetReversePostOrder(const Graph& g, std::vector<Node*>* order);
+
+// Prune nodes in "g" that are not in some path from the source node
+// to any node in 'nodes'.
+void PruneForReverseReachability(Graph* g,
+ const std::unordered_set<const Node*>& nodes);
+
+// Connect all nodes with no incoming edges to source.
+// Connect all nodes with no outgoing edges to sink.
+void FixupSourceAndSinkEdges(Graph* g);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_ALGORITHM_H_
diff --git a/tensorflow/core/graph/algorithm_test.cc b/tensorflow/core/graph/algorithm_test.cc
new file mode 100644
index 0000000000..48f2e1ebd7
--- /dev/null
+++ b/tensorflow/core/graph/algorithm_test.cc
@@ -0,0 +1,103 @@
+#include "tensorflow/core/graph/algorithm.h"
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/subgraph.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+// TODO(josh11b): Test setting the "device" field of a NodeDef.
+// TODO(josh11b): Test that feeding won't prune targets.
+
+namespace tensorflow {
+namespace {
+
+REGISTER_OP("TestParams").Output("o: float");
+REGISTER_OP("TestInput").Output("a: float").Output("b: float");
+REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
+
+// Compares that the order of nodes in 'inputs' respects the
+// pair orders described in 'ordered_pairs'.
+bool ExpectBefore(const std::vector<std::pair<string, string>>& ordered_pairs,
+ const std::vector<Node*>& inputs, string* error) {
+ for (const std::pair<string, string>& pair : ordered_pairs) {
+ const string& before_node = pair.first;
+ const string& after_node = pair.second;
+ bool seen_before = false;
+ bool seen_both = false;
+ for (const Node* node : inputs) {
+ if (!seen_before && after_node == node->name()) {
+ *error = strings::StrCat("Saw ", after_node, " before ", before_node);
+ return false;
+ }
+
+ if (before_node == node->name()) {
+ seen_before = true;
+ } else if (after_node == node->name()) {
+ seen_both = seen_before;
+ break;
+ }
+ }
+ if (!seen_both) {
+ *error = strings::StrCat("didn't see either ", before_node, " or ",
+ after_node);
+ return false;
+ }
+ }
+
+ return true;
+}
+
+TEST(AlgorithmTest, ReversePostOrder) {
+ RequireDefaultOps();
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Node* w1 = SourceOp("TestParams", b.opts().WithName("W1"));
+ Node* w2 = SourceOp("TestParams", b.opts().WithName("W2"));
+ Node* input =
+ SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1));
+ Node* t1 = BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t1"));
+ BinaryOp("TestMul", w1, {input, 1},
+ b.opts().WithName("t2").WithControlInput(t1));
+ BinaryOp("TestMul", w2, {input, 1}, b.opts().WithName("t3"));
+
+ Graph g(OpRegistry::Global());
+ ASSERT_OK(b.ToGraph(&g));
+ std::vector<Node*> order;
+
+ // Test reverse post order:
+ GetReversePostOrder(g, &order);
+
+ // Check that the order respects the dependencies correctly.
+ std::vector<std::pair<string, string>> reverse_orders = {
+ {"W1", "input"}, {"W1", "t1"}, {"W1", "t2"}, {"W1", "t3"},
+ {"input", "t1"}, {"input", "t3"}, {"t1", "t2"}, {"W2", "t3"}};
+ string error;
+ EXPECT_TRUE(ExpectBefore(reverse_orders, order, &error)) << error;
+
+ // A false ordering should fail the check.
+ reverse_orders = {{"input", "W1"}};
+ EXPECT_FALSE(ExpectBefore(reverse_orders, order, &error));
+
+ // Test post order:
+ GetPostOrder(g, &order);
+
+ // Check that the order respects the dependencies correctly.
+ std::vector<std::pair<string, string>> orders = {
+ {"input", "W1"}, {"t1", "W1"}, {"t2", "W1"}, {"t3", "W1"},
+ {"t1", "input"}, {"t3", "input"}, {"t2", "t1"}, {"t3", "W2"}};
+ EXPECT_TRUE(ExpectBefore(orders, order, &error)) << error;
+
+ // A false ordering should fail the check.
+ orders = {{"W1", "t3"}};
+ EXPECT_FALSE(ExpectBefore(orders, order, &error));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/colors.cc b/tensorflow/core/graph/colors.cc
new file mode 100644
index 0000000000..0eb2fc3740
--- /dev/null
+++ b/tensorflow/core/graph/colors.cc
@@ -0,0 +1,25 @@
+#include "tensorflow/core/graph/colors.h"
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Color palette
+// http://www.mulinblog.com/a-color-palette-optimized-for-data-visualization/
+static const char* kColors[] = {
+ "#F15854", // red
+ "#5DA5DA", // blue
+ "#FAA43A", // orange
+ "#60BD68", // green
+ "#F17CB0", // pink
+ "#B2912F", // brown
+ "#B276B2", // purple
+ "#DECF3F", // yellow
+ "#4D4D4D", // gray
+};
+
+const char* ColorFor(int dindex) {
+ return kColors[dindex % TF_ARRAYSIZE(kColors)];
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/colors.h b/tensorflow/core/graph/colors.h
new file mode 100644
index 0000000000..150c8dc025
--- /dev/null
+++ b/tensorflow/core/graph/colors.h
@@ -0,0 +1,14 @@
+#ifndef TENSORFLOW_GRAPH_COLORS_H_
+#define TENSORFLOW_GRAPH_COLORS_H_
+
+namespace tensorflow {
+
+// Return a color drawn from a palette to represent an entity
+// identified by "i". The return value has the form "#RRGGBB" Note
+// that the palette has a limited set of colors and therefore colors
+// will be reused eventually.
+const char* ColorFor(int dindex);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_COLORS_H_
diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc
new file mode 100644
index 0000000000..89bc41acfd
--- /dev/null
+++ b/tensorflow/core/graph/costmodel.cc
@@ -0,0 +1,308 @@
+#include "tensorflow/core/graph/costmodel.h"
+
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace {
+const Microseconds kDefaultTimeEstimate(1);
+const Microseconds kMinTimeEstimate(1);
+} // namespace
+
+void CostModel::SuppressInfrequent() {
+ // Find the median of the non-zero counts, and use half of its value
+ // as the cutoff for a "normal" execution mode node.
+ if (count_.empty()) return;
+ std::vector<int32> non_zero;
+ for (auto v : count_) {
+ if (v > 0) non_zero.push_back(v);
+ }
+ const size_t sz = non_zero.size();
+ if (sz > 0) {
+ std::nth_element(non_zero.begin(), non_zero.begin() + sz / 2,
+ non_zero.end());
+ int32 median_value = non_zero[sz / 2];
+ min_count_ = median_value / 2;
+ VLOG(1) << "num non_zero vals: " << non_zero.size() << " median_value "
+ << median_value;
+ } else {
+ min_count_ = 1;
+ }
+}
+
+void CostModel::MergeFromLocal(const Graph& g, const CostModel& cm) {
+ CHECK(is_global_);
+ CHECK(!cm.is_global());
+ for (const Node* n : g.nodes()) {
+ const int local_id = cm.Id(n);
+ const int global_id = Id(n);
+ if (local_id < 0 || global_id < 0) continue;
+ Ensure(global_id);
+ count_[global_id] += cm.count_[local_id];
+ time_[global_id] += cm.time_[local_id];
+ int num_slots = cm.slot_bytes_[local_id].size();
+ if (num_slots > 0) {
+ if (slot_bytes_[global_id].size() == 0) {
+ slot_bytes_[global_id].resize(num_slots);
+ } else {
+ CHECK_EQ(num_slots, slot_bytes_[global_id].size());
+ }
+ for (int s = 0; s < num_slots; ++s) {
+ slot_bytes_[global_id][s] += cm.slot_bytes_[local_id][s];
+ }
+ }
+ }
+}
+
+void CostModel::MergeFromGlobal(const CostModel& cm) {
+ CHECK(is_global_);
+ CHECK_EQ(true, cm.is_global());
+ const int num_nodes = cm.count_.size();
+ Ensure(num_nodes);
+ for (int i = 0; i < num_nodes; ++i) {
+ count_[i] += cm.count_[i];
+ time_[i] += cm.time_[i];
+ int num_slots = cm.slot_bytes_[i].size();
+ if (num_slots > 0) {
+ if (slot_bytes_[i].size() == 0) {
+ slot_bytes_[i].resize(num_slots);
+ } else {
+ CHECK_EQ(num_slots, slot_bytes_[i].size());
+ }
+ for (int s = 0; s < num_slots; ++s) {
+ slot_bytes_[i][s] += cm.slot_bytes_[i][s];
+ }
+ }
+ }
+}
+
+void CostModel::MergeFromStats(const NodeNameToCostIdMap& map,
+ const StepStats& ss) {
+ CHECK(is_global_);
+ for (auto& ds : ss.dev_stats()) {
+ for (auto& ns : ds.node_stats()) {
+ NodeNameToCostIdMap::const_iterator iter = map.find(ns.node_name());
+ // We don't keep stats for nodes not in the global graph, i.e.
+ // copy/send/recv nodes, feed/fetch, etc.
+ if (iter == map.end()) continue;
+ int32 global_id = iter->second;
+ Ensure(global_id);
+ int64 elapsed_micros = ns.op_end_rel_micros() - ns.op_start_rel_micros();
+ count_[global_id]++;
+ time_[global_id] += elapsed_micros;
+ for (auto& no : ns.output()) {
+ int si = no.slot();
+ if (static_cast<size_t>(si) >= slot_bytes_[global_id].size()) {
+ slot_bytes_[global_id].resize(1 + si);
+ }
+ slot_bytes_[global_id][si] +=
+ no.tensor_description().allocation_description().requested_bytes();
+ }
+ }
+ }
+}
+
+void CostModel::Ensure(int id) {
+ if (slot_bytes_.size() <= static_cast<size_t>(id)) {
+ slot_bytes_.resize(id + 1);
+ count_.resize(id + 1);
+ time_.resize(id + 1);
+ }
+}
+
+void CostModel::SetNumOutputs(const Node* node, int num_outputs) {
+ const int id = Id(node);
+ if (id < 0) return;
+ Ensure(id);
+ auto perslot = &slot_bytes_[id];
+ if (perslot->size() > 0) {
+ CHECK_EQ(num_outputs, perslot->size()) << "Cannot resize slot_bytes, node="
+ << node->name();
+ } else {
+ perslot->resize(num_outputs, Bytes(-1));
+ }
+}
+
+void CostModel::RecordCount(const Node* node, int count) {
+ const int id = Id(node);
+ if (id < 0) return;
+ CHECK_LT(id, slot_bytes_.size());
+ count_[id] += count;
+}
+
+int32 CostModel::TotalCount(const Node* node) const {
+ const int id = Id(node);
+ if (id < 0) return 0;
+ return (static_cast<size_t>(id) < slot_bytes_.size()) ? count_[id] : 0;
+}
+
+void CostModel::RecordSize(const Node* node, int slot, Bytes bytes) {
+ const int id = Id(node);
+ if (id < 0) return;
+ CHECK_LT(id, slot_bytes_.size());
+ auto perslot = &slot_bytes_[id];
+ CHECK_LT(slot, perslot->size());
+ auto v = &(*perslot)[slot];
+ if (*v >= 0) {
+ *v += bytes;
+ } else {
+ *v = bytes;
+ }
+}
+
+Bytes CostModel::TotalBytes(const Node* node, int slot) const {
+ const int id = Id(node);
+ if (id < 0 || static_cast<size_t>(id) >= slot_bytes_.size() ||
+ slot_bytes_[id].size() <= static_cast<size_t>(slot)) {
+ return Bytes(0);
+ }
+ return slot_bytes_[id][slot];
+}
+
+Bytes CostModel::SizeEstimate(const Node* node, int slot) const {
+ int32 count = TotalCount(node);
+ if (count < min_count_) return Bytes(0);
+ return TotalBytes(node, slot) / std::max(1, TotalCount(node));
+}
+
+void CostModel::RecordTime(const Node* node, Microseconds time) {
+ const int id = Id(node);
+ if (id < 0) return;
+ DCHECK(node->IsOp()) << node->DebugString();
+ Ensure(id);
+ time_[id] += time;
+}
+
+Microseconds CostModel::TotalTime(const Node* node) const {
+ DCHECK(node->IsOp()) << node->DebugString();
+ const int id = Id(node);
+ if (id < 0 || static_cast<size_t>(id) >= time_.size() ||
+ time_[id] < Microseconds(0)) {
+ return Microseconds(0);
+ }
+ return time_[id];
+}
+
+Microseconds CostModel::TimeEstimate(const Node* node) const {
+ int32 count = TotalCount(node);
+ if (count <= min_count_) return kMinTimeEstimate;
+ return std::max(kMinTimeEstimate, TotalTime(node) / std::max(1, count));
+}
+
+void CostModel::CheckInitialized(const Graph& graph) const {
+ for (const Node* n : graph.nodes()) {
+ if (n->IsOp()) {
+ CHECK(static_cast<size_t>(n->id()) < time_.size() &&
+ time_[n->id()] >= Microseconds(0))
+ << ": no time estimate for " << n->DebugString();
+
+ CHECK(static_cast<size_t>(n->id()) < slot_bytes_.size())
+ << ": no size estimate for " << n->DebugString();
+ const auto& perslot = slot_bytes_[n->id()];
+ for (size_t i = 0; i < perslot.size(); i++) {
+ CHECK_GE(perslot[i], Bytes(0)) << ": no size estimate for output# " << i
+ << " of " << n->DebugString();
+ }
+ }
+ }
+}
+
+Microseconds CostModel::CopyTimeEstimate(Bytes b, double network_latency_millis,
+ double estimated_gbps) {
+ // TODO(jeff,sanjay): estimate cost based on bandwidth along the
+ // communication path and the type of transport we are using between
+ // devices.
+ //
+ // We assume the copy time follows a linear model:
+ // copy_time = copy_bytes / rate + min_time
+ int64 copy_bytes = b.value();
+ const double bytes_per_usec = estimated_gbps * 1000.0 / 8;
+ const double min_micros = network_latency_millis * 1000.0;
+ return Microseconds(
+ static_cast<int64>(copy_bytes / bytes_per_usec + min_micros));
+}
+
+Microseconds CostModel::ComputationTimeEstimate(int64 math_ops) {
+ // TODO(jeff,sanjay): Eventually we should pass in the type of device
+ // (GPU vs. CPU) and use that to affect the estimate.
+
+ // We estimate the microseconds using that value. We divide
+ // by 1000 to convert the madd number into microseconds (assuming
+ // roughly 1000 madds per microsecond (~1 GHz for one core)).
+ return Microseconds(math_ops / 1000);
+}
+
+// ----------------------------------------------------------------------------
+// InitCostModel
+// ----------------------------------------------------------------------------
+
+namespace {
+
+static void AddNodesToCostModel(const Graph& g, CostModel* cost_model) {
+ for (Node* n : g.nodes()) {
+ const int num_outputs = n->num_outputs();
+ cost_model->SetNumOutputs(n, num_outputs);
+ for (int output = 0; output < num_outputs; output++) {
+ // Set up an initial bogus estimate for the node's outputs
+ cost_model->RecordSize(n, output, Bytes(1));
+ }
+ }
+}
+
+static void AssignSizes(const Graph& g, CostModel* cost_model) {
+ for (const Edge* e : g.edges()) {
+ // Skip if it is a control edge.
+ if (e->IsControlEdge()) {
+ continue;
+ }
+ Node* src = e->src();
+
+ // TODO(josh11b): Get an estimate from the Op
+ Bytes size(1);
+ cost_model->RecordSize(src, e->src_output(), size);
+ }
+}
+
+// This generates an extremely simple initial guess for the
+// computation cost of each node. For ordinary Ops, its value should quickly
+// be wiped out by the real runtime measurements. For other Ops we don't
+// actually generate measurements, so suppression of infrequent Ops ends up
+// giving them 0 costs. So, this is not of much consequence except perhaps
+// in tests.
+static Microseconds TimeEstimateForNode(CostModel* cost_model, Node* n) {
+ CHECK(n->IsOp());
+ VLOG(2) << "Node " << n->id() << ": " << n->name()
+ << " type_string: " << n->type_string();
+ if (IsConstant(n) || IsVariable(n)) {
+ return Microseconds(0);
+ }
+ return kDefaultTimeEstimate;
+}
+
+static void EstimateComputationCosts(const Graph& g, CostModel* cost_model) {
+ for (Node* n : g.nodes()) {
+ if (!n->IsOp()) continue;
+ cost_model->RecordTime(n, TimeEstimateForNode(cost_model, n));
+ }
+}
+
+} // namespace
+
+void CostModel::InitFromGraph(const Graph& g) {
+ AddNodesToCostModel(g, this);
+ AssignSizes(g, this);
+ EstimateComputationCosts(g, this);
+ CheckInitialized(g);
+}
+
+void CostModel::WriteToLog() {
+ LOG(INFO) << " min_count_=" << min_count_;
+ for (size_t i = 0; i < count_.size(); ++i) {
+ LOG(INFO) << "Node " << i << " count " << count_[i] << " total time "
+ << time_[i] << " avg time "
+ << (time_[i] / (std::max(1, count_[i])));
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/costmodel.h b/tensorflow/core/graph/costmodel.h
new file mode 100644
index 0000000000..4d7dd65f5a
--- /dev/null
+++ b/tensorflow/core/graph/costmodel.h
@@ -0,0 +1,123 @@
+#ifndef TENSORFLOW_GRAPH_COSTMODEL_H_
+#define TENSORFLOW_GRAPH_COSTMODEL_H_
+
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tensorflow {
+typedef std::unordered_map<string, int32> NodeNameToCostIdMap;
+
+class StepStats;
+
+// CostModel keeps track of the following runtime statistics for nodes
+// of a single Graph:
+// * The total number of times a node has executed.
+// * The accumulated execution time (in microseconds) of a node.
+// * The accumulated size (in bytes) of each node's output.
+//
+// This class is NOT thread-safe.
+class CostModel {
+ public:
+ // If "global" is true, maintains costs based on Node::cost_id, otherwise
+ // maintains costs based on Node::id.
+ explicit CostModel(bool is_global) : is_global_(is_global) {}
+
+ // Assigns min_count_ as a function of the median count for a Node.
+ // This value is then used for suppressing the time/size costs of
+ // infrequent operations.
+ // NOTE(tucker): Maybe this should move to a subclass of CostModel.
+ void SuppressInfrequent();
+
+ bool is_global() const { return is_global_; }
+
+ // Initializes cost model for 'g'.
+ void InitFromGraph(const Graph& g);
+
+ // Merges costs from cm.
+ // REQUIRES: is_global_ is true for this and for "cm"
+ void MergeFromGlobal(const CostModel& cm);
+
+ // Merges costs from "cm", which has been computed relative to "g".
+ // REQUIRES: is_global_ is true for this, and false for "cm".
+ void MergeFromLocal(const Graph& g, const CostModel& cm);
+
+ void MergeFromStats(const NodeNameToCostIdMap& map, const StepStats& ss);
+
+ // Sets the number of outputs of "node".
+ void SetNumOutputs(const Node* node, int num_outputs);
+
+ // Records that "node" has executed "num_count" more times.
+ void RecordCount(const Node* node, int num_count);
+
+ // Returns how many times "node" has been executed.
+ int32 TotalCount(const Node* node) const;
+
+ // Records that "output_slot" of "node" has produced tensors of
+ // aggregated "bytes".
+ void RecordSize(const Node* node, int output_slot, Bytes bytes);
+
+ // Returns total bytes of tensors produced by "node"s output slot.
+ Bytes TotalBytes(const Node* node, int output_slot) const;
+
+ // Returns a prediction for the size of the tensor at the
+ // output_slot produced by one execution of "node".
+ Bytes SizeEstimate(const Node* node, int output_slot) const;
+
+ // Records that Executions of "node" have taken "time" microseconds.
+ void RecordTime(const Node* node, Microseconds time);
+
+ // Returns the total execution time for "node".
+ Microseconds TotalTime(const Node* node) const;
+
+ // Returns a prediction for one execution of "node".
+ Microseconds TimeEstimate(const Node* node) const;
+
+ // Check that an estimate is available for every OP node in graph.
+ void CheckInitialized(const Graph& graph) const;
+
+ // Helper routines to encapsulate static estimatation heuristics
+
+ // Compute an estimate of the time to copy "b" bytes over the network,
+ // given a fixed cost of "network_latency_millis" milliseconds and
+ // an estimated bandwidth of "estimated_gbps" gigabits per second (note that
+ // this value is in gigabits, not gigabytes).
+ static Microseconds CopyTimeEstimate(Bytes b, double network_latency_millis,
+ double estimated_gbps);
+ static Microseconds ComputationTimeEstimate(int64 mathops);
+
+ // Write the contents of the CostModel to the INFO log.
+ void WriteToLog();
+
+ private:
+ const bool is_global_;
+ inline int Id(const Node* n) const {
+ if (is_global_) {
+ return n->cost_id();
+ } else {
+ return n->id();
+ }
+ }
+ // Resizes vectors so that they are large enough for "id".
+ void Ensure(int id);
+
+ // Nodes and Edges whose count is < this value
+ // get type/byte estimates of 0.
+ int32 min_count_ = 0;
+
+ // Number of times each Node has been executed.
+ std::vector<int32> count_;
+ // Cumulative execution time.
+ std::vector<Microseconds> time_;
+ // Cumulative Bytes output on each channel.
+ std::vector<gtl::InlinedVector<Bytes, 2> > slot_bytes_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CostModel);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_COSTMODEL_H_
diff --git a/tensorflow/core/graph/costutil.cc b/tensorflow/core/graph/costutil.cc
new file mode 100644
index 0000000000..f8e2d9fe68
--- /dev/null
+++ b/tensorflow/core/graph/costutil.cc
@@ -0,0 +1,22 @@
+#include "tensorflow/core/graph/costutil.h"
+
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/costmodel.h"
+
+namespace tensorflow {
+
+std::vector<int64> LongestOutgoingPathCost(const Graph& graph,
+ const CostModel& cm) {
+ std::vector<int64> result(graph.num_node_ids());
+ DFS(graph, nullptr, [&result, &cm](Node* n) {
+ int64 max_child = 0;
+ for (const Node* out : n->out_nodes()) {
+ max_child = std::max(max_child, result[out->id()]);
+ }
+ result[n->id()] = max_child + (n->IsOp() ? cm.TimeEstimate(n).value() : 0);
+ });
+ return result;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/costutil.h b/tensorflow/core/graph/costutil.h
new file mode 100644
index 0000000000..46e5215132
--- /dev/null
+++ b/tensorflow/core/graph/costutil.h
@@ -0,0 +1,19 @@
+#ifndef TENSORFLOW_GRAPH_COSTUTIL_H_
+#define TENSORFLOW_GRAPH_COSTUTIL_H_
+
+#include <vector>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+class CostModel;
+class Graph;
+
+// result[i] is an estimate of the longest execution path from
+// the node with id i to the sink node.
+std::vector<int64> LongestOutgoingPathCost(const Graph& graph,
+ const CostModel& cm);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_COSTUTIL_H_
diff --git a/tensorflow/core/graph/default_device.h b/tensorflow/core/graph/default_device.h
new file mode 100644
index 0000000000..30cd4e8a57
--- /dev/null
+++ b/tensorflow/core/graph/default_device.h
@@ -0,0 +1,25 @@
+#ifndef TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_
+#define TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/graph.pb.h"
+
+namespace tensorflow {
+namespace graph {
+
+// Sets the default device for all nodes in graph_def to "device",
+// only if not already set.
+inline void SetDefaultDevice(const string& device, GraphDef* graph_def) {
+ for (int i = 0; i < graph_def->node_size(); ++i) {
+ auto node = graph_def->mutable_node(i);
+ if (node->device().empty()) {
+ node->set_device(device);
+ }
+ }
+}
+
+} // namespace graph
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_
diff --git a/tensorflow/core/graph/dot.cc b/tensorflow/core/graph/dot.cc
new file mode 100644
index 0000000000..6d6e46ce61
--- /dev/null
+++ b/tensorflow/core/graph/dot.cc
@@ -0,0 +1,289 @@
+#include "tensorflow/core/graph/dot.h"
+
+#include <map>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "tensorflow/core/graph/colors.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/util/util.h"
+
+namespace tensorflow {
+
+static string GraphNodeName(const DotOptions& opts, const Node* n) {
+ return strings::StrCat("N", n->id());
+}
+
+bool ShoulDisplayOpType(const Node* n) {
+ if (n->type_string() == "NoOp") {
+ return false;
+ }
+ const string& op_name = n->def().name();
+ if (op_name.find(n->type_string() + "_") == 0) {
+ return false;
+ }
+ return true;
+}
+
+string DotGraph(const Graph& g, const DotOptions& opts) {
+ RegexpStringPiece flag(opts.prefix_collapse_regexp);
+ if (flag == "all") {
+ flag = ".";
+ } else if (flag == "none") {
+ flag = "^$";
+ }
+ RE2 cluster_name_pattern(flag);
+ string result;
+ strings::StrAppend(&result, "digraph G {\n");
+ strings::StrAppend(&result, "rankdir=\"BT\"\n");
+
+ std::map<string, int> device_index; // Map from device name to index.
+ std::unordered_set<Node*> visible_nodes; // Nodes to display.
+ // Cluster name => set of nodes.
+ std::unordered_map<string, std::unordered_set<Node*> > clusters;
+ // Node* => Cluster
+ std::unordered_map<Node*, string> node_cluster;
+ for (Node* src : g.nodes()) {
+ if (opts.include_node_function != nullptr &&
+ !opts.include_node_function(src)) {
+ continue;
+ }
+ // Do not display source and sink nodes
+ if (src->IsSource() || src->IsSink()) {
+ continue;
+ }
+ visible_nodes.insert(src);
+ const string name_prefix = NodeNamePrefix(src->def().name()).ToString();
+ if (!name_prefix.empty()) {
+ clusters[name_prefix].insert(src);
+ node_cluster[src] = name_prefix;
+ }
+ // Record device if present.
+ if (src->IsOp()) {
+ const string& d = src->assigned_device_name();
+ if (!d.empty()) {
+ device_index[d] = -1; // Assigned later
+ }
+ }
+ }
+
+ // Add nodes whose name is exactly a cluster name to the cluster itself.
+ for (Node* src : g.nodes()) {
+ if (node_cluster.count(src) == 0) {
+ const string name = src->def().name();
+ auto it = clusters.find(name);
+ if (it != clusters.end()) {
+ it->second.insert(src);
+ node_cluster[src] = name;
+ }
+ }
+ }
+
+ auto node_in_collapsed_cluster = [&node_cluster,
+ &cluster_name_pattern](Node* n) {
+ return node_cluster.count(n) > 0 &&
+ RE2::PartialMatch(node_cluster[n], cluster_name_pattern);
+ };
+
+ // Assign device indices in sorted order.
+ int num = 0;
+ for (auto& e : device_index) {
+ e.second = num++;
+ }
+
+ double total_node_cost = 0;
+ double avg_node_cost = 1;
+ if (opts.node_cost) {
+ int node_count = 0;
+ for (const Node* n : g.nodes()) {
+ total_node_cost += opts.node_cost(n);
+ ++node_count;
+ }
+ if (total_node_cost > 0) avg_node_cost = total_node_cost / node_count;
+ }
+
+ for (Node* src : g.nodes()) {
+ if (visible_nodes.count(src) == 0 || node_in_collapsed_cluster(src)) {
+ continue;
+ }
+ string label = src->name();
+ if (ShoulDisplayOpType(src)) {
+ // Append the op type if it is not directly deducible from the op name.
+ strings::StrAppend(&label, "\\n(", src->type_string(), ")");
+ }
+ const char* shape = "box";
+ const char* color = nullptr;
+ if (src->IsSource()) {
+ shape = "oval";
+ } else if (src->IsSink()) {
+ shape = "oval";
+ } else {
+ const string& d = src->assigned_device_name();
+ const int dindex = (!d.empty()) ? device_index[d] : -1;
+ if (dindex >= 0) {
+ color = ColorFor(dindex);
+ }
+
+ shape = "box";
+ }
+
+ if (opts.node_label) {
+ string extra = opts.node_label(src);
+ if (!extra.empty()) {
+ strings::StrAppend(&label, "\\n", extra);
+ }
+ }
+
+ strings::StrAppend(&result, GraphNodeName(opts, src), "[shape=", shape,
+ ", label=\"", label, "\"");
+ if (opts.node_cost && total_node_cost > 0) {
+ // Pick fontsize in range [8..40] so that area is proportional to cost.
+ const double cost = opts.node_cost(src);
+ const double relcost = fabs(cost / avg_node_cost);
+ // Average cost node has font size of 12.
+ const int fs = 8 + static_cast<int>(4.0 * std::min(sqrt(relcost), 8.0));
+ strings::StrAppend(&result, ", width=0, height=0, fontsize=", fs);
+ VLOG(2) << "Node: " << cost << " => " << relcost << " => " << fs;
+ }
+ if (color != nullptr) {
+ strings::StrAppend(&result, ", fillcolor=\"", color,
+ "\", fontcolor=\"white\", style=\"filled\"");
+ }
+ strings::StrAppend(&result, "]\n");
+ }
+
+ for (auto c : clusters) {
+ const string& cluster_name = c.first;
+ const std::unordered_set<Node*> nodes = c.second;
+ std::unordered_map<string, int> node_colors;
+ for (auto n : nodes) {
+ const string& d = n->assigned_device_name();
+ const int dindex = (!d.empty()) ? device_index[d] : -1;
+ if (dindex >= 0) {
+ ++node_colors[ColorFor(dindex)];
+ }
+ }
+
+ string majority_color;
+ if (node_colors.empty()) {
+ majority_color = ColorFor(0);
+ } else {
+ majority_color = std::max_element(node_colors.begin(), node_colors.end(),
+ [](const std::pair<string, int>& x,
+ const std::pair<string, int>& y) {
+ return x.second < y.second;
+ })
+ ->first;
+ }
+
+ if (!RE2::PartialMatch(cluster_name, cluster_name_pattern)) {
+ strings::StrAppend(&result, "subgraph cluster_", cluster_name, "{\n");
+ for (auto n : nodes) {
+ strings::StrAppend(&result, GraphNodeName(opts, n), ";\n");
+ }
+ strings::StrAppend(&result, "}\n");
+ } else {
+ strings::StrAppend(&result, cluster_name, " [shape=oval, fillcolor=\"",
+ majority_color, "\", label=\"", cluster_name,
+ "\", style=\"filled\", fontcolor=\"white\"]\n");
+ }
+ }
+
+ std::unordered_set<string> edge_drawn;
+
+ double max_edge_cost = 0;
+ double total_edge_cost = 0;
+ double avg_edge_cost = 1;
+ if (opts.edge_cost && g.edges().size()) {
+ for (const Edge* e : g.edges()) {
+ auto cost = opts.edge_cost(e);
+ total_edge_cost += cost;
+ max_edge_cost = std::max(max_edge_cost, cost);
+ }
+ avg_edge_cost = total_edge_cost / g.edges().size();
+ }
+ VLOG(2) << "Edge cost tot/max/avg: " << total_edge_cost << "/"
+ << max_edge_cost << "/" << avg_edge_cost;
+
+ for (const Edge* e : g.edges()) {
+ Node* src = e->src();
+ Node* dst = e->dst();
+ // If either endpoint isn't drawn in the graph, don't draw the edge
+ if (visible_nodes.count(src) == 0 || visible_nodes.count(dst) == 0) {
+ continue;
+ }
+
+ const string src_name = node_in_collapsed_cluster(src)
+ ? node_cluster[src]
+ : GraphNodeName(opts, src);
+ const string dst_name = node_in_collapsed_cluster(dst)
+ ? node_cluster[dst]
+ : GraphNodeName(opts, dst);
+ // Don't draw self edges
+ if (src_name == dst_name) {
+ continue;
+ }
+ // And previously drawn edges.
+ const string& edge_name = strings::StrCat(src_name, ":", dst_name);
+ if (edge_drawn.count(edge_name) > 0) {
+ continue;
+ }
+ edge_drawn.insert(edge_name);
+
+ strings::StrAppend(&result, src_name, " -> ", dst_name, "[");
+ string label;
+ if (e->IsControlEdge()) {
+ strings::StrAppend(&result, " style=dotted");
+ }
+ if (opts.edge_label) {
+ string label = opts.edge_label(e);
+ if (!label.empty()) {
+ strings::StrAppend(&result, " label=<", label, ">");
+ }
+ }
+ // Make edge widths proportional to amount of data transferred.
+ if (opts.edge_cost && max_edge_cost > 0) {
+ const double cost = opts.edge_cost(e);
+ const double relcost = fabs(cost / avg_edge_cost);
+ // Pick penwidth in range [1..6] so that width is proportional to cost.
+ const int pw = 1 + std::min(5, static_cast<int>(2.0 * relcost));
+ strings::StrAppend(&result, " penwidth=", pw);
+ // Use weight attributes [1..100] to keep heavier edges more vertical.
+ const int weight = 1 + std::min(99, static_cast<int>(100.0 * relcost));
+ strings::StrAppend(&result, " weight=", weight);
+ VLOG(2) << "Edge: " << cost << " => " << relcost << " => " << pw << "/"
+ << weight;
+ }
+
+ strings::StrAppend(&result, "]\n");
+ }
+ // Compute some statistics
+ int op_nodes = 0;
+ for (Node* n : g.nodes()) {
+ if (n->IsOp()) {
+ op_nodes++;
+ }
+ }
+
+ // Emit legend
+ strings::StrAppend(&result,
+ "{ rank = source; Legend [shape=box, margin=0, label=<",
+ "<TABLE BORDER=\"0\" CELLBORDER=\"1\" CELLSPACING=\"0\" ",
+ "CELLPADDING=\"4\">", "<TR><TD COLSPAN=\"2\">op_nodes: ",
+ op_nodes, "</TD></TR>\n");
+ for (const auto& e : device_index) {
+ const int dindex = e.second;
+ strings::StrAppend(&result, "<TR><TD BGCOLOR=\"", ColorFor(dindex),
+ "\"><FONT COLOR=\"white\">", dindex, "</FONT></TD><TD>",
+ e.first, "</TD></TR>\n");
+ }
+ strings::StrAppend(&result, "</TABLE>>]}\n");
+
+ strings::StrAppend(&result, "}\n"); // End digraph
+ return result;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/dot.h b/tensorflow/core/graph/dot.h
new file mode 100644
index 0000000000..f87f68099c
--- /dev/null
+++ b/tensorflow/core/graph/dot.h
@@ -0,0 +1,43 @@
+#ifndef TENSORFLOW_GRAPH_DOT_H_
+#define TENSORFLOW_GRAPH_DOT_H_
+
+#include <functional>
+#include <string>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+class Edge;
+class Graph;
+class Node;
+
+struct DotOptions {
+ bool (*include_node_function)(const Node*) = nullptr;
+
+ // By default, all nodes with the same name prefix are collapsed into
+ // a single node in the dot graph. This regexp can be changed so that
+ // only prefixes that match the regexp are collapsed in this fashion.
+ // 'all' collapses all ops with prefixes, 'none' disables all collapsing.
+ string prefix_collapse_regexp = "all";
+
+ // A function that returns a label to embed into the per-node display.
+ std::function<string(const Node*)> node_label;
+
+ // A function that returns a label to attach to an edge.
+ std::function<string(const Edge*)> edge_label;
+
+ // A function that returns the "cost" of the node. The dot display
+ // makes a node size proportional to its cost.
+ std::function<double(const Node*)> node_cost;
+
+ // A function that returns the "cost" of the edge. The dot display
+ // makes a edge thickness proportional to its cost.
+ std::function<double(const Edge*)> edge_cost;
+};
+
+// Return a string that contains a graphviz specification of the graph.
+string DotGraph(const Graph& g, const DotOptions& opts);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_DOT_H_
diff --git a/tensorflow/core/graph/edgeset.cc b/tensorflow/core/graph/edgeset.cc
new file mode 100644
index 0000000000..83293c7b4e
--- /dev/null
+++ b/tensorflow/core/graph/edgeset.cc
@@ -0,0 +1,56 @@
+#include "tensorflow/core/graph/edgeset.h"
+
+namespace tensorflow {
+
+std::pair<EdgeSet::const_iterator, bool> EdgeSet::insert(value_type value) {
+ RegisterMutation();
+ const_iterator ci;
+ ci.Init(this);
+ auto s = get_set();
+ if (!s) {
+ for (int i = 0; i < kInline; i++) {
+ if (ptrs_[i] == value) {
+ ci.array_iter_ = &ptrs_[i];
+ return std::make_pair(ci, false);
+ }
+ }
+ for (int i = 0; i < kInline; i++) {
+ if (ptrs_[i] == nullptr) {
+ ptrs_[i] = value;
+ ci.array_iter_ = &ptrs_[i];
+ return std::make_pair(ci, true);
+ }
+ }
+ // array is full. convert to set.
+ s = new std::set<const Edge*>;
+ for (int i = 0; i < kInline; i++) {
+ s->insert(static_cast<const Edge*>(ptrs_[i]));
+ }
+ ptrs_[0] = this;
+ ptrs_[1] = s;
+ // fall through.
+ }
+ auto p = s->insert(value);
+ ci.tree_iter_ = p.first;
+ return std::make_pair(ci, p.second);
+}
+
+EdgeSet::size_type EdgeSet::erase(key_type key) {
+ RegisterMutation();
+ auto s = get_set();
+ if (!s) {
+ for (int i = 0; i < kInline; i++) {
+ if (ptrs_[i] == key) {
+ size_t n = size();
+ ptrs_[i] = ptrs_[n - 1];
+ ptrs_[n - 1] = nullptr;
+ return 1;
+ }
+ }
+ return 0;
+ } else {
+ return s->erase(key);
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/edgeset.h b/tensorflow/core/graph/edgeset.h
new file mode 100644
index 0000000000..df0d78b8fb
--- /dev/null
+++ b/tensorflow/core/graph/edgeset.h
@@ -0,0 +1,216 @@
+#ifndef TENSORFLOW_GRAPH_EDGESET_H_
+#define TENSORFLOW_GRAPH_EDGESET_H_
+
+#include <stddef.h>
+#include <set>
+#include "tensorflow/core/platform/port.h"
+
+#include "tensorflow/core/platform/logging.h"
+namespace tensorflow {
+
+class Edge;
+
+// An unordered set of edges. Uses very little memory for small sets.
+// Unlike std::set, EdgeSet does NOT allow mutations during iteration.
+class EdgeSet {
+ public:
+ EdgeSet();
+ ~EdgeSet();
+
+ typedef const Edge* key_type;
+ typedef const Edge* value_type;
+ typedef size_t size_type;
+ typedef ptrdiff_t difference_type;
+
+ class const_iterator;
+ typedef const_iterator iterator;
+
+ bool empty() const;
+ size_type size() const;
+ void clear();
+ std::pair<iterator, bool> insert(value_type value);
+ size_type erase(key_type key);
+
+ // Caller is not allowed to mutate the EdgeSet while iterating.
+ const_iterator begin() const;
+ const_iterator end() const;
+
+ private:
+ // Up to kInline elements are stored directly in ptrs_ (nullptr means none).
+ // If ptrs_[0] == this then ptrs_[1] points to a set<const Edge*>.
+ static const int kInline = 2; // Must be >= 2.
+ const void* ptrs_[kInline];
+
+ std::set<const Edge*>* get_set() const {
+ if (ptrs_[0] == this) {
+ return static_cast<std::set<const Edge*>*>(const_cast<void*>(ptrs_[1]));
+ } else {
+ return nullptr;
+ }
+ }
+
+// To detect mutations while iterating.
+#ifdef NDEBUG
+ void RegisterMutation() {}
+#else
+ uint32 mutations_ = 0;
+ void RegisterMutation() { mutations_++; }
+#endif
+
+ TF_DISALLOW_COPY_AND_ASSIGN(EdgeSet);
+};
+
+class EdgeSet::const_iterator {
+ public:
+ typedef typename EdgeSet::value_type value_type;
+ typedef const typename EdgeSet::value_type& reference;
+ typedef const typename EdgeSet::value_type* pointer;
+ typedef typename EdgeSet::difference_type difference_type;
+ typedef std::forward_iterator_tag iterator_category;
+
+ const_iterator() {}
+
+ const_iterator& operator++();
+ const_iterator operator++(int /*unused*/);
+ const value_type* operator->() const;
+ value_type operator*() const;
+ bool operator==(const const_iterator& other) const;
+ bool operator!=(const const_iterator& other) const {
+ return !(*this == other);
+ }
+
+ private:
+ friend class EdgeSet;
+
+ void const* const* array_iter_ = nullptr;
+ typename std::set<const Edge*>::const_iterator tree_iter_;
+
+#ifdef NDEBUG
+ inline void Init(const EdgeSet* e) {}
+ inline void CheckNoMutations() const {}
+#else
+ inline void Init(const EdgeSet* e) {
+ owner_ = e;
+ init_mutations_ = e->mutations_;
+ }
+ inline void CheckNoMutations() const {
+ CHECK_EQ(init_mutations_, owner_->mutations_);
+ }
+ const EdgeSet* owner_ = nullptr;
+ uint32 init_mutations_ = 0;
+#endif
+};
+
+inline EdgeSet::EdgeSet() {
+ for (int i = 0; i < kInline; i++) {
+ ptrs_[i] = nullptr;
+ }
+}
+
+inline EdgeSet::~EdgeSet() { delete get_set(); }
+
+inline bool EdgeSet::empty() const { return size() == 0; }
+
+inline EdgeSet::size_type EdgeSet::size() const {
+ auto s = get_set();
+ if (s) {
+ return s->size();
+ } else {
+ size_t result = 0;
+ for (int i = 0; i < kInline; i++) {
+ if (ptrs_[i]) result++;
+ }
+ return result;
+ }
+}
+
+inline void EdgeSet::clear() {
+ RegisterMutation();
+ delete get_set();
+ for (int i = 0; i < kInline; i++) {
+ ptrs_[i] = nullptr;
+ }
+}
+
+inline EdgeSet::const_iterator EdgeSet::begin() const {
+ const_iterator ci;
+ ci.Init(this);
+ auto s = get_set();
+ if (s) {
+ ci.tree_iter_ = s->begin();
+ } else {
+ ci.array_iter_ = &ptrs_[0];
+ }
+ return ci;
+}
+
+inline EdgeSet::const_iterator EdgeSet::end() const {
+ const_iterator ci;
+ ci.Init(this);
+ auto s = get_set();
+ if (s) {
+ ci.tree_iter_ = s->end();
+ } else {
+ ci.array_iter_ = &ptrs_[size()];
+ }
+ return ci;
+}
+
+inline EdgeSet::const_iterator& EdgeSet::const_iterator::operator++() {
+ CheckNoMutations();
+ if (array_iter_ != nullptr) {
+ ++array_iter_;
+ } else {
+ ++tree_iter_;
+ }
+ return *this;
+}
+
+inline EdgeSet::const_iterator EdgeSet::const_iterator::operator++(
+ int /*unused*/) {
+ CheckNoMutations();
+ const_iterator tmp = *this;
+ operator++();
+ return tmp;
+}
+
+// gcc's set and multiset always use const_iterator since it will otherwise
+// allow modification of keys.
+inline const EdgeSet::const_iterator::value_type* EdgeSet::const_iterator::
+operator->() const {
+ CheckNoMutations();
+ if (array_iter_ != nullptr) {
+ return reinterpret_cast<const value_type*>(array_iter_);
+ } else {
+ return tree_iter_.operator->();
+ }
+}
+
+// gcc's set and multiset always use const_iterator since it will otherwise
+// allow modification of keys.
+inline EdgeSet::const_iterator::value_type EdgeSet::const_iterator::operator*()
+ const {
+ CheckNoMutations();
+ if (array_iter_ != nullptr) {
+ return static_cast<value_type>(*array_iter_);
+ } else {
+ return *tree_iter_;
+ }
+}
+
+inline bool EdgeSet::const_iterator::operator==(
+ const const_iterator& other) const {
+ DCHECK((array_iter_ == nullptr) == (other.array_iter_ == nullptr))
+ << "Iterators being compared must be from same set that has not "
+ << "been modified since the iterator was constructed";
+ CheckNoMutations();
+ if (array_iter_ != nullptr) {
+ return array_iter_ == other.array_iter_;
+ } else {
+ return other.array_iter_ == nullptr && tree_iter_ == other.tree_iter_;
+ }
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_EDGESET_H_
diff --git a/tensorflow/core/graph/edgeset_test.cc b/tensorflow/core/graph/edgeset_test.cc
new file mode 100644
index 0000000000..7909e8ea0a
--- /dev/null
+++ b/tensorflow/core/graph/edgeset_test.cc
@@ -0,0 +1,95 @@
+#include "tensorflow/core/graph/edgeset.h"
+
+#include "tensorflow/core/graph/graph.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+class EdgeSetTest : public ::testing::Test {
+ public:
+ EdgeSetTest() : edges_(nullptr), eset_(nullptr) {}
+
+ ~EdgeSetTest() override {
+ delete eset_;
+ delete[] edges_;
+ }
+
+ void MakeEdgeSet(int n) {
+ delete eset_;
+ delete[] edges_;
+ edges_ = new Edge[n];
+ eset_ = new EdgeSet;
+ model_.clear();
+ for (int i = 0; i < n; i++) {
+ eset_->insert(&edges_[i]);
+ model_.insert(&edges_[i]);
+ }
+ }
+
+ void CheckSame() {
+ EXPECT_EQ(model_.size(), eset_->size());
+ EXPECT_EQ(model_.empty(), eset_->empty());
+ std::vector<const Edge*> modelv(model_.begin(), model_.end());
+ std::vector<const Edge*> esetv(eset_->begin(), eset_->end());
+ std::sort(modelv.begin(), modelv.end());
+ std::sort(esetv.begin(), esetv.end());
+ EXPECT_EQ(modelv.size(), esetv.size());
+ for (size_t i = 0; i < modelv.size(); i++) {
+ EXPECT_EQ(modelv[i], esetv[i]) << i;
+ }
+ }
+
+ Edge nonexistent_;
+ Edge* edges_;
+ EdgeSet* eset_;
+ std::set<const Edge*> model_;
+};
+
+namespace {
+
+TEST_F(EdgeSetTest, Ops) {
+ for (int n : {0, 1, 2, 3, 4, 10}) {
+ MakeEdgeSet(n);
+ CheckSame();
+ EXPECT_EQ((n == 0), eset_->empty());
+ EXPECT_EQ(n, eset_->size());
+
+ eset_->clear();
+ model_.clear();
+ CheckSame();
+
+ eset_->insert(&edges_[0]);
+ model_.insert(&edges_[0]);
+ CheckSame();
+ }
+}
+
+// Try insert/erase of existing elements at different positions.
+TEST_F(EdgeSetTest, Exists) {
+ for (int n : {0, 1, 2, 3, 4, 10}) {
+ MakeEdgeSet(n);
+ for (int pos = 0; pos < n; pos++) {
+ MakeEdgeSet(n);
+ auto p = eset_->insert(&edges_[pos]);
+ EXPECT_FALSE(p.second);
+ EXPECT_EQ(&edges_[pos], *p.first);
+
+ EXPECT_EQ(1, eset_->erase(&edges_[pos]));
+ model_.erase(&edges_[pos]);
+ CheckSame();
+ }
+ }
+}
+
+// Try insert/erase of non-existent element.
+TEST_F(EdgeSetTest, DoesNotExist) {
+ for (int n : {0, 1, 2, 3, 4, 10}) {
+ MakeEdgeSet(n);
+ EXPECT_EQ(0, eset_->erase(&nonexistent_));
+ auto p = eset_->insert(&nonexistent_);
+ EXPECT_TRUE(p.second);
+ EXPECT_EQ(&nonexistent_, *p.first);
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/equal_graph_def.cc b/tensorflow/core/graph/equal_graph_def.cc
new file mode 100644
index 0000000000..35f59b5ed0
--- /dev/null
+++ b/tensorflow/core/graph/equal_graph_def.cc
@@ -0,0 +1,176 @@
+#include "tensorflow/core/graph/equal_graph_def.h"
+
+#include <unordered_map>
+#include <unordered_set>
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
+ string* diff) {
+ std::unordered_map<string, const NodeDef*> actual_index;
+ for (const NodeDef& node : actual.node()) {
+ actual_index[node.name()] = &node;
+ }
+
+ for (const NodeDef& expected_node : expected.node()) {
+ auto actual_iter = actual_index.find(expected_node.name());
+ if (actual_iter == actual_index.end()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Did not find expected node '",
+ SummarizeNodeDef(expected_node), "'");
+ }
+ return false;
+ }
+
+ if (!EqualNodeDef(*actual_iter->second, expected_node, diff)) return false;
+
+ actual_index.erase(actual_iter);
+ }
+
+ if (!actual_index.empty()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Found unexpected node '",
+ SummarizeNodeDef(*actual_index.begin()->second),
+ "' not in expected graph:\n",
+ SummarizeGraphDef(expected));
+ }
+ return false;
+ }
+
+ return true;
+}
+
+namespace {
+
+string JoinStringField(const protobuf::RepeatedPtrField<string>& f) {
+ string ret;
+ for (int i = 0; i < f.size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, f.Get(i));
+ }
+ return ret;
+}
+
+} // namespace
+
+bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected,
+ string* diff) {
+ if (actual.name() != expected.name()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Actual node name '", actual.name(),
+ "' is not expected '", expected.name(), "'");
+ }
+ return false;
+ }
+
+ if (actual.op() != expected.op()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Node named '", actual.name(), "' has op '",
+ actual.op(), "' that is not expected '",
+ expected.op(), "'");
+ }
+ return false;
+ }
+
+ if (actual.device() != expected.device()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Node named '", actual.name(), "' has device '",
+ actual.device(), "' that is not expected '",
+ expected.device(), "'");
+ }
+ return false;
+ }
+
+ if (actual.input_size() != expected.input_size()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Node named '", actual.name(), "' has inputs '",
+ JoinStringField(actual.input()),
+ "' that don't match expected '",
+ JoinStringField(expected.input()), "'");
+ }
+ return false;
+ }
+
+ int first_control_input = actual.input_size();
+ for (int i = 0; i < actual.input_size(); ++i) {
+ if (StringPiece(actual.input(i)).starts_with("^")) {
+ first_control_input = i;
+ break;
+ }
+ if (actual.input(i) != expected.input(i)) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Node named '", actual.name(), "' has input ",
+ i, " '", actual.input(i),
+ "' that doesn't match expected '",
+ expected.input(i), "'");
+ }
+ return false;
+ }
+ }
+
+ std::unordered_set<string> actual_control;
+ std::unordered_set<string> expected_control;
+ for (int i = first_control_input; i < actual.input_size(); ++i) {
+ actual_control.insert(actual.input(i));
+ expected_control.insert(expected.input(i));
+ }
+ for (const auto& e : expected_control) {
+ if (actual_control.erase(e) == 0) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Node named '", actual.name(),
+ "' missing expected control input '", e, "'");
+ }
+ return false;
+ }
+ }
+ if (!actual_control.empty()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Node named '", actual.name(),
+ "' has unexpected control input '",
+ *actual_control.begin(), "'");
+ }
+ return false;
+ }
+
+ std::unordered_set<string> actual_attr;
+ for (const auto& a : actual.attr()) {
+ actual_attr.insert(a.first);
+ }
+ for (const auto& e : expected.attr()) {
+ if (actual_attr.erase(e.first) == 0) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Node named '", actual.name(),
+ "' missing expected attr '", e.first,
+ "' with value: ", SummarizeAttrValue(e.second));
+ }
+ return false;
+ }
+ auto iter = actual.attr().find(e.first);
+ if (!AreAttrValuesEqual(e.second, iter->second)) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat(
+ "Node named '", actual.name(), "' has attr '", e.first,
+ "' with value: ", SummarizeAttrValue(iter->second),
+ " that does not match expected: ", SummarizeAttrValue(e.second));
+ }
+ return false;
+ }
+ }
+ if (!actual_attr.empty()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat(
+ "Node named '", actual.name(), "' has unexpected attr '",
+ *actual_attr.begin(), "' with value: ",
+ SummarizeAttrValue(actual.attr().find(*actual_attr.begin())->second));
+ }
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/equal_graph_def.h b/tensorflow/core/graph/equal_graph_def.h
new file mode 100644
index 0000000000..7dd8aab340
--- /dev/null
+++ b/tensorflow/core/graph/equal_graph_def.h
@@ -0,0 +1,32 @@
+#ifndef TENSORFLOW_GRAPH_EQUAL_GRAPH_DEF_H_
+#define TENSORFLOW_GRAPH_EQUAL_GRAPH_DEF_H_
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Determines if actual and expected are equal, ignoring ordering of
+// nodes, attrs, and control inputs. If the GraphDefs are different
+// and diff != nullptr, *diff is set to an explanation of the
+// difference. Note that we use node names to match up nodes between
+// the graphs, and so the naming of nodes must be consistent.
+bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
+ string* diff);
+
+// Determines if actual and expected are equal, ignoring ordering of
+// attrs and control inputs. If the NodeDefs are different and
+// diff != nullptr, *diff is set to an explanation of the difference.
+bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff);
+
+#define TF_EXPECT_GRAPH_EQ(expected, actual) \
+ do { \
+ string diff; \
+ EXPECT_TRUE(EqualGraphDef(actual, expected, &diff)) \
+ << diff << "\nActual: " << SummarizeGraphDef(actual); \
+ } while (false)
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_EQUAL_GRAPH_DEF_H_
diff --git a/tensorflow/core/graph/equal_graph_def_test.cc b/tensorflow/core/graph/equal_graph_def_test.cc
new file mode 100644
index 0000000000..3a38b9e522
--- /dev/null
+++ b/tensorflow/core/graph/equal_graph_def_test.cc
@@ -0,0 +1,279 @@
+#include "tensorflow/core/graph/equal_graph_def.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+REGISTER_OP("Input").Output("o: float");
+REGISTER_OP("Alternate").Output("o: float");
+REGISTER_OP("Cross").Input("a: float").Input("b: float").Output("o: float");
+
+Node* Input(const GraphDefBuilder::Options& opts) {
+ return ops::SourceOp("Input", opts);
+}
+
+Node* Alternate(const GraphDefBuilder::Options& opts) {
+ return ops::SourceOp("Alternate", opts);
+}
+
+Node* Cross(ops::NodeOut a, ops::NodeOut b,
+ const GraphDefBuilder::Options& opts) {
+ return ops::BinaryOp("Cross", a, b, opts);
+}
+
+class EqualGraphDefTest : public ::testing::Test {
+ protected:
+ EqualGraphDefTest()
+ : e_(GraphDefBuilder::kFailImmediately),
+ a_(GraphDefBuilder::kFailImmediately) {
+ RequireDefaultOps();
+ }
+
+ bool Match() {
+ GraphDef expected;
+ e_.ToGraphDef(&expected);
+ GraphDef actual;
+ a_.ToGraphDef(&actual);
+ return EqualGraphDef(actual, expected, &diff_);
+ }
+
+ GraphDefBuilder e_;
+ GraphDefBuilder a_;
+ string diff_;
+};
+
+TEST_F(EqualGraphDefTest, Match) {
+ Input(e_.opts().WithName("A"));
+ Input(a_.opts().WithName("A"));
+ EXPECT_TRUE(Match()) << diff_;
+}
+
+TEST_F(EqualGraphDefTest, NoMatch) {
+ Input(e_.opts().WithName("A"));
+ Input(a_.opts().WithName("B"));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ("Did not find expected node 'A = Input[]()'", diff_);
+}
+
+TEST_F(EqualGraphDefTest, MissingNode) {
+ Input(e_.opts().WithName("A"));
+ Input(e_.opts().WithName("B"));
+ Input(a_.opts().WithName("A"));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ("Did not find expected node 'B = Input[]()'", diff_);
+}
+
+TEST_F(EqualGraphDefTest, ExtraNode) {
+ Input(e_.opts().WithName("A"));
+ Input(a_.opts().WithName("A"));
+ Input(a_.opts().WithName("B"));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ(
+ "Found unexpected node 'B = Input[]()' not in expected graph:\n"
+ "A = Input[]();\n",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, NodeOrder) {
+ Node* a = Input(e_.opts().WithName("A"));
+ Node* b = Input(e_.opts().WithName("B"));
+ Cross(a, b, e_.opts().WithName("C"));
+
+ b = Input(a_.opts().WithName("B"));
+ a = Input(a_.opts().WithName("A"));
+ Cross(a, b, a_.opts().WithName("C"));
+ EXPECT_TRUE(Match()) << diff_;
+}
+
+TEST_F(EqualGraphDefTest, NameMismatch) {
+ Node* a = Input(e_.opts().WithName("A"));
+ Node* b = Input(e_.opts().WithName("B"));
+ // Have to call EqualNodeDef() directly here, since EqualGraphDef()
+ // only calls EqualNodeDef() with nodes that have matching names.
+ EXPECT_FALSE(EqualNodeDef(a->def(), b->def(), &diff_));
+ EXPECT_EQ("Actual node name 'A' is not expected 'B'", diff_);
+}
+
+TEST_F(EqualGraphDefTest, OpMismatch) {
+ Input(e_.opts().WithName("A"));
+ Alternate(a_.opts().WithName("A"));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ("Node named 'A' has op 'Alternate' that is not expected 'Input'",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, DeviceMatch) {
+ Input(e_.opts().WithName("A").WithDevice("/cpu:0"));
+ Input(a_.opts().WithName("A").WithDevice("/cpu:0"));
+ EXPECT_TRUE(Match()) << diff_;
+}
+
+TEST_F(EqualGraphDefTest, DeviceMismatch) {
+ Input(e_.opts().WithName("A").WithDevice("/cpu:0"));
+ Input(a_.opts().WithName("A").WithDevice("/cpu:1"));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ("Node named 'A' has device '/cpu:1' that is not expected '/cpu:0'",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, InputMismatch) {
+ Node* a = Input(e_.opts().WithName("A"));
+ Node* b = Input(e_.opts().WithName("B"));
+ Cross(a, a, e_.opts().WithName("C"));
+
+ a = Input(a_.opts().WithName("A"));
+ b = Input(a_.opts().WithName("B"));
+ Cross(b, b, a_.opts().WithName("C"));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ("Node named 'C' has input 0 'B' that doesn't match expected 'A'",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, InputOrderMismatch) {
+ Node* a = Input(e_.opts().WithName("A"));
+ Node* b = Input(e_.opts().WithName("B"));
+ Cross(a, b, e_.opts().WithName("C"));
+
+ a = Input(a_.opts().WithName("A"));
+ b = Input(a_.opts().WithName("B"));
+ Cross(b, a, a_.opts().WithName("C"));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ("Node named 'C' has input 0 'B' that doesn't match expected 'A'",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, ControlInputOrder) {
+ Node* a = Input(e_.opts().WithName("A"));
+ Node* b = Input(e_.opts().WithName("B"));
+ Node* c = Input(e_.opts().WithName("C"));
+ Node* d = Input(e_.opts().WithName("D"));
+ Cross(a, a, e_.opts()
+ .WithName("E")
+ .WithControlInput(b)
+ .WithControlInput(c)
+ .WithControlInput(d));
+
+ a = Input(a_.opts().WithName("A"));
+ b = Input(a_.opts().WithName("B"));
+ c = Input(a_.opts().WithName("C"));
+ d = Input(a_.opts().WithName("D"));
+ Cross(a, a, a_.opts()
+ .WithName("E")
+ .WithControlInput(c)
+ .WithControlInput(d)
+ .WithControlInput(b));
+ EXPECT_TRUE(Match()) << diff_;
+}
+
+TEST_F(EqualGraphDefTest, ControlInputMismatch) {
+ Node* a = Input(e_.opts().WithName("A"));
+ Node* b = Input(e_.opts().WithName("B"));
+ Node* c = Input(e_.opts().WithName("C"));
+ Node* d = Input(e_.opts().WithName("D"));
+ Cross(a, a, e_.opts().WithName("E").WithControlInput(b).WithControlInput(c));
+
+ a = Input(a_.opts().WithName("A"));
+ b = Input(a_.opts().WithName("B"));
+ c = Input(a_.opts().WithName("C"));
+ d = Input(a_.opts().WithName("D"));
+ Cross(a, a, a_.opts().WithName("E").WithControlInput(b).WithControlInput(d));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ("Node named 'E' missing expected control input '^C'", diff_);
+}
+
+TEST_F(EqualGraphDefTest, ControlInputAdded) {
+ Node* a = Input(e_.opts().WithName("A"));
+ Node* b = Input(e_.opts().WithName("B"));
+ Node* c = Input(e_.opts().WithName("C"));
+ Cross(a, a, e_.opts().WithName("D").WithControlInput(b));
+
+ a = Input(a_.opts().WithName("A"));
+ b = Input(a_.opts().WithName("B"));
+ c = Input(a_.opts().WithName("C"));
+ Cross(a, a, a_.opts().WithName("D").WithControlInput(b).WithControlInput(c));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ(
+ "Node named 'D' has inputs 'A, A, ^B, ^C' that don't match "
+ "expected 'A, A, ^B'",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, ControlInputRemoved) {
+ Node* a = Input(e_.opts().WithName("A"));
+ Node* b = Input(e_.opts().WithName("B"));
+ Node* c = Input(e_.opts().WithName("C"));
+ Cross(a, a, e_.opts().WithName("D").WithControlInput(b).WithControlInput(c));
+
+ a = Input(a_.opts().WithName("A"));
+ b = Input(a_.opts().WithName("B"));
+ c = Input(a_.opts().WithName("C"));
+ Cross(a, a, a_.opts().WithName("D").WithControlInput(b));
+ EXPECT_FALSE(Match());
+ EXPECT_EQ(
+ "Node named 'D' has inputs 'A, A, ^B' that don't match "
+ "expected 'A, A, ^B, ^C'",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, Attr) {
+ Node* a = Input(e_.opts().WithName("A"));
+ NodeDef same(a->def());
+ AddNodeAttr("foo", "bar", &same);
+ EXPECT_TRUE(EqualNodeDef(same, same, &diff_)) << diff_;
+}
+
+TEST_F(EqualGraphDefTest, AttrAdded) {
+ Node* a = Input(e_.opts().WithName("A"));
+ NodeDef actual(a->def());
+ AddNodeAttr("foo", "bar", &actual);
+ EXPECT_FALSE(EqualNodeDef(actual, a->def(), &diff_));
+ EXPECT_EQ("Node named 'A' has unexpected attr 'foo' with value: \"bar\"",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, AttrRemoved) {
+ Node* a = Input(e_.opts().WithName("A"));
+ NodeDef expected(a->def());
+ AddNodeAttr("foo", "bar", &expected);
+ EXPECT_FALSE(EqualNodeDef(a->def(), expected, &diff_));
+ EXPECT_EQ("Node named 'A' missing expected attr 'foo' with value: \"bar\"",
+ diff_);
+}
+
+TEST_F(EqualGraphDefTest, AttrOrder) {
+ Node* a = Input(e_.opts().WithName("A"));
+ NodeDef actual(a->def());
+ AddNodeAttr("foo", "bar", &actual);
+ AddNodeAttr("baz", 42, &actual);
+
+ NodeDef expected(a->def());
+ AddNodeAttr("baz", 42, &expected);
+ AddNodeAttr("foo", "bar", &expected);
+
+ EXPECT_TRUE(EqualNodeDef(actual, expected, &diff_)) << diff_;
+}
+
+TEST_F(EqualGraphDefTest, AttrMismatch) {
+ Node* a = Input(e_.opts().WithName("A"));
+ NodeDef actual(a->def());
+ AddNodeAttr("foo", "bar", &actual);
+ AddNodeAttr("baz", 5, &actual);
+
+ NodeDef expected(a->def());
+ AddNodeAttr("baz", 42, &expected);
+ AddNodeAttr("foo", "bar", &expected);
+
+ EXPECT_FALSE(EqualNodeDef(actual, expected, &diff_));
+ EXPECT_EQ(
+ "Node named 'A' has attr 'baz' with value: 5 that does not match "
+ "expected: 42",
+ diff_);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
new file mode 100644
index 0000000000..0c268a51a9
--- /dev/null
+++ b/tensorflow/core/graph/graph.cc
@@ -0,0 +1,319 @@
+#include "tensorflow/core/graph/graph.h"
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+// Node
+
+string Node::DebugString() const {
+ if (this == nullptr) {
+ return "{nullptr}";
+ }
+ string ret = strings::StrCat("{name:'", name(), "' id:", id_);
+ if (IsSource()) {
+ strings::StrAppend(&ret, " source}");
+ } else if (IsSink()) {
+ strings::StrAppend(&ret, " sink}");
+ } else {
+ strings::StrAppend(&ret, " op device:");
+ strings::StrAppend(&ret, "{", assigned_device_name_, "}");
+ strings::StrAppend(&ret, " def:{", SummarizeNodeDef(def()), "}}");
+ }
+ return ret;
+}
+
+Node::Node()
+ : id_(-1), cost_id_(-1), props_(nullptr), assigned_device_name_() {}
+
+Node::~Node() {
+ if (props_) {
+ props_->Unref();
+ }
+}
+
+void Node::Initialize(int id, int cost_id, Properties* props) {
+ DCHECK_EQ(id_, -1);
+ DCHECK(in_edges_.empty());
+ DCHECK(out_edges_.empty());
+ id_ = id;
+ cost_id_ = cost_id;
+
+ // Unref the old, assign the new properties.
+ if (props_) {
+ props_->Unref();
+ }
+ props_ = props;
+}
+
+void Node::Clear() {
+ in_edges_.clear();
+ out_edges_.clear();
+ id_ = -1;
+ cost_id_ = -1;
+
+ if (props_) {
+ props_->Unref();
+ props_ = nullptr;
+ }
+
+ assigned_device_name_.clear();
+}
+
+gtl::iterator_range<NeighborIter> Node::out_nodes() const {
+ return gtl::make_range(NeighborIter(out_edges_.begin(), false),
+ NeighborIter(out_edges_.end(), false));
+}
+
+gtl::iterator_range<NeighborIter> Node::in_nodes() const {
+ return gtl::make_range(NeighborIter(in_edges_.begin(), true),
+ NeighborIter(in_edges_.end(), true));
+}
+
+// Node::Properties
+
+Node::Properties::Properties(const OpDef* op_def, const NodeDef& node_def,
+ const DataTypeSlice inputs,
+ const DataTypeSlice outputs)
+ : op_def_(op_def),
+ node_def_(node_def),
+ input_types_(inputs.begin(), inputs.end()),
+ output_types_(outputs.begin(), outputs.end()) {}
+
+Node::Properties::~Properties() {}
+
+// Graph
+
+Graph::Graph(const OpRegistryInterface* ops)
+ : ops_(ops), arena_(8 << 10 /* 8kB */) {
+ // Source and sink have no endpoints, just control edges.
+ NodeDef def;
+ def.set_name("_SOURCE");
+ def.set_op("NoOp");
+ Status status;
+ Node* source = AddNode(def, &status);
+ TF_CHECK_OK(status);
+ CHECK_EQ(source->id(), kSourceId);
+
+ def.set_name("_SINK");
+ Node* sink = AddNode(def, &status);
+ TF_CHECK_OK(status);
+ CHECK_EQ(sink->id(), kSinkId);
+
+ AddControlEdge(source, sink);
+}
+
+Graph::~Graph() {
+ // Manually call the destructors for all the Nodes we constructed using
+ // placement new.
+ for (Node* node : nodes_) {
+ if (node != nullptr) {
+ node->~Node();
+ }
+ }
+ for (Node* node : free_nodes_) {
+ node->~Node();
+ }
+ // Edges have no destructor, and we arena-allocated them, so no need to
+ // destroy them.
+}
+
+Node* Graph::AddNode(const NodeDef& node_def, Status* status) {
+ const OpDef* op_def = ops_->LookUp(node_def.op(), status);
+ if (op_def == nullptr) return nullptr;
+
+ // TODO(vrv,josh11b): Find a location higher in the stack to add these defaults
+ // to the NodeDef.
+ NodeDef node_def_with_defaults(node_def);
+ AddDefaultsToNodeDef(*op_def, &node_def_with_defaults);
+
+ DataTypeVector inputs;
+ DataTypeVector outputs;
+ status->Update(
+ InOutTypesForNode(node_def_with_defaults, *op_def, &inputs, &outputs));
+ if (!status->ok()) {
+ *status = AttachDef(*status, node_def_with_defaults);
+ return nullptr;
+ }
+
+ Node* node = AllocateNode(
+ new Node::Properties(op_def, node_def_with_defaults, inputs, outputs),
+ nullptr);
+ return node;
+}
+
+Node* Graph::CopyNode(Node* node) {
+ DCHECK(!node->IsSource());
+ DCHECK(!node->IsSink());
+ Node::Properties* props = node->properties();
+ props->Ref();
+ Node* copy = AllocateNode(props, node);
+ copy->set_assigned_device_name(node->assigned_device_name());
+ return copy;
+}
+
+void Graph::RemoveNode(Node* node) {
+ DCHECK(IsValidNode(node)) << node->DebugString();
+ DCHECK(!node->IsSource());
+ DCHECK(!node->IsSink());
+
+ // Remove any edges involving this node.
+ while (!node->in_edges_.empty()) {
+ RemoveEdge(*node->in_edges_.begin());
+ }
+ while (!node->out_edges_.empty()) {
+ RemoveEdge(*node->out_edges_.begin());
+ }
+ ReleaseNode(node);
+}
+
+const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
+ DCHECK(IsValidNode(source)) << source->DebugString();
+ DCHECK(IsValidNode(dest)) << dest->DebugString();
+
+ // source/sink must only be linked via control slots, and
+ // control slots must only be linked to control slots.
+ if (source == source_node() || dest == sink_node() || x == kControlSlot ||
+ y == kControlSlot) {
+ DCHECK_EQ(x, kControlSlot) << source->DebugString();
+ DCHECK_EQ(y, kControlSlot) << dest->DebugString();
+ }
+
+ Edge* e = nullptr;
+ if (free_edges_.empty()) {
+ e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new
+ } else {
+ e = free_edges_.back();
+ free_edges_.pop_back();
+ }
+ e->id_ = edges_.size();
+ e->src_ = source;
+ e->dst_ = dest;
+ e->src_output_ = x;
+ e->dst_input_ = y;
+ CHECK(source->out_edges_.insert(e).second);
+ CHECK(dest->in_edges_.insert(e).second);
+ edges_.push_back(e);
+ edge_set_.insert(e);
+ return e;
+}
+
+void Graph::RemoveEdge(const Edge* e) {
+ DCHECK(IsValidNode(e->src_)) << e->src_->DebugString();
+ DCHECK(IsValidNode(e->dst_)) << e->dst_->DebugString();
+ CHECK_EQ(e->src_->out_edges_.erase(e), 1);
+ CHECK_EQ(e->dst_->in_edges_.erase(e), 1);
+ CHECK_EQ(e, edges_[e->id_]);
+
+ CHECK_EQ(edge_set_.erase(e), 1);
+ edges_[e->id_] = nullptr;
+
+ Edge* del = const_cast<Edge*>(e);
+ del->src_ = nullptr;
+ del->dst_ = nullptr;
+ del->id_ = -1;
+ del->src_output_ = kControlSlot - 1;
+ del->dst_input_ = kControlSlot - 1;
+ free_edges_.push_back(del);
+}
+
+namespace {
+
+void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
+ if (src_slot == Graph::kControlSlot) {
+ dst->add_input(strings::StrCat("^", src_name));
+ } else if (src_slot == 0) {
+ dst->add_input(src_name.data(), src_name.size());
+ } else {
+ dst->add_input(strings::StrCat(src_name, ":", src_slot));
+ }
+}
+
+} // namespace
+
+void Graph::ToGraphDef(GraphDef* graph_def) const {
+ graph_def->Clear();
+ std::vector<const Edge*>
+ inputs; // Construct this outside the loop for speed.
+ for (const Node* node : nodes()) {
+ if (!node->IsOp()) continue;
+ NodeDef* node_def = graph_def->add_node();
+ *node_def = node->def();
+
+ // Use the node's assigned device, if any, instead of the device requested
+ // in the NodeDef.
+ if (!node->assigned_device_name().empty()) {
+ node_def->set_device(node->assigned_device_name());
+ }
+
+ // Get the inputs for this Node. We make sure control inputs are
+ // after data inputs, as required by GraphDef.
+ inputs.clear();
+ inputs.resize(node->num_inputs(), nullptr);
+ for (const Edge* edge : node->in_edges()) {
+ if (edge->IsControlEdge()) {
+ inputs.push_back(edge);
+ } else {
+ DCHECK(inputs[edge->dst_input()] == nullptr);
+ inputs[edge->dst_input()] = edge;
+ }
+ }
+ node_def->clear_input();
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ const Edge* edge = inputs[i];
+ if (edge == nullptr) {
+ node_def->add_input(node->def().input(i));
+ } else {
+ const Node* src = edge->src();
+ if (!src->IsOp()) continue;
+ AddInput(node_def, src->name(), edge->src_output());
+ }
+ }
+ }
+}
+
+string Graph::NewName(StringPiece prefix) {
+ return strings::StrCat(prefix, "/_", name_counter_++);
+}
+
+gtl::iterator_range<NodeIter> Graph::nodes() const {
+ // Note that NodeId 0 is always valid since we don't let the source
+ // node be removed from the graph.
+ return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
+}
+
+bool Graph::IsValidNode(Node* node) const {
+ if (node == nullptr) return false;
+ const int id = node->id();
+ if (id < 0 || static_cast<size_t>(id) >= nodes_.size()) return false;
+ return nodes_[id] == node;
+}
+
+Node* Graph::AllocateNode(Node::Properties* props, const Node* cost_node) {
+ Node* node = nullptr;
+ if (free_nodes_.empty()) {
+ node = new (arena_.Alloc(sizeof(Node))) Node; // placement new
+ } else {
+ node = free_nodes_.back();
+ free_nodes_.pop_back();
+ }
+ const int id = nodes_.size();
+ int cost_id = cost_node ? cost_node->cost_id() : id;
+ node->Initialize(id, cost_id, props);
+ nodes_.push_back(node);
+ return node;
+}
+
+void Graph::ReleaseNode(Node* node) {
+ DCHECK(IsValidNode(node)) << node->DebugString();
+ nodes_[node->id()] = nullptr;
+ free_nodes_.push_back(node);
+ node->Clear();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
new file mode 100644
index 0000000000..030e471bf4
--- /dev/null
+++ b/tensorflow/core/graph/graph.h
@@ -0,0 +1,440 @@
+// A Graph describes a set of computations that are to be
+// performed, as well as the dependencies between those
+// compuations. The basic model is a DAG (directed acyclic graph) with
+// * internal nodes representing computational operations to be performed;
+// * edges represent dependencies, indicating the target may only be
+// executed once the source has completed; and
+// * predefined "source" (start) and "sink" (finish) nodes -- the source
+// should be the only node that doesn't depend on anything, and the sink
+// should be the only node that nothing depends on.
+//
+// Note: Node ids are intended to be relatively dense in the
+// 0..max_id range, but there may be gaps since ids won't be reused.
+//
+// Note: Some dependencies between operations are due to one operation
+// consuming the output of another. In fact operations can produce
+// multiple outputs and consume multiple inputs, and some
+// optimizations will care about which specific outputs are connected
+// to which specific inputs. We therefore represent data dependency
+// between output O of layer A and input I of layer B using
+// "input index" and "output index" labels per edge.
+
+#ifndef TENSORFLOW_GRAPH_GRAPH_H_
+#define TENSORFLOW_GRAPH_GRAPH_H_
+
+#include <functional>
+#include <string>
+#include <vector>
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/edgeset.h"
+#include "tensorflow/core/lib/core/arena.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/gtl/iterator_range.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+class Edge;
+class EdgeSetTest;
+class Graph;
+class Node;
+
+class NeighborIter; // Declared below
+class NodeIter; // Declared below
+
+class Node {
+ public:
+ string DebugString() const;
+ int id() const { return id_; }
+ int cost_id() const { return cost_id_; }
+ const string& name() const { return props_->node_def_.name(); }
+ const string& type_string() const { return props_->node_def_.op(); }
+ const NodeDef& def() const { return props_->node_def_; }
+ const OpDef& op_def() const { return *props_->op_def_; }
+
+ // input and output types
+ int num_inputs() const { return props_->input_types_.size(); }
+ DataType input_type(int i) const { return props_->input_types_[i]; }
+ const DataTypeVector& input_types() const { return props_->input_types_; }
+
+ int num_outputs() const { return props_->output_types_.size(); }
+ DataType output_type(int o) const { return props_->output_types_[o]; }
+ const DataTypeVector& output_types() const { return props_->output_types_; }
+
+ // This gives the device the runtime has assigned this node to. If
+ // you want the device the user requested, use def().device() instead.
+ // TODO(josh11b): Validate that the assigned_device, if not empty:
+ // fully specifies a device, and satisfies def().device().
+ // TODO(josh11b): Move device_name outside of Node into a NodeId->DeviceName
+ // map.
+ string assigned_device_name() const { return assigned_device_name_; }
+ void set_assigned_device_name(const string& device_name) {
+ assigned_device_name_ = device_name;
+ }
+
+ // Get the neighboring nodes via edges either in or out of this node.
+ gtl::iterator_range<NeighborIter> in_nodes() const;
+ gtl::iterator_range<NeighborIter> out_nodes() const;
+ const EdgeSet& in_edges() const { return in_edges_; }
+ const EdgeSet& out_edges() const { return out_edges_; }
+
+ // Node type helpers.
+ bool IsSource() const { return id() == 0; }
+ bool IsSink() const { return id() == 1; }
+ // Anything other than the special Source & Sink nodes.
+ bool IsOp() const { return id() > 1; }
+
+ private:
+ friend class Graph;
+ Node();
+ ~Node();
+
+ class Properties : public core::RefCounted {
+ public:
+ Properties(const OpDef* op_def, const NodeDef& node_def,
+ const DataTypeSlice inputs, const DataTypeSlice outputs);
+
+ const OpDef* op_def_; // not owned
+ const NodeDef node_def_;
+ const DataTypeVector input_types_;
+ const DataTypeVector output_types_;
+
+ private:
+ // Destructor invoked when last reference goes away via Unref()
+ virtual ~Properties();
+ TF_DISALLOW_COPY_AND_ASSIGN(Properties);
+ };
+
+ Properties* properties() const { return props_; }
+
+ // Initialize() adopts a reference to props, and so is suitable if props was
+ // just allocated or you call props->Ref() to increment the reference
+ // count for a props being held by another Node.
+ void Initialize(int id, int cost_id, Properties* props);
+ // Releases memory from props_, in addition to restoring *this to its
+ // uninitialized state.
+ void Clear();
+
+ int id_; // -1 until Initialize() is called
+ int cost_id_; // -1 if there is no corresponding cost accounting node
+
+ EdgeSet in_edges_;
+ EdgeSet out_edges_;
+
+ Properties* props_;
+
+ // Name of device assigned to perform this computation.
+ string assigned_device_name_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Node);
+};
+
+class Edge {
+ public:
+ Node* src() const { return src_; }
+ Node* dst() const { return dst_; }
+ int id() const { return id_; }
+
+ // Return the number of the source output that produces the data
+ // carried by this edge. The special value kControlSlot is used
+ // for control dependencies.
+ int src_output() const { return src_output_; }
+
+ // Return the number of the destination input that consumes the data
+ // carried by this edge. The special value kControlSlot is used
+ // for control dependencies.
+ int dst_input() const { return dst_input_; }
+
+ // Return true iff this is an edge that indicates a control-flow
+ // (as opposed to a data-flow) dependency.
+ bool IsControlEdge() const;
+
+ private:
+ Edge() {}
+
+ friend class EdgeSetTest;
+ friend class Graph;
+ Node* src_;
+ Node* dst_;
+ int id_;
+ int src_output_;
+ int dst_input_;
+};
+
+// Thread compatible but not thread safe.
+class Graph {
+ public:
+ // Constructs a graph with a single SOURCE (always id kSourceId) and a
+ // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
+ //
+ // The graph can hold ops found in registry.
+ explicit Graph(const OpRegistryInterface* registry);
+ ~Graph();
+
+ static const int kControlSlot = -1;
+
+ // Adds a new node to this graph, and returns it. Infers the Op and
+ // input/output types for the node. *this owns the returned instance.
+ // Returns nullptr and sets *status on error.
+ Node* AddNode(const NodeDef& node_def, Status* status);
+
+ // Copies *node, which may belong to another graph, to a new node,
+ // which is returned. Does not copy any edges. *this owns the
+ // returned instance.
+ Node* CopyNode(Node* node);
+
+ // Remove a node from this graph, including all edges from or to it.
+ // *node should not be accessed after calling this function.
+ // REQUIRES: node->IsOp()
+ void RemoveNode(Node* node);
+
+ // Add an edge that connects the xth output of "source" to the yth input
+ // of "dest".
+ const Edge* AddEdge(Node* source, int x, Node* dest, int y);
+
+ // Add a control-edge (no data flows along this edge) that
+ // connects "source" to "dest".
+ const Edge* AddControlEdge(Node* source, Node* dest) {
+ return AddEdge(source, kControlSlot, dest, kControlSlot);
+ }
+
+ // Removes edge from the graph.
+ // REQUIRES: The edge must exist.
+ void RemoveEdge(const Edge* edge);
+
+ // Returns one more than the maximum id assigned to any node.
+ int num_node_ids() const { return nodes_.size(); }
+
+ // Serialize to a GraphDef.
+ void ToGraphDef(GraphDef* graph_def) const;
+
+ // Generate new node name with the specified prefix that is unique
+ // across this graph.
+ string NewName(StringPiece prefix);
+
+ // Access to the list of all nodes. Example usage:
+ // for (Node* node : graph.nodes()) { ... }
+ gtl::iterator_range<NodeIter> nodes() const;
+
+ // Returns the node associated with an id, or nullptr if no node
+ // with that id (the node with that id was removed and the id has
+ // not yet been re-used). *this owns the returned instance.
+ // REQUIRES: 0 <= id < num_node_ids().
+ Node* FindNodeId(int id) const { return nodes_[id]; }
+
+ // Returns one more than the maximum id assigned to any edge.
+ int num_edge_ids() const { return edges_.size(); }
+
+ // Returns the Edge associated with an id, or nullptr if no edge
+ // with that id (the node with that id was removed and the id has
+ // not yet been re-used). *this owns the returned instance.
+ // REQUIRES: 0 <= id < num_node_ids().
+ const Edge* FindEdgeId(int id) const { return edges_[id]; }
+
+ // Access to the set of all edges. Example usage:
+ // for (const Edge* e : graph.edges()) { ... }
+ const EdgeSet& edges() const { return edge_set_; }
+
+ // The pre-defined nodes.
+ enum { kSourceId = 0, kSinkId = 1 };
+ Node* source_node() const { return FindNodeId(kSourceId); }
+ Node* sink_node() const { return FindNodeId(kSinkId); }
+
+ const OpRegistryInterface* op_registry() const { return ops_; }
+
+ // TODO(josh11b): uint64 hash() const;
+
+ private:
+ bool IsValidNode(Node* node) const;
+ // If cost_node is non-null, then cost accounting (in CostModel)
+ // will be associated with that node rather than the new one being
+ // created.
+ Node* AllocateNode(Node::Properties* props, const Node* cost_node);
+ void ReleaseNode(Node* node);
+
+ // Registry of all known ops. Not owned.
+ const OpRegistryInterface* const ops_;
+
+ // Allocator which will give us good locality.
+ core::Arena arena_;
+
+ // Map from node ids to allocated nodes. nodes_[id] may be nullptr if
+ // the node with that id was removed from the graph.
+ std::vector<Node*> nodes_;
+
+ // Map from edge ids to allocated edges. edges_[id] may be nullptr if
+ // the edge with that id was removed from the graph.
+ std::vector<Edge*> edges_;
+
+ // For ease of iteration, we currently just keep a set of all live
+ // edges. May want to optimize by removing this copy.
+ EdgeSet edge_set_;
+
+ // Allocated but free nodes and edges.
+ std::vector<Node*> free_nodes_;
+ std::vector<Edge*> free_edges_;
+
+ // For generating unique names.
+ int name_counter_ = 0;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Graph);
+};
+
+// TODO(josh11b): We may want to support keeping an index on various
+// node/edge attributes in a graph, particularly node names.
+
+// Helper routines
+
+inline bool IsSwitch(const Node* node) {
+ return node->type_string() == "Switch" || node->type_string() == "RefSwitch";
+}
+
+inline bool IsMerge(const Node* node) { return node->type_string() == "Merge"; }
+
+inline bool IsEnter(const Node* node) {
+ return node->type_string() == "Enter" || node->type_string() == "RefEnter";
+}
+
+inline bool IsExit(const Node* node) { return node->type_string() == "Exit"; }
+
+inline bool IsNextIteration(const Node* node) {
+ return node->type_string() == "NextIteration";
+}
+
+inline bool IsLoopCond(const Node* node) {
+ return node->type_string() == "LoopCond";
+}
+
+inline bool IsControlTrigger(const Node* node) {
+ return node->type_string() == "ControlTrigger";
+}
+
+inline bool IsSend(const Node* node) {
+ return node->type_string() == "_Send" || node->type_string() == "_HostSend";
+}
+
+inline bool IsRecv(const Node* node) {
+ return node->type_string() == "_Recv" || node->type_string() == "_HostRecv";
+}
+
+// True for Nodes that mediate the transfer of values between processes.
+inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); }
+
+inline bool IsConstant(const Node* node) {
+ return node->type_string() == "Const" || node->type_string() == "HostConst";
+}
+
+inline bool IsVariable(const Node* node) {
+ return node->type_string() == "Variable";
+}
+
+inline bool IsIdentity(const Node* node) {
+ return (node->type_string() == "Identity" ||
+ node->type_string() == "RefIdentity");
+}
+
+// Returns true iff 'n' is a control flow node.
+inline bool IsControlFlow(const Node* n) {
+ return IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n) ||
+ IsNextIteration(n);
+}
+
+inline bool IsHostMemoryPreserving(const Node* node) {
+ return IsIdentity(node) || IsControlFlow(node);
+}
+
+// Iterator for stepping through the nodes of a graph.
+class NodeIter {
+ public:
+ NodeIter(const Graph* graph, int id);
+ bool operator==(const NodeIter& rhs);
+ bool operator!=(const NodeIter& rhs);
+ void operator++();
+ Node* operator*();
+ Node* operator->();
+
+ private:
+ // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr
+ const Graph* graph_;
+ int id_;
+};
+
+// Iterator for stepping through the neighbors of a node.
+class NeighborIter {
+ public:
+ NeighborIter(EdgeSet::const_iterator iter, bool incoming);
+ bool operator==(const NeighborIter& rhs);
+ bool operator!=(const NeighborIter& rhs);
+ void operator++();
+ Node* operator*();
+ Node* operator->();
+
+ private:
+ EdgeSet::const_iterator iter_;
+ bool incoming_;
+};
+
+// IMPLEMENTATION DETAILS, PLEASE IGNORE
+
+inline NodeIter::NodeIter(const Graph* graph, int id)
+ : graph_(graph), id_(id) {}
+
+inline bool NodeIter::operator==(const NodeIter& rhs) {
+ DCHECK(graph_ == rhs.graph_);
+ return id_ == rhs.id_;
+}
+
+inline bool NodeIter::operator!=(const NodeIter& rhs) {
+ return !(*this == rhs);
+}
+
+inline void NodeIter::operator++() {
+ while (1) {
+ DCHECK_LE(id_, graph_->num_node_ids());
+ ++id_;
+ if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) {
+ return;
+ }
+ }
+}
+
+inline Node* NodeIter::operator*() { return graph_->FindNodeId(id_); }
+
+inline Node* NodeIter::operator->() { return graph_->FindNodeId(id_); }
+
+inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming)
+ : iter_(iter), incoming_(incoming) {}
+
+inline bool NeighborIter::operator==(const NeighborIter& rhs) {
+ return iter_ == rhs.iter_ && incoming_ == rhs.incoming_;
+}
+
+inline bool NeighborIter::operator!=(const NeighborIter& rhs) {
+ return !(*this == rhs);
+}
+
+inline void NeighborIter::operator++() { ++iter_; }
+
+inline Node* NeighborIter::operator*() {
+ const Edge* e = *iter_;
+ return incoming_ ? e->src() : e->dst();
+}
+
+inline Node* NeighborIter::operator->() {
+ const Edge* e = *iter_;
+ return incoming_ ? e->src() : e->dst();
+}
+
+inline bool Edge::IsControlEdge() const {
+ // Note that if either src_output_ or dst_input_ is kControlSlot,
+ // so is the other one (AddEdge checks this).
+ return src_output_ == Graph::kControlSlot;
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_GRAPH_H_
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
new file mode 100644
index 0000000000..3928348f0a
--- /dev/null
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -0,0 +1,385 @@
+#include "tensorflow/core/graph/graph_constructor.h"
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/optimizer_cse.h"
+#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace tensorflow {
+
+namespace {
+inline bool IsMerge(const NodeDef& node_def) {
+ return node_def.op() == "Merge";
+}
+} // namespace
+
+namespace {
+
+class GraphConstructor {
+ public:
+ GraphConstructor(const GraphConstructorOptions& opts, const GraphDef* gdef,
+ Graph* g, Status* status)
+ : opts_(opts), gdef_(gdef), g_(g), status_(status) {
+ BuildNodeIndex();
+ InitFromEdges();
+ Convert();
+ }
+
+ private:
+ void SetError(const string& error);
+ void SetNodeError(const NodeDef& node_def, const StringPiece& message) {
+ SetError(strings::StrCat("Node '", node_def.name(), "': ", message));
+ }
+ void BuildNodeIndex();
+ void InitFromEdges();
+ Node* MakeNode(const NodeDef& node_def);
+ void Convert();
+ // Calls SetError() and returns false if the type of the output of
+ // the source of the edge can't be consumed by destination of the edge.
+ // REQUIRES: edge must be a data edge, not a control edge.
+ bool TypeValidateEdge(const Edge* edge);
+
+ // From constructor
+ const GraphConstructorOptions opts_;
+ const GraphDef* gdef_;
+ Graph* g_;
+ Status* status_;
+
+ // Mapping from node name to the index within gdef_
+ struct NodeInfo {
+ explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
+ // std::unordered_map<> requires that we have a default constructor.
+ NodeInfo() : NodeInfo(-1) {}
+ int gdef_index;
+ Node* node; // nullptr until the NodeDef is converted to a Node.
+ };
+ // TODO(vrv): Profile this data structure to see if we should use an
+ // alternative implementation of std::unordered_map.
+ std::unordered_map<StringPiece, NodeInfo, StringPiece::Hasher> name_index_;
+
+ // Index of NodeDefs in gdef_ with all inputs already converted.
+ std::vector<int> ready_;
+
+ // Mapping between index within gdef_ and the number of inputs that
+ // still need to be converted.
+ std::vector<int> pending_count_;
+
+ // Mapping between index within gdef_ and the index within gdef_ of
+ // all nodes it outputs to.
+ std::vector<gtl::InlinedVector<int, 4>> outputs_;
+
+ // Used in the conversion from gdef_ to g_ to represent the ith input
+ // of a node.
+ struct InputInfo {
+ explicit InputInfo(StringPiece node_name, Node* n, int i)
+ : name(node_name), node(n), index(i) {}
+ StringPiece name;
+ Node* node;
+ int index;
+ };
+
+ // Used in the conversion from gdef_ to g_ to represent an edge from
+ // the node named 'name' to node 'n'.
+ struct EdgeInfo {
+ explicit EdgeInfo(StringPiece name, int i1, Node* n, int i2)
+ : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {}
+ StringPiece src_name;
+ int src_index;
+ Node* dst_node;
+ int dst_index;
+ };
+};
+
+void GraphConstructor::SetError(const string& error) {
+ status_->Update(errors::InvalidArgument(error));
+}
+
+void GraphConstructor::BuildNodeIndex() {
+ // Initialized outside the loop for efficiency
+ const char* pattern;
+ if (opts_.allow_internal_ops) {
+ pattern = "[A-Za-z0-9._][A-Za-z0-9_.\\-/]*";
+ } else {
+ pattern = "[A-Za-z0-9.][A-Za-z0-9_.\\-/]*";
+ }
+ RE2 node_name_re(pattern);
+
+ // Validate the node names and add them to name_index_.
+ for (int n = 0; n < gdef_->node_size(); ++n) {
+ const NodeDef& node_def(gdef_->node(n));
+ if (!RE2::FullMatch(node_def.name(), node_name_re)) {
+ SetNodeError(node_def, "Node name contains invalid characters");
+ return;
+ }
+ if (!name_index_.insert(std::make_pair(StringPiece(node_def.name()),
+ NodeInfo(n)))
+ .second) {
+ SetNodeError(node_def, "Node name is not unique");
+ return;
+ }
+ // Validate the operation's type.
+ if (node_def.op().empty()) {
+ SetNodeError(node_def, "Does not specify a type");
+ return;
+ }
+ if (opts_.expect_device_spec && node_def.device().empty()) {
+ SetNodeError(node_def, strings::StrCat("Missing device specification."));
+ return;
+ }
+ }
+}
+
+void GraphConstructor::InitFromEdges() {
+ const int num_nodes = gdef_->node_size();
+ ready_.reserve(num_nodes);
+ pending_count_.reserve(num_nodes);
+ outputs_.resize(num_nodes);
+
+ // Parse the inputs for each node.
+ for (int n = 0; n < num_nodes; ++n) {
+ const NodeDef& node_def(gdef_->node(n));
+ if (IsMerge(node_def)) {
+ // for merge only wait for one non-control input.
+ int32 num_control_edges = 0;
+ for (int i = 0; i < node_def.input_size(); ++i) {
+ StringPiece input_name(node_def.input(i));
+ if (StringPiece(input_name).starts_with("^")) {
+ num_control_edges++;
+ }
+ }
+ pending_count_.push_back(num_control_edges + 1);
+ } else {
+ pending_count_.push_back(node_def.input_size());
+ }
+ if (node_def.input_size() == 0) {
+ ready_.push_back(n);
+ continue;
+ }
+ for (int i = 0; i < node_def.input_size(); ++i) {
+ StringPiece input_name = node_def.input(i);
+ if (input_name.starts_with("^")) {
+ // Control dependence
+ input_name.remove_prefix(1);
+ }
+ TensorId id(ParseTensorName(input_name));
+ auto iter = name_index_.find(id.first);
+ if (iter == name_index_.end()) {
+ SetNodeError(node_def,
+ strings::StrCat("Unknown input node ", node_def.input(i)));
+ return;
+ }
+ outputs_[iter->second.gdef_index].push_back(n);
+ }
+ }
+}
+
+Node* GraphConstructor::MakeNode(const NodeDef& node_def) {
+ // Add the node to the graph.
+ Node* node = g_->AddNode(node_def, status_);
+ if (node == nullptr) return nullptr;
+ if (opts_.expect_device_spec) {
+ node->set_assigned_device_name(node_def.device());
+ }
+ name_index_[node_def.name()].node = node;
+ return node;
+}
+
+// Return the number of nodes in "g"
+static int CountNodes(Graph* g) {
+ int nodes = 0;
+ for (Node* node : g->nodes()) {
+ VLOG(1) << node; // Dummy use to avoid compiler warning
+ nodes++;
+ }
+ return nodes;
+}
+
+void GraphConstructor::Convert() {
+ std::vector<InputInfo> inputs;
+ std::vector<EdgeInfo> back_edges;
+ int processed = 0;
+ // Process the NodeDefs in topological order.
+ while (!ready_.empty()) {
+ int o = ready_.back();
+ ready_.pop_back();
+ ++processed;
+ const NodeDef& node_def(gdef_->node(o));
+ inputs.clear();
+ bool in_control_dependence = false;
+ bool has_data_back_edge = false;
+ for (int i = 0; i < node_def.input_size(); ++i) {
+ StringPiece input_name(node_def.input(i));
+ if (StringPiece(input_name).starts_with("^")) {
+ // A control dependence
+ in_control_dependence = true;
+ input_name.remove_prefix(1);
+ } else {
+ if (in_control_dependence) {
+ SetNodeError(node_def, strings::StrCat(
+ "Control dependencies must come after ",
+ "regular dependencies: input ", input_name,
+ " of source node ", node_def.name()));
+ return;
+ }
+ }
+ TensorId id(ParseTensorName(input_name));
+ auto iter = name_index_.find(id.first);
+ DCHECK(iter != name_index_.end());
+ Node* src_node = iter->second.node;
+ if (in_control_dependence) {
+ inputs.push_back(InputInfo(id.first, src_node, -1));
+ } else {
+ if (src_node == nullptr) {
+ has_data_back_edge = true;
+ inputs.push_back(InputInfo(id.first, src_node, id.second));
+ } else {
+ if (id.second >= src_node->num_outputs()) {
+ SetNodeError(
+ node_def,
+ strings::StrCat("Connecting to invalid output ", id.second,
+ " of source node ", id.first, " which has ",
+ src_node->num_outputs(), " outputs"));
+ return;
+ }
+ inputs.push_back(InputInfo(id.first, src_node, id.second));
+ }
+ }
+ }
+ if (has_data_back_edge && !IsMerge(node_def)) {
+ SetError(strings::StrCat(
+ node_def.name(),
+ " had a back edge. But only Merge can have back edges."));
+ return;
+ }
+
+ Node* node = MakeNode(node_def);
+ if (node == nullptr) return;
+
+ // Add edges from inputs to *node to the graph.
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ if (inputs[i].node == nullptr) {
+ // Record this back edge, which will be added after all nodes
+ // are created.
+ back_edges.push_back(
+ EdgeInfo(inputs[i].name, inputs[i].index, node, i));
+ } else if (inputs[i].index == -1) {
+ g_->AddControlEdge(inputs[i].node, node);
+ } else {
+ const Edge* edge =
+ g_->AddEdge(inputs[i].node, inputs[i].index, node, i);
+ if (!TypeValidateEdge(edge)) return;
+ }
+ }
+
+ // Update pending_count_ for outputs.
+ for (size_t i = 0; i < outputs_[o].size(); ++i) {
+ const int output = outputs_[o][i];
+ pending_count_[output]--;
+ if (pending_count_[output] == 0) {
+ ready_.push_back(output);
+ }
+ }
+ }
+
+ // Add the back edges after all nodes are created.
+ for (auto e : back_edges) {
+ Node* src_node = name_index_[e.src_name].node;
+ if (e.src_index == -1) {
+ g_->AddControlEdge(src_node, e.dst_node);
+ } else {
+ const Edge* edge =
+ g_->AddEdge(src_node, e.src_index, e.dst_node, e.dst_index);
+ if (!TypeValidateEdge(edge)) return;
+ }
+
+ VLOG(2) << "Add back edge: " << src_node->name() << " -> "
+ << e.dst_node->name();
+ }
+
+ if (processed < gdef_->node_size()) {
+ SetError(
+ strings::StrCat(gdef_->node_size() - processed, " nodes in a cycle"));
+ return;
+ }
+
+ if (status_->ok()) {
+ FixupSourceAndSinkEdges(g_);
+
+ if (opts_.optimizer_do_cse) {
+ if (!back_edges.empty()) {
+ LOG(WARNING) << "Not doing CSE. We need to figure out how to handle "
+ << "loops in the CSE phase.";
+ } else {
+ VLOG(1) << "Starting CSE: graph of " << CountNodes(g_) << " nodes";
+ OptimizeCSE(g_, opts_.cse_consider_function);
+ VLOG(1) << "Finished CSE: graph of " << CountNodes(g_) << " nodes";
+ }
+ }
+ }
+}
+
+bool GraphConstructor::TypeValidateEdge(const Edge* edge) {
+ DataType src_out = edge->src()->output_type(edge->src_output());
+ DataType dst_in = edge->dst()->input_type(edge->dst_input());
+ if (!TypesCompatible(dst_in, src_out)) {
+ SetError(strings::StrCat(
+ "Input ", edge->dst_input(), " of node ", edge->dst()->name(),
+ " was passed ", DataTypeString(src_out), " from ", edge->src()->name(),
+ ":", edge->src_output(), " incompatible with expected ",
+ DataTypeString(dst_in), "."));
+ return false;
+ }
+ return true;
+}
+
+} // namespace
+
+// ----------------------------------------------------------------------------
+// ConvertGraphDefToGraph
+// ----------------------------------------------------------------------------
+
+Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
+ const GraphDef& gdef, Graph* g) {
+ Status status;
+ GraphConstructor constructor(opts, &gdef, g, &status);
+ return status;
+}
+
+// ----------------------------------------------------------------------------
+// CopyGraph
+// ----------------------------------------------------------------------------
+void CopyGraph(const Graph& src, Graph* dest) {
+ for (Node* n : dest->nodes()) {
+ CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
+ }
+
+ // Copy the nodes
+ std::unordered_map<Node*, Node*>
+ node_map; // "Node in src" -> "Node in *dest"
+ node_map[src.source_node()] = dest->source_node();
+ node_map[src.sink_node()] = dest->sink_node();
+ for (Node* n : src.nodes()) {
+ if (n->IsSource() || n->IsSink()) continue;
+ CHECK(n->IsOp());
+ node_map[n] = dest->CopyNode(n);
+ }
+
+ // Copy the edges
+ for (const Edge* e : src.edges()) {
+ Node* src_copy = node_map[e->src()];
+ Node* dst_copy = node_map[e->dst()];
+ dest->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h
new file mode 100644
index 0000000000..cd1615ef6b
--- /dev/null
+++ b/tensorflow/core/graph/graph_constructor.h
@@ -0,0 +1,43 @@
+#ifndef TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_
+#define TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// Construct a graph *g out of a GraphDef gdef. Returns non-OK on
+// error, in which case *g is left in an incomplete state.
+struct GraphConstructorOptions {
+ // If true, allows internal ops in the GraphDef.
+ bool allow_internal_ops = false;
+
+ // If true, the graph def is expected to have fully specified
+ // devices for all nodes. A node in the resulting graph "g" has the
+ // device name set accordingly.
+ //
+ // TODO(zhifengc): if possible, consider removing this option.
+ bool expect_device_spec = false;
+
+ // If true, perform common subexpression elimination on the graph.
+ // TODO(jeff): Turn this default to true?
+ bool optimizer_do_cse = false;
+
+ // If "optimizer_do_cse" is true and "cse_consider_function" is
+ // not nullptr, then only consider nodes for CSE for which
+ // "cse_consider_function(node)" returns true.
+ std::function<bool(const Node*)> cse_consider_function = nullptr;
+};
+extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
+ const GraphDef& gdef, Graph* g);
+
+// Make a copy of "src" into "*dest".
+//
+// REQUIRES: "*dest" is a freshly allocated graph without any nodes or edges
+// other than the implicit Source/Sink nodes.
+extern void CopyGraph(const Graph& src, Graph* dest);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
new file mode 100644
index 0000000000..61f4427297
--- /dev/null
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -0,0 +1,190 @@
+#include "tensorflow/core/graph/graph_constructor.h"
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/public/status.h"
+#include <gtest/gtest.h>
+
+// TODO(josh11b): Test InitCostModel().
+// TODO(josh11b): Test setting the "device" field of a NodeDef.
+// TODO(josh11b): Test that feeding won't prune targets.
+
+namespace tensorflow {
+namespace {
+
+class GraphConstructorTest : public ::testing::Test {
+ protected:
+ GraphConstructorTest() : g_(new Graph(OpRegistry::Global())) {
+ RequireDefaultOps();
+ }
+ ~GraphConstructorTest() override {}
+
+ void Convert(const string& gdef_ascii) {
+ CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef_));
+ }
+
+ void ExpectError(const string& gdef_ascii, const string& expected_error_re) {
+ Convert(gdef_ascii);
+ GraphConstructorOptions opts;
+ Status status = ConvertGraphDefToGraph(opts, gdef_, g_.get());
+ EXPECT_FALSE(status.ok());
+ EXPECT_TRUE(RE2::PartialMatch(status.error_message(), expected_error_re))
+ << status;
+ }
+
+ void ExpectOK(const string& gdef_ascii) {
+ Convert(gdef_ascii);
+ GraphConstructorOptions opts;
+ TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef_, g_.get()));
+ }
+
+ Node* FindNode(const string& name) {
+ for (Node* n : g_->nodes()) {
+ if (n->name() == name) return n;
+ }
+ return nullptr;
+ }
+
+ bool HasNode(const string& name) { return FindNode(name) != nullptr; }
+
+ void ExpectNodes(const string& nodes) {
+ int count = 0;
+ std::vector<string> actual_nodes;
+ for (Node* n : g_->nodes()) {
+ if (n->IsOp()) {
+ count++;
+ actual_nodes.push_back(n->name());
+ }
+ }
+ std::sort(actual_nodes.begin(), actual_nodes.end());
+
+ LOG(INFO) << "Nodes present: " << str_util::Join(actual_nodes, " ");
+
+ std::vector<string> expected_nodes = str_util::Split(nodes, ',');
+ std::sort(expected_nodes.begin(), expected_nodes.end());
+ for (const string& s : expected_nodes) {
+ Node* n = FindNode(s);
+ EXPECT_TRUE(n != nullptr) << s;
+ }
+
+ EXPECT_TRUE(actual_nodes.size() == expected_nodes.size())
+ << "\nActual: " << str_util::Join(actual_nodes, ",")
+ << "\nExpected: " << str_util::Join(expected_nodes, ",");
+ }
+
+ bool HasEdge(const string& src, int src_out, const string& dst, int dst_in) {
+ for (const Edge* e : g_->edges()) {
+ if (e->src()->name() == src && e->src_output() == src_out &&
+ e->dst()->name() == dst && e->dst_input() == src_out)
+ return true;
+ }
+ return false;
+ }
+ bool HasControlEdge(const string& src, const string& dst) {
+ return HasEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot);
+ }
+
+ private:
+ GraphDef gdef_;
+ std::unique_ptr<Graph> g_;
+};
+
+REGISTER_OP("ABC");
+REGISTER_OP("TestParams").Output("o: float");
+REGISTER_OP("TestInput").Output("a: float").Output("b: float");
+REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
+REGISTER_OP("TestInt").Input("a: int32");
+
+TEST_F(GraphConstructorTest, InvalidNodeName) {
+ ExpectError("node { name: 'a:b' op: 'ABC' }",
+ "Node 'a:b': Node name contains invalid characters");
+ ExpectError("node { name: '_abc' op: 'ABC' }",
+ // Can't start with '_'
+ "Node '_abc': Node name contains invalid characters");
+ ExpectOK("node { name: 'a-bc_' op: 'ABC' }");
+}
+
+TEST_F(GraphConstructorTest, InvalidSourceNodeName) {
+ ExpectError(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: 'W999' input: 'input' }",
+
+ "Unknown input node.*W999");
+}
+
+TEST_F(GraphConstructorTest, InvalidSourceNodeIndex) {
+ ExpectError(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1:1', 'input:1' ] }",
+
+ "Connecting to invalid output 1 of source node W1");
+}
+
+TEST_F(GraphConstructorTest, GraphWithCycle) {
+ ExpectError(
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'input:0', 't2' ] }"
+ "node { name: 't2' op: 'TestMul' input: [ 'input:1', 't1' ] }",
+
+ "cycle");
+}
+
+TEST_F(GraphConstructorTest, TypeMismatch) {
+ ExpectError(
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 'int' op: 'TestInt' input: [ 'input' ] }",
+
+ "Input 0 of node int was passed float from input:0 incompatible with "
+ "expected int32.");
+}
+
+TEST_F(GraphConstructorTest, EmptyGraph) { ExpectOK(""); }
+
+TEST_F(GraphConstructorTest, SimpleModel) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }");
+ EXPECT_TRUE(HasNode("W1"));
+ EXPECT_TRUE(HasNode("input"));
+ EXPECT_TRUE(HasNode("t1"));
+ EXPECT_TRUE(HasEdge("W1", 0, "t1", 0));
+ EXPECT_TRUE(HasEdge("input", 1, "t1", 0));
+}
+
+TEST_F(GraphConstructorTest, SimpleModelWithControlEdges) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' input: [ '^W1' ] }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+ "node { name: 't2' op: 'TestMul' input: [ 'W1', 'input:1', '^t1' ] }");
+ EXPECT_TRUE(HasNode("W1"));
+ EXPECT_TRUE(HasNode("input"));
+ EXPECT_TRUE(HasNode("t1"));
+ EXPECT_TRUE(HasNode("t2"));
+ EXPECT_TRUE(HasEdge("W1", 0, "t1", 0));
+ EXPECT_TRUE(HasEdge("input", 1, "t1", 0));
+ EXPECT_TRUE(HasEdge("W1", 0, "t2", 0));
+ EXPECT_TRUE(HasEdge("input", 1, "t2", 0));
+ EXPECT_TRUE(HasControlEdge("W1", "input"));
+ EXPECT_TRUE(HasControlEdge("t1", "t2"));
+}
+
+TEST_F(GraphConstructorTest, Error_ControlEdgeBeforeRealInput) {
+ ExpectError(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' input: [ '^W1' ] }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+ "node { name: 't2' op: 'TestMul' input: [ 'W1', '^t1', 'input:1' ] }",
+ "Node 't2': Control dependencies must come after regular dependencies");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph_def_builder.cc b/tensorflow/core/graph/graph_def_builder.cc
new file mode 100644
index 0000000000..979604f948
--- /dev/null
+++ b/tensorflow/core/graph/graph_def_builder.cc
@@ -0,0 +1,121 @@
+#include "tensorflow/core/graph/graph_def_builder.h"
+
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+GraphDefBuilder::Options::Options(Graph* graph, Status* status)
+ : graph_(graph), status_(status) {}
+GraphDefBuilder::Options::~Options() {}
+
+GraphDefBuilder::Options GraphDefBuilder::Options::WithName(
+ StringPiece name) const {
+ return Options(*this).WithNameImpl(name);
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithDevice(
+ StringPiece device) const {
+ return Options(*this).WithDeviceImpl(device);
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInput(
+ Node* control_input) const {
+ return Options(*this).WithControlInputImpl(control_input);
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputs(
+ gtl::ArraySlice<Node*> control_inputs) const {
+ return Options(*this).WithControlInputsImpl(control_inputs);
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithNameImpl(
+ StringPiece name) {
+ name_ = name.ToString();
+ return *this;
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithDeviceImpl(
+ StringPiece device) {
+ device_ = device.ToString();
+ return *this;
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputImpl(
+ Node* control_input) {
+ control_inputs_.push_back(control_input);
+ return *this;
+}
+GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputsImpl(
+ gtl::ArraySlice<Node*> control_inputs) {
+ control_inputs_.insert(control_inputs_.end(), control_inputs.begin(),
+ control_inputs.end());
+ return *this;
+}
+
+Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const {
+ if (status_.ok()) {
+ graph_.ToGraphDef(graph_def);
+ }
+ return status_;
+}
+
+Status GraphDefBuilder::ToGraph(Graph* graph) const {
+ if (status_.ok()) {
+ GraphDef graph_def;
+ graph_.ToGraphDef(&graph_def);
+ GraphConstructorOptions opts;
+ TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, graph));
+ }
+ return status_;
+}
+
+string GraphDefBuilder::Options::GetNameForOp(StringPiece op) const {
+ if (name_.empty()) return graph_->NewName(op);
+ return name_;
+}
+
+Node* GraphDefBuilder::Options::FinalizeBuilder(NodeBuilder* builder) const {
+ builder->ControlInputs(control_inputs_);
+ if (!device_.empty()) builder->Device(device_);
+ for (const auto& attr : attrs_) {
+ builder->Attr(attr.first, attr.second);
+ }
+
+ Node* returned_node;
+ UpdateStatus(builder->Finalize(graph_, &returned_node));
+ return returned_node;
+}
+
+void GraphDefBuilder::Options::UpdateStatus(const Status& status) const {
+ if (status_ == nullptr) {
+ TF_CHECK_OK(status);
+ } else {
+ status_->Update(status);
+ }
+}
+
+namespace ops {
+
+Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts) {
+ if (opts.HaveError()) return nullptr;
+ NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
+ opts.op_registry());
+ return opts.FinalizeBuilder(&node_builder);
+}
+
+Node* UnaryOp(const string& op_name, NodeOut input,
+ const GraphDefBuilder::Options& opts) {
+ if (opts.HaveError()) return nullptr;
+ NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
+ opts.op_registry());
+ node_builder.Input(input);
+ return opts.FinalizeBuilder(&node_builder);
+}
+
+Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
+ const GraphDefBuilder::Options& opts) {
+ if (opts.HaveError()) return nullptr;
+ NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
+ opts.op_registry());
+ node_builder.Input(a).Input(b);
+ return opts.FinalizeBuilder(&node_builder);
+}
+
+} // end namespace ops
+} // end namespace tensorflow
diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h
new file mode 100644
index 0000000000..bb72f9eea6
--- /dev/null
+++ b/tensorflow/core/graph/graph_def_builder.h
@@ -0,0 +1,181 @@
+#ifndef TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
+#define TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
+
+#include <vector>
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tensorflow {
+
+// Given a function like:
+// namespace ops {
+// Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) {
+// if (opts.HaveError()) return nullptr;
+// static const string kOpName = "Identity";
+// NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName,
+// opts.op_registry());
+// node_builder.Input(input);
+// return opts.FinalizeBuilder(&node_builder);
+// }
+// } // namspace ops
+//
+// // Or, alternatively:
+// namespace ops {
+// Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) {
+// static const string kOpName = "Identity";
+// return UnaryOp(kOpName, input, opts);
+// }
+// } // namspace ops
+//
+// You call it like:
+// GraphDefBuilder b;
+// using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+// Node* a = Const(7, b.opts());
+// // Note: WithName() returns a copy, opts is unchanged.
+// Node* b = Const(5, b.opts().WithName("control-input"));
+// Node* c = Identity(a, b.opts().WithControlInput(b));
+// GraphDef graph_def;
+// Status status = b.ToGraphDef(&graph_def);
+// if (!status.ok()) { /* Handle error */ }
+//
+// In tests you can skip the status handling via:
+// GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+// ...
+// b.ToGraphDef(&graph_def);
+
+class GraphDefBuilder {
+ public:
+ // Options for adding a Node to a Graph.
+ class Options {
+ public:
+ // Sets the Graph (that Nodes will be added to) and the status. The
+ // status may be set to nullptr, in which case errors cause CHECK
+ // failures. The graph and status must outlive *this.
+ Options(Graph* graph, Status* status);
+ ~Options();
+
+ // Methods for setting options. These are const methods: they
+ // return a copy of *this with the option set.
+ Options WithName(StringPiece name) const;
+ Options WithDevice(StringPiece device) const;
+ Options WithControlInput(Node* control_input) const;
+ Options WithControlInputs(gtl::ArraySlice<Node*> control_inputs) const;
+
+ // Override the default value for an optional attr.
+ template <class T>
+ Options WithAttr(StringPiece attr_name, T&& value) const {
+ return Options(*this).WithAttrImpl(attr_name, std::forward<T>(value));
+ }
+ // Note: overload needed to allow {...} expressions for value.
+ template <class T>
+ Options WithAttr(StringPiece attr_name,
+ std::initializer_list<T> value) const {
+ return WithAttr<std::initializer_list<T>>(attr_name, std::move(value));
+ }
+
+ // Methods for using options from a function that creates a Node.
+
+ // Returns true if the status associated with *this has an error.
+ // Use this to skip processing that may depend on prior results.
+ bool HaveError() const { return status_ != nullptr && !status_->ok(); }
+
+ // Given the Op type name, return a name for a node of that type.
+ // Uses the value set in WithName() if that has been called. Otherwise,
+ // returns a name built out of the Op type name.
+ string GetNameForOp(StringPiece op) const;
+
+ // Sets the device, adds control inputs, adds attrs, and calls Finalize().
+ // If Finalize returns an error, it is saved and this function returns
+ // nullptr.
+ Node* FinalizeBuilder(NodeBuilder* builder) const;
+
+ // Updates the associated status, if any, or calls TF_CHECK_OK if none.
+ void UpdateStatus(const Status& status) const;
+
+ // Accessor
+ const OpRegistryInterface* op_registry() const {
+ return graph_->op_registry();
+ }
+
+ private:
+ Options WithNameImpl(StringPiece name);
+ Options WithDeviceImpl(StringPiece device);
+ Options WithControlInputImpl(Node* control_input);
+ Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs);
+ template <class T>
+ Options WithAttrImpl(StringPiece name, T&& value) {
+ attrs_.emplace_back(name.ToString(), AttrValue());
+ SetAttrValue(std::forward<T>(value), &attrs_.back().second);
+ return *this;
+ }
+
+ Graph* const graph_;
+ Status* const status_;
+ string name_;
+ string device_;
+ std::vector<Node*> control_inputs_;
+ std::vector<std::pair<string, AttrValue>> attrs_;
+ };
+
+ // Start building a new graph.
+ explicit GraphDefBuilder(
+ const OpRegistryInterface* op_registry = OpRegistry::Global())
+ : graph_(op_registry), opts_(&graph_, &status_) {}
+
+ // For use in tests, where you want to fail immediately on error instead
+ // of checking the status at the end.
+ enum TestFailImmediatelyType { kFailImmediately };
+ explicit GraphDefBuilder(
+ TestFailImmediatelyType,
+ const OpRegistryInterface* op_registry = OpRegistry::Global())
+ : graph_(op_registry), opts_(&graph_, nullptr) {}
+
+ // Gets the Options with the associated Graph and Status.
+ const Options& opts() const { return opts_; }
+
+ // Once all the nodes have been added, call this to get whether it was
+ // successful, and if so fill *graph_def.
+ Status ToGraphDef(GraphDef* graph_def) const;
+
+ // Like ToGraphDef(), but converts to a Graph (using the default
+ // GraphConstructorOptions).
+ // TODO(josh11b): Make this faster; right now it converts
+ // Graph->GraphDef->Graph. This cleans up the graph (e.g. adds
+ // edges from the source and to the sink node, resolves back edges
+ // by name), and makes sure the resulting graph is valid.
+ Status ToGraph(Graph* graph) const;
+
+ private:
+ Graph graph_;
+ Status status_;
+ Options opts_;
+};
+
+namespace ops {
+
+// A NodeOut may either be a regular input or back input. Regular
+// inputs are specified via either a Node* or a Node* and an output
+// index. Back inputs are specified by a node name, output index, and
+// output type.
+typedef NodeBuilder::NodeOut NodeOut;
+
+// For adding an Op with no inputs to a GraphDefBuilder.
+Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts);
+
+// For adding an Op with one input to a GraphDefBuilder.
+Node* UnaryOp(const string& op_name, NodeOut input,
+ const GraphDefBuilder::Options& opts);
+
+// For adding an Op with two inputs to a GraphDefBuilder.
+Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
+ const GraphDefBuilder::Options& opts);
+
+} // namespace ops
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
new file mode 100644
index 0000000000..1571790e59
--- /dev/null
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -0,0 +1,1050 @@
+#include "tensorflow/core/graph/graph_partition.h"
+
+#include <deque>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/costmodel.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+namespace {
+
+struct DupRecvKey {
+ int src_node_id; // Edge's src node id
+ int src_output_slot; // Edge's src node output slot
+ GraphDef* dst_graph; // Edge's dst node is in this subgraph
+ bool recv_output_on_host; // The output of recv is on host
+};
+
+struct DupRecvKeyHash {
+ size_t operator()(const DupRecvKey& k) const {
+ size_t h = Hash64(reinterpret_cast<const char*>(&k.src_node_id),
+ sizeof(k.src_node_id), k.src_output_slot);
+ h = Hash64(reinterpret_cast<const char*>(&k.dst_graph), sizeof(k.dst_graph),
+ h);
+ h = Hash64(reinterpret_cast<const char*>(&k.recv_output_on_host),
+ sizeof(k.recv_output_on_host), h);
+ return h;
+ }
+};
+
+struct DupRecvKeyEq {
+ bool operator()(const DupRecvKey& x, const DupRecvKey& y) const {
+ return (x.src_node_id == y.src_node_id) &&
+ (x.src_output_slot == y.src_output_slot) &&
+ (x.dst_graph == y.dst_graph) &&
+ (x.recv_output_on_host == y.recv_output_on_host);
+ }
+};
+
+// struct used to store the recvs, so that start times can be properly updated
+struct RecvInfo {
+ NodeDef* recv;
+ NodeDef* real_recv;
+ int64 start_time;
+};
+
+typedef std::unordered_map<DupRecvKey, RecvInfo, DupRecvKeyHash, DupRecvKeyEq>
+ DupRecvTable;
+
+// Control flow info for a graph node.
+struct ControlFlowInfo {
+ const Node* frame = nullptr; // frame of a node
+ const Node* parent_frame = nullptr; // parent frame of a node
+ string frame_name; // frame name of a node
+ int iter_level = -1; // level of a node
+};
+
+struct PairIntHash {
+ public:
+ std::size_t operator()(const std::pair<int, int>& x) const {
+ return std::hash<int>()(x.first) ^ std::hash<int>()(x.second);
+ }
+};
+// A map used to store memory types for the inputs/outputs of every node.
+// The key is a pair of ints consisting of a node id and input/output index.
+typedef std::unordered_map<std::pair<int, int>, MemoryType, PairIntHash>
+ MemoryTypeMap;
+
+// We collect the following information about the graph before performing
+// graph partitioning.
+struct GraphInfo {
+ std::vector<DeviceType> device_types;
+ MemoryTypeMap input_types;
+ MemoryTypeMap output_types;
+ std::vector<ControlFlowInfo> cf_info;
+};
+
+DataType EdgeType(const Edge* e) {
+ if (e->IsControlEdge()) {
+ return DT_FLOAT;
+ } else {
+ return e->dst()->input_type(e->dst_input());
+ }
+}
+
+// Return true iff we need to add a same device send/recv for 'edge'.
+bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) {
+ if (edge->IsControlEdge()) {
+ return false;
+ }
+
+ Node* src = edge->src();
+ Node* dst = edge->dst();
+ if (src->assigned_device_name() == dst->assigned_device_name()) {
+ int src_port = edge->src_output();
+ int dst_port = edge->dst_input();
+ if (info.device_types[src->id()] == DEVICE_GPU) {
+ auto src_it = info.output_types.find({src->id(), src_port});
+ DCHECK(src_it != info.output_types.end());
+ auto dst_it = info.input_types.find({dst->id(), dst_port});
+ DCHECK(dst_it != info.input_types.end());
+ return src_it->second != dst_it->second;
+ }
+ }
+ return false;
+}
+
+// Return true iff (dst, dst_input) is specified on host memory.
+bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) {
+ Node* dst = edge->dst();
+ int dst_port = edge->dst_input();
+ if (info.device_types[dst->id()] == DEVICE_GPU) {
+ if (edge->IsControlEdge()) return false;
+ auto dst_it = info.input_types.find({dst->id(), dst_port});
+ DCHECK(dst_it != info.input_types.end());
+ return dst_it->second == HOST_MEMORY;
+ }
+ return true;
+}
+
+// Add an input to dst that comes from the "src_slot" output of the
+// node named by "src_name".
+void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
+ if (src_slot == Graph::kControlSlot) {
+ dst->add_input(strings::StrCat("^", src_name));
+ } else if (src_slot == 0) {
+ dst->add_input(src_name.data(), src_name.size());
+ } else {
+ dst->add_input(strings::StrCat(src_name, ":", src_slot));
+ }
+}
+
+// Add a control edge from each input to each recv.
+void AddReadControl(const std::vector<NodeDef*>& recvs,
+ const std::vector<string>& inputs) {
+ for (NodeDef* recv : recvs) {
+ for (const string& input : inputs) {
+ recv->add_input(strings::StrCat("^", input));
+ }
+ }
+}
+
+void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge,
+ NodeDefBuilder* builder) {
+ builder->Attr("tensor_name",
+ strings::StrCat("edge_", edge->id(), "_", edge->src()->name()));
+ builder->Attr("send_device", edge->src()->assigned_device_name());
+ builder->Attr("send_device_incarnation",
+ static_cast<int64>(
+ opts.get_incarnation(edge->src()->assigned_device_name())));
+ builder->Attr("recv_device", edge->dst()->assigned_device_name());
+ builder->Attr("client_terminated", false);
+}
+
+NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
+ GraphDef* gdef, const Edge* edge,
+ NodeDefBuilder::NodeOut send_from, int64 start_time,
+ Status* status) {
+ const DataType dtype = send_from.data_type;
+ const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype;
+ const Node* src = edge->src();
+ const int src_port = edge->src_output();
+
+ // host_memory = true iff we need to use HostSend/HostCast.
+ bool host_memory = false;
+ if (!edge->IsControlEdge()) {
+ auto src_it = g_info.output_types.find({src->id(), src_port});
+ DCHECK(src_it != g_info.output_types.end());
+ host_memory = (src_it->second == HOST_MEMORY);
+ }
+
+ // Add a cast node that casts dtype to cast_dtype.
+ // NOTE(yuanbyu): Only cast for cross-device send/recv.
+ if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) {
+ const string cast_op = (host_memory) ? "_HostCast" : "Cast";
+ NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op);
+ cast_builder.Device(src->assigned_device_name()).Input(send_from);
+ if (opts.scheduling_for_recvs) {
+ cast_builder.Attr("_start_time", start_time);
+ }
+ cast_builder.Attr("DstT", cast_dtype);
+ NodeDef* cast = gdef->add_node();
+ *status = cast_builder.Finalize(cast);
+ if (!status->ok()) return nullptr;
+
+ // Connect the Send op to the cast.
+ send_from.Reset(cast->name(), 0, cast_dtype);
+ }
+
+ // Add the send node.
+ const string send_op = (host_memory) ? "_HostSend" : "_Send";
+ NodeDefBuilder send_builder(opts.new_name(src->name()), send_op);
+ SetSendRecvAttrs(opts, edge, &send_builder);
+ send_builder.Device(src->assigned_device_name()).Input(send_from);
+ if (opts.scheduling_for_recvs) {
+ send_builder.Attr("_start_time", start_time);
+ }
+ NodeDef* send = gdef->add_node();
+ *status = send_builder.Finalize(send);
+ return send;
+}
+
+NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
+ GraphDef* gdef, const Edge* edge, NodeDef** real_recv,
+ Status* status) {
+ const DataType dtype = EdgeType(edge);
+ const Node* src = edge->src();
+ const Node* dst = edge->dst();
+ const int dst_port = edge->dst_input();
+ DataType cast_dtype = dtype;
+
+ // NOTE(yuanbyu): Only cast for cross-device send/recv.
+ if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) {
+ cast_dtype = opts.should_cast(edge);
+ }
+
+ // host_memory = true iff we need to use HostRecv/HostCast.
+ bool host_memory = false;
+ if (!edge->IsControlEdge()) {
+ auto dst_it = g_info.input_types.find({dst->id(), dst_port});
+ DCHECK(dst_it != g_info.input_types.end());
+ host_memory = (dst_it->second == HOST_MEMORY);
+ }
+
+ // Add the recv node.
+ const string recv_op = (host_memory) ? "_HostRecv" : "_Recv";
+ NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op);
+ SetSendRecvAttrs(opts, edge, &recv_builder);
+ recv_builder.Device(dst->assigned_device_name())
+ .Attr("tensor_type", cast_dtype);
+ NodeDef* recv = gdef->add_node();
+ *status = recv_builder.Finalize(recv);
+ if (!status->ok()) return nullptr;
+ *real_recv = recv;
+
+ // Add the cast node (from cast_dtype to dtype) or an Identity node.
+ if (dtype != cast_dtype) {
+ const string cast_op = (host_memory) ? "_HostCast" : "Cast";
+ NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op);
+ cast_builder.Attr("DstT", dtype);
+ cast_builder.Device(dst->assigned_device_name())
+ .Input(recv->name(), 0, cast_dtype);
+ NodeDef* cast = gdef->add_node();
+ *status = cast_builder.Finalize(cast);
+ if (!status->ok()) return nullptr;
+ return cast;
+ } else if (edge->IsControlEdge()) {
+ // An Identity is only needed for control edges.
+ NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity");
+ id_builder.Device(dst->assigned_device_name())
+ .Input(recv->name(), 0, cast_dtype);
+ NodeDef* id = gdef->add_node();
+ *status = id_builder.Finalize(id);
+ if (!status->ok()) return nullptr;
+ return id;
+ } else {
+ return recv;
+ }
+}
+
+NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef,
+ const Edge* edge, Status* status) {
+ const Node* src = edge->src();
+ Tensor tensor(DT_FLOAT, TensorShape({0}));
+ NodeDef* result = gdef->add_node();
+ *status = NodeDefBuilder(opts.new_name(src->name()), "Const")
+ .Device(src->assigned_device_name())
+ .Attr("dtype", DT_FLOAT)
+ .Attr("value", tensor)
+ .Finalize(result);
+ return result;
+}
+
+// A dummy node for scheduling.
+NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
+ const string& assigned_device_name, int64 epoch,
+ int64 starttime, Status* status) {
+ NodeDef* result = gdef->add_node();
+ *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)),
+ "ControlTrigger")
+ .Device(assigned_device_name)
+ .Attr("_start_time", starttime)
+ .Finalize(result);
+ return result;
+}
+
+// Assign to each node the name of the frame and the level it belongs to.
+// We check the well-formedness of the graph: All inputs to a node must
+// come from the same frame and have the same "static" iteration level.
+// NOTE(yuanbyu): For now, we require all sends/recvs have iteration level
+// 0. This essentially means there can't be multiple serial Nexts in
+// an iteration, which all sane front-ends should satisfy.
+Status BuildControlFlowInfo(Graph* g, std::vector<ControlFlowInfo>* info) {
+ info->clear();
+ info->resize(g->num_node_ids());
+
+ Node* src_node = g->source_node();
+ ControlFlowInfo& src_info = (*info)[src_node->id()];
+ src_info.frame = src_node;
+ src_info.parent_frame = src_node;
+ src_info.iter_level = 0;
+
+ string frame_name;
+ std::deque<const Node*> ready;
+ ready.push_back(src_node);
+ while (!ready.empty()) {
+ const Node* curr_node = ready.front();
+ ready.pop_front();
+ const ControlFlowInfo& curr_info = (*info)[curr_node->id()];
+ const Node* frame = curr_info.frame;
+ const Node* parent = curr_info.parent_frame;
+ frame_name = curr_info.frame_name;
+ int iter_level = curr_info.iter_level;
+
+ if (IsExit(curr_node)) {
+ const ControlFlowInfo& parent_info = (*info)[parent->id()];
+ frame = parent_info.frame;
+ parent = parent_info.parent_frame;
+ frame_name = parent_info.frame_name;
+ iter_level = parent_info.iter_level;
+ }
+
+ for (const Edge* out_edge : curr_node->out_edges()) {
+ const Node* out = out_edge->dst();
+ int out_id = out->id();
+ ControlFlowInfo* out_info = &(*info)[out_id];
+ const Node* out_parent = out_info->parent_frame;
+ bool is_visited = (out_info->iter_level != -1);
+
+ // Skip Sink/Source nodes.
+ if (!out->IsOp()) continue;
+
+ // Add to ready queue if not seen.
+ if (!is_visited) {
+ ready.push_back(out);
+ }
+
+ // Process the node 'out'.
+ if (IsEnter(out)) {
+ if (is_visited) {
+ const string& parent_name = (*info)[out_parent->id()].frame_name;
+ if (parent_name != frame_name || iter_level != out_info->iter_level) {
+ return errors::InvalidArgument(
+ "All inputs to Enter must be from the same frame and level.");
+ }
+ } else {
+ out_info->frame = out;
+ out_info->parent_frame = frame;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(out->def(), "frame_name", &out_info->frame_name));
+ if (out_info->frame_name.empty()) {
+ return errors::InvalidArgument(
+ "Enter must have a non-empty frame name.");
+ }
+ out_info->iter_level = 0;
+ }
+ } else if (IsNextIteration(out)) {
+ if (is_visited) {
+ if (out_info->frame_name != frame_name ||
+ out_info->iter_level != (iter_level + 1)) {
+ return errors::InvalidArgument(
+ "All inputs to NextIteration must be from the same frame "
+ "and level.");
+ }
+ } else {
+ out_info->frame = frame;
+ out_info->parent_frame = parent;
+ out_info->frame_name = frame_name;
+ out_info->iter_level = iter_level + 1;
+ }
+ } else {
+ if (is_visited) {
+ if (out_info->frame_name != frame_name) {
+ return errors::InvalidArgument(
+ "All inputs to a node must be from the same frame.");
+ }
+ } else {
+ out_info->frame = frame;
+ out_info->parent_frame = parent;
+ out_info->frame_name = frame_name;
+ out_info->iter_level = iter_level;
+ }
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+string ControlLoopName(const string& name) {
+ return strings::StrCat("_cloop", name);
+}
+
+bool IsControlLoop(const Node* node) {
+ const string& name = node->def().name();
+ return StringPiece(name).starts_with("_cloop");
+}
+
+// An enter node for control flow.
+Node* AddControlEnter(Graph* g, const string& node_name,
+ const string& device_name, const string& frame_name,
+ const int parallel_iterations, Status* status) {
+ NodeBuilder node_builder(node_name, "Enter", g->op_registry());
+ node_builder.Input({"dummy", 0, DT_FLOAT});
+ node_builder.Attr("frame_name", frame_name);
+ node_builder.Attr("parallel_iterations", parallel_iterations);
+ Node* res_node;
+ *status = node_builder.Finalize(g, &res_node);
+ if (!status->ok()) return nullptr;
+ res_node->set_assigned_device_name(device_name);
+ return res_node;
+}
+
+// A merge node for control flow.
+Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g,
+ const string& node_name, const string& device_name,
+ Status* status) {
+ NodeBuilder node_builder(node_name, "Merge", g->op_registry());
+ node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}});
+ Node* res_node;
+ *status = node_builder.Finalize(g, &res_node);
+ if (!status->ok()) return nullptr;
+ res_node->set_assigned_device_name(device_name);
+ return res_node;
+}
+
+// A switch node for control flow.
+Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2,
+ const string& device_name,
+ const GraphDefBuilder::Options& bopts) {
+ Node* res_node = ops::BinaryOp("Switch", input1, input2, bopts);
+ if (bopts.HaveError()) return nullptr;
+ res_node->set_assigned_device_name(device_name);
+ return res_node;
+}
+
+// A next_iteration node for control flow.
+Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name,
+ const GraphDefBuilder::Options& bopts) {
+ Node* res_node = ops::UnaryOp("NextIteration", input, bopts);
+ if (bopts.HaveError()) return nullptr;
+ res_node->set_assigned_device_name(device_name);
+ return res_node;
+}
+
+Node* EmptyConst(const GraphDefBuilder::Options& options) {
+ if (options.HaveError()) return nullptr;
+ NodeBuilder node_builder(options.GetNameForOp("Const"), "Const",
+ options.op_registry());
+ const DataType dt = DataTypeToEnum<float>::v();
+ TensorProto proto;
+ proto.set_dtype(dt);
+ TensorShape empty_shape({0});
+ empty_shape.AsProto(proto.mutable_tensor_shape());
+ node_builder.Attr("dtype", dt).Attr("value", proto);
+ return options.FinalizeBuilder(&node_builder);
+}
+
+// A dummy const node for control flow.
+Node* AddControlConst(const string& device_name,
+ const GraphDefBuilder::Options& bopts) {
+ Node* res_node = EmptyConst(bopts);
+ if (bopts.HaveError()) return nullptr;
+ res_node->set_assigned_device_name(device_name);
+ return res_node;
+}
+
+// A synthetic loop, made up of dummy nodes. It performs control-flow actions
+// on behalf of a leader on a different device.
+struct ControlLoop {
+ Node* enter = nullptr;
+ Node* merge = nullptr;
+ Node* switch_node = nullptr;
+};
+
+// Add the control flow info of a new node added during partitioning.
+// The new node has the same control flow info as edge->src().
+void AddControlFlowInfo(const Node* node, const Node* src,
+ std::vector<ControlFlowInfo>* cf_info) {
+ int id = node->id();
+ if (static_cast<size_t>(id) >= cf_info->size()) {
+ cf_info->resize(id + 1);
+ }
+ const ControlFlowInfo& src_info = (*cf_info)[src->id()];
+ ControlFlowInfo* info = &(*cf_info)[id];
+ info->frame = src_info.frame;
+ info->parent_frame = src_info.parent_frame;
+ info->frame_name = src_info.frame_name;
+ info->iter_level = src_info.iter_level;
+}
+
+// Constructs a control loop. Returns a struct containing the newly created
+// enter, merge, and switch nodes. The enter and merge nodes are used in the
+// recursive construction of control loops for nested frames (loops). The
+// switch node will be connected to the LoopCond node. The merge node will
+// be connected to all the recvs of the same frame by control edges when
+// the actual partitioning happens.
+Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src,
+ const Edge* edge, Node* loop_cond,
+ std::vector<ControlFlowInfo>* cf_info,
+ ControlLoop* loop) {
+ Status status;
+ GraphDefBuilder::Options bopts(g, &status);
+ const ControlFlowInfo& src_info = (*cf_info)[src->id()];
+ const string& device_name = edge->dst()->assigned_device_name();
+ const string& frame_name = src_info.frame_name;
+ int parallel_iterations;
+ status = GetNodeAttr(src_info.frame->def(), "parallel_iterations",
+ &parallel_iterations);
+ if (!status.ok()) return status;
+
+ // The names of the nodes to be added.
+ const string& enter_name =
+ ControlLoopName(opts.new_name(edge->dst()->name()));
+ const string& merge_name =
+ ControlLoopName(opts.new_name(edge->dst()->name()));
+ const string& switch_name =
+ ControlLoopName(opts.new_name(edge->dst()->name()));
+ const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name()));
+
+ // Add the nodes to the graph g.
+ Node* enter = AddControlEnter(g, enter_name, device_name, frame_name,
+ parallel_iterations, &status);
+ if (!status.ok()) return status;
+ Node* merge = AddControlMerge(enter_name, next_name, g, merge_name,
+ device_name, &status);
+ if (!status.ok()) return status;
+ Node* switch_node = AddControlSwitch(merge, loop_cond, device_name,
+ bopts.WithName(switch_name));
+ if (!status.ok()) return status;
+ Node* next =
+ AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name));
+ if (!status.ok()) return status;
+
+ // Add control flow info for these new nodes:
+ AddControlFlowInfo(enter, src, cf_info);
+ AddControlFlowInfo(merge, src, cf_info);
+ AddControlFlowInfo(switch_node, src, cf_info);
+ AddControlFlowInfo(next, src, cf_info);
+
+ // Add input edges for the newly created merge node:
+ g->AddEdge(enter, 0, merge, 0);
+ g->AddEdge(next, 0, merge, 1);
+
+ loop->enter = enter;
+ loop->merge = merge;
+ loop->switch_node = switch_node;
+ return Status::OK();
+}
+
+// Build memory and device type info for every node in the graph.
+// TODO(yuanbyu): It might be simpler if we convert MemoryType to
+// DeviceType for the inputs/outputs of each node.
+Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) {
+ Status status;
+ MemoryTypeVector input_memory_types;
+ MemoryTypeVector output_memory_types;
+
+ info->device_types.resize(g.num_node_ids(), DEVICE_CPU);
+ for (const Node* node : g.nodes()) {
+ if (!node->IsOp()) continue; // Skip Sink/Source nodes.
+
+ DeviceNameUtils::ParsedName parsed;
+ if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(),
+ &parsed)) {
+ return errors::Internal("Malformed assigned device '",
+ node->assigned_device_name(), "'");
+ }
+
+ input_memory_types.clear();
+ input_memory_types.resize(node->num_inputs());
+ output_memory_types.clear();
+ output_memory_types.resize(node->num_outputs());
+ status = MemoryTypesForNode(g.op_registry(), DeviceType(parsed.type),
+ node->def(), &input_memory_types,
+ &output_memory_types);
+ if (!status.ok()) return status;
+
+ int node_id = node->id();
+ info->device_types[node_id] = DeviceType(parsed.type);
+ for (size_t i = 0; i < input_memory_types.size(); ++i) {
+ info->input_types[{node_id, i}] = input_memory_types[i];
+ }
+ for (size_t i = 0; i < output_memory_types.size(); ++i) {
+ info->output_types[{node_id, i}] = output_memory_types[i];
+ }
+ }
+ return status;
+}
+
+// Each participating device needs to decide a) if there is a next iteration,
+// and b) if the loop terminates. We take the approach to encode this control
+// flow logic in the dataflow graph. There are at least two possible encodings.
+// In a completely decentralized encoding, the participants communicate peer
+// to peer. The other encoding uses a frame leader (the participant who owns
+// the pivot termination predicate) to broadcast the termination condition to
+// all the participants. For now we take the latter because it is simpler.
+//
+// TODO(yuanbyu): The correctness of this construction is rather subtle. I got
+// it wrong many times so it would be nice to write a proof to be sure.
+Status AddControlFlow(const PartitionOptions& opts, Graph* g,
+ GraphInfo* g_info) {
+ Status status;
+ GraphDefBuilder::Options bopts(g, &status);
+ std::vector<ControlFlowInfo>& cf_info = g_info->cf_info;
+
+ // Build the control flow info for every node.
+ status = BuildControlFlowInfo(g, &cf_info);
+ if (!status.ok()) return status;
+
+ // The map from frames to their LoopCond nodes.
+ std::unordered_map<string, Node*> frame_cond_map;
+ int num_node_ids = g->num_node_ids();
+ for (int i = 0; i < num_node_ids; ++i) {
+ Node* node = g->FindNodeId(i);
+ if (node == nullptr) continue;
+
+ if (IsLoopCond(node)) {
+ const string& frame_name = cf_info[node->id()].frame_name;
+ DCHECK(!frame_name.empty());
+ frame_cond_map[frame_name] = node;
+ }
+ }
+
+ // Add all control loops for cross-device frames.
+ // A control loop is added only when there is a cross-device edge in a
+ // non-root frame. Nothing is added if there is no loops. We also don't
+ // add anything for a frame that is completely local to a device. For
+ // nested loops, we stack the control loops together by connecting
+ // the merge of the outer loop to the enter of the inner loop.
+ //
+ // A map from <frame_name, device_name> to ControlLoop.
+ std::unordered_map<string, ControlLoop> control_loops;
+ int num_edge_ids = g->num_edge_ids();
+ for (int i = 0; i < num_edge_ids; ++i) {
+ const Edge* edge = g->FindEdgeId(i);
+ if (edge == nullptr) continue;
+
+ const Node* src = edge->src();
+ const Node* dst = edge->dst();
+ // Skip Sink/Source nodes.
+ if (!src->IsOp() || !dst->IsOp()) continue;
+
+ const string& src_device = src->assigned_device_name();
+ const string& dst_device = dst->assigned_device_name();
+ // Skip local edges.
+ if (src_device == dst_device) continue;
+
+ const string& src_frame = cf_info[src->id()].frame_name;
+ const string& dst_frame = cf_info[dst->id()].frame_name;
+ // Skip if src and dst are not in the same frame.
+ if (src_frame.empty() || src_frame != dst_frame) {
+ continue;
+ }
+
+ // Add the control loop. Start by adding the control loop for the
+ // current frame if needed, and recursively adding the control loop
+ // for its outer frame when nested.
+ ControlLoop child_loop;
+ while (true) {
+ const string& curr_frame = cf_info[src->id()].frame_name;
+ if (curr_frame.empty()) {
+ // We have reached the root frame.
+ if (child_loop.merge != nullptr) {
+ const string& node_name = opts.new_name(edge->dst()->name());
+ const string& device_name = edge->dst()->assigned_device_name();
+ Node* const_node =
+ AddControlConst(device_name, bopts.WithName(node_name));
+ if (!status.ok()) return status;
+ AddControlFlowInfo(const_node, src, &cf_info);
+ g->AddEdge(const_node, 0, child_loop.enter, 0);
+ }
+ break;
+ }
+
+ const string& cl_key = strings::StrCat(curr_frame, "$$", dst_device);
+ auto it = control_loops.find(cl_key);
+ if (it != control_loops.end()) {
+ if (child_loop.enter != nullptr) {
+ g->AddEdge(it->second.merge, 0, child_loop.enter, 0);
+ }
+ break;
+ }
+
+ // Get the frame's LoopCond.
+ auto cond_it = frame_cond_map.find(curr_frame);
+ if (cond_it == frame_cond_map.end()) {
+ return errors::InvalidArgument(
+ "A cross-device loop must have a pivot predicate: ", curr_frame);
+ }
+ Node* loop_cond = cond_it->second;
+
+ // Add the control loop.
+ ControlLoop curr_loop;
+ status =
+ AddControlLoop(opts, g, src, edge, loop_cond, &cf_info, &curr_loop);
+ if (!status.ok()) return status;
+ control_loops[cl_key] = curr_loop;
+
+ if (child_loop.enter != nullptr) {
+ // Connect the merge of the outer loop to the enter of the inner.
+ g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0);
+ }
+ src = cf_info[src->id()].parent_frame;
+ child_loop = curr_loop;
+ }
+ }
+
+ // For a cross-device edge, on the dst device, add a control edge
+ // from the merge node of the control loop to dst. If a send/recv is
+ // introduced for this edge in future partitioning, we delete this
+ // control edge and add a new control edge from the merge to the recv.
+ num_edge_ids = g->num_edge_ids();
+ for (int i = 0; i < num_edge_ids; ++i) {
+ const Edge* edge = g->FindEdgeId(i);
+ if (edge == nullptr) continue;
+
+ const Node* src = edge->src();
+ Node* dst = edge->dst();
+ // Skip Sink/Source nodes.
+ if (!src->IsOp() || !dst->IsOp()) continue;
+
+ const string& src_device = src->assigned_device_name();
+ const string& dst_device = dst->assigned_device_name();
+ if (src_device != dst_device) {
+ const string& src_frame = cf_info[src->id()].frame_name;
+ const string& dst_frame = cf_info[dst->id()].frame_name;
+ if (!src_frame.empty() && src_frame == dst_frame) {
+ const string& cl_key = strings::StrCat(dst_frame, "$$", dst_device);
+ ControlLoop loop = control_loops[cl_key];
+ DCHECK(loop.enter != nullptr);
+ g->AddControlEdge(loop.merge, dst);
+ }
+ }
+ }
+ return Status::OK();
+}
+
+} // end namespace
+
+Status AddControlEdges(const PartitionOptions& opts,
+ std::unordered_map<string, GraphDef>* partitions) {
+ Status status;
+ // TODO(yuanbyu): Very naive for now. To be improved.
+ const int num_epochs = 100;
+ const int prefetch = 6;
+
+ typedef std::pair<const NodeDef*, int64> NodeStartTime;
+ for (auto& part : *partitions) {
+ GraphDef* gdef = &part.second;
+
+ std::vector<NodeStartTime> start_times;
+ start_times.resize(gdef->node_size());
+ for (int n = 0; n < gdef->node_size(); ++n) {
+ const NodeDef& ndef = gdef->node(n);
+ int64 start_time;
+ status = GetNodeAttr(ndef, "_start_time", &start_time);
+ if (!status.ok()) {
+ return status;
+ }
+ start_times[n] = std::make_pair(&ndef, start_time);
+ }
+
+ // Sort the nodes based on their start times.
+ std::sort(
+ start_times.begin(), start_times.end(),
+ [](NodeStartTime x, NodeStartTime y) { return x.second < y.second; });
+
+ // Add a dummy node for every epoch, and add a control edge from the
+ // "last" node in the preceding epoch to the dummy node.
+ string device_name = gdef->node(0).device();
+ int64 makespan = start_times.back().second;
+ int64 resolution = (makespan / num_epochs) + 1;
+
+ int i = 0;
+ int j = 0;
+ std::vector<NodeDef*> dummys;
+ while (i < num_epochs && static_cast<size_t>(j) < start_times.size()) {
+ if (i * resolution > start_times[j].second) {
+ j++;
+ } else {
+ NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i,
+ i * resolution, &status);
+ if (!status.ok()) {
+ return status;
+ }
+ dummys.push_back(dummy);
+ if (j > 0) {
+ string src_name = start_times[j - 1].first->name();
+ AddInput(dummy, src_name, Graph::kControlSlot);
+ }
+ i++;
+ }
+ }
+
+ // Finally, add the control edges to recvs.
+ for (int n = 0; n < gdef->node_size(); ++n) {
+ NodeDef* ndef = gdef->mutable_node(n);
+ if (ndef->op() == "_Recv") {
+ int64 start_time;
+ status = GetNodeAttr(*ndef, "_start_time", &start_time);
+ if (!status.ok()) {
+ return status;
+ }
+ int recv_epoch = start_time / resolution;
+ if (recv_epoch >= prefetch) {
+ NodeDef* dummy = dummys[recv_epoch - prefetch];
+ AddInput(ndef, dummy->name(), Graph::kControlSlot);
+ }
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status Partition(const PartitionOptions& opts, Graph* g,
+ std::unordered_map<string, GraphDef>* partitions) {
+ Status status;
+ partitions->clear();
+
+ GraphInfo g_info;
+ if (!opts.control_flow_added) {
+ // Add the "code" for distributed execution of control flow. Code is
+ // added only for the frames that are placed on multiple devices. The
+ // new graph is an equivalent transformation of the original graph and
+ // has the property that it can be subsequently partitioned arbitrarily
+ // (down to the level of individual device) for distributed execution.
+ status = AddControlFlow(opts, g, &g_info);
+ if (!status.ok()) return status;
+ }
+ // At this point, all the graph mutations have been done. Build memory
+ // and device type info for every node and edge in the graph.
+ status = BuildMemoryDeviceInfo(*g, &g_info);
+ if (!status.ok()) return status;
+
+ string dstp;
+ std::vector<const Edge*> inputs;
+ DupRecvTable dup_recv(3);
+ // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref
+ // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref
+ // edge to dst. We will add a control edge for every pair in
+ // (ref_recvs x ref_control_inputs).
+ std::vector<NodeDef*> ref_recvs;
+ std::vector<string> ref_control_inputs;
+
+ int32 num_data = 0;
+ int32 num_control = 0;
+ for (const Node* dst : g->nodes()) {
+ if (!dst->IsOp()) continue; // Skip Sink/Source nodes.
+
+ dstp = opts.node_to_loc(dst);
+ GraphDef* dst_graph = &(*partitions)[dstp];
+ NodeDef* dst_def = dst_graph->add_node();
+ *dst_def = dst->def();
+ dst_def->set_device(dst->assigned_device_name());
+ dst_def->clear_input(); // Inputs are filled below
+ if (opts.need_to_record_start_times) {
+ int64 start_time = opts.start_times[dst->id()].value();
+ AddNodeAttr("_start_time", start_time, dst_def);
+ }
+
+ // Arrange the incoming edges to dst so that input[i] holds the
+ // input flowing into slot numbered i. Trailing entries in input[]
+ // hold control edges.
+ inputs.clear();
+ inputs.resize(dst->num_inputs(), nullptr);
+ ref_recvs.clear();
+ ref_control_inputs.clear();
+ const Edge* control_flow_edge = nullptr;
+ for (const Edge* edge : dst->in_edges()) {
+ if (edge->IsControlEdge()) {
+ if (IsMerge(edge->src()) && IsControlLoop(edge->src())) {
+ // This is one of the control edges added for control flow. There
+ // can be multiple such edges as the dest node may have multiple
+ // remote inputs. We will just take one and ignore the others.
+ control_flow_edge = edge;
+ } else {
+ inputs.push_back(edge);
+ }
+ } else {
+ DCHECK(inputs[edge->dst_input()] == nullptr);
+ inputs[edge->dst_input()] = edge;
+ }
+ }
+
+ // Process in order so that all data edges are added as inputs to
+ // dst in Edge::dst_input() order.
+ bool recv_added = false;
+ for (const Edge* edge : inputs) {
+ const Node* src = edge->src();
+ if (!src->IsOp()) continue; // Skip Sink/Source nodes.
+
+ GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)];
+ if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) {
+ // Same partition and compatible memory types:
+ AddInput(dst_def, src->name(), edge->src_output());
+ if (edge->IsControlEdge() ||
+ !IsRefType(src->output_type(edge->src_output()))) {
+ ref_control_inputs.push_back(src->name());
+ }
+ continue;
+ }
+
+ int64 send_start_time = 0;
+ int64 recv_start_time = 0;
+ if (opts.scheduling_for_recvs) {
+ if (opts.need_to_record_start_times) {
+ send_start_time = opts.start_times[src->id()].value();
+ recv_start_time = opts.start_times[dst->id()].value();
+ } else {
+ status = GetNodeAttr(src->def(), "_start_time", &send_start_time);
+ if (!status.ok()) {
+ return status;
+ }
+ status = GetNodeAttr(dst->def(), "_start_time", &recv_start_time);
+ if (!status.ok()) {
+ return status;
+ }
+ }
+ }
+
+ // Check whether there is already a send/recv pair transferring
+ // the same tensor/control from the src to dst partition.
+ const bool on_host = IsDstInputOnHost(edge, g_info);
+ DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
+ auto iter = dup_recv.find(key);
+ if (iter != dup_recv.end()) {
+ // We found one. Reuse the data/control transferred already.
+ const string& recv_node_name = iter->second.recv->name();
+ if (edge->IsControlEdge()) {
+ AddInput(dst_def, recv_node_name, Graph::kControlSlot);
+ } else {
+ AddInput(dst_def, recv_node_name, 0);
+ }
+ // We want the start_time for the recv to be the smallest of the start
+ // times of it's consumers. So we update this whenever we use a recv,
+ // and write it out to the attribute at the end of the subroutine
+ if (iter->second.start_time > recv_start_time) {
+ iter->second.start_time = recv_start_time;
+ }
+ continue;
+ }
+
+ NodeDefBuilder::NodeOut send_from;
+ if (edge->IsControlEdge()) {
+ // Insert a dummy const node that will generate a tiny
+ // data element to be sent from send to recv.
+ VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "["
+ << src->name() << "] -> " << dst->assigned_device_name() << "["
+ << dst->name() << "]";
+ NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status);
+ if (!status.ok()) return status;
+ // Set the start time for this dummy node.
+ if (opts.scheduling_for_recvs) {
+ AddNodeAttr("_start_time", send_start_time, dummy);
+ }
+ AddInput(dummy, src->name(), Graph::kControlSlot);
+ send_from.Reset(dummy->name(), 0, DT_FLOAT);
+ } else {
+ send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
+ }
+
+ // Need to split edge by placing matching send/recv nodes on
+ // the src/dst sides of the edge.
+ NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from,
+ send_start_time, &status);
+ if (!status.ok()) return status;
+
+ NodeDef* real_recv = nullptr;
+ NodeDef* recv =
+ AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status);
+ if (!status.ok()) return status;
+
+ // Fix up the control flow edge. Redirect it to the recv.
+ // NOTE(yuanbyu): 'real_recv' must be the real recv node.
+ recv_added = true;
+ if (control_flow_edge != nullptr) {
+ AddInput(real_recv, control_flow_edge->src()->name(),
+ Graph::kControlSlot);
+ }
+
+ // For same device send/recv, add a control edge from send to recv.
+ // This prevents the asynchronous recv kernel from being scheduled
+ // immediately.
+ if (src_graph == dst_graph) {
+ AddInput(real_recv, send->name(), Graph::kControlSlot);
+ }
+
+ if (!edge->IsControlEdge() &&
+ IsRefType(src->output_type(edge->src_output()))) {
+ // If src is of ref type and the edge is not a control edge, dst has
+ // read semantics and therefore we must control the recv.
+ ref_recvs.push_back(real_recv);
+ } else {
+ // Memorize the send/recv pair, only if this is not a "ref" edge.
+ // NOTE(yuanbyu): Collapsing ref edges requires extreme care so
+ // for now we don't do it.
+ dup_recv[key] = {recv, real_recv, recv_start_time};
+ ref_control_inputs.push_back(recv->name());
+ }
+
+ if (edge->IsControlEdge()) {
+ ++num_control;
+ AddInput(dst_def, recv->name(), Graph::kControlSlot);
+ } else {
+ ++num_data;
+ AddInput(dst_def, recv->name(), 0);
+ }
+ }
+
+ // Add control edges from 'ref_control_inputs' to 'ref_recvs'.
+ // NOTE(yuanbyu): Adding these control edges should not introduce
+ // deadlocks. 'dst' has implicit "read" nodes that, when we split
+ // across devices, are made explicit; Retargettig the dependencies
+ // to 'dst' to those nodes would not introduce cycles if there isn't
+ // one before the transformation.
+ // NOTE(yuanbyu): This may impact performance because it defers the
+ // execution of recvs until all the other inputs become available.
+ AddReadControl(ref_recvs, ref_control_inputs);
+
+ // Add back this control edge for control flow if not used.
+ if (!recv_added && (control_flow_edge != nullptr)) {
+ AddInput(dst_def, control_flow_edge->src()->name(), Graph::kControlSlot);
+ }
+ }
+
+ // Set the start times for recvs at the very end.
+ if (opts.scheduling_for_recvs) {
+ for (auto& it : dup_recv) {
+ AddNodeAttr("_start_time", it.second.start_time, it.second.recv);
+ if (it.second.real_recv != it.second.recv) {
+ AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv);
+ }
+ }
+ }
+
+ VLOG(1) << "Added send/recv: controls=" << num_control
+ << ", data=" << num_data;
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph_partition.h b/tensorflow/core/graph/graph_partition.h
new file mode 100644
index 0000000000..eb88ff71b1
--- /dev/null
+++ b/tensorflow/core/graph/graph_partition.h
@@ -0,0 +1,77 @@
+#ifndef TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
+#define TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
+
+#include <functional>
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/costmodel.h"
+
+namespace tensorflow {
+
+struct PartitionOptions {
+ // A function that returns a location for the execution of a given
+ // Node.
+ typedef std::function<string(const Node*)> NodeToLocFunc;
+ NodeToLocFunc node_to_loc = nullptr;
+
+ // A function that returns a unique graph node name with the given
+ // prefix.
+ typedef std::function<string(const string&)> NewNameFunc;
+ NewNameFunc new_name = nullptr;
+
+ // A function that returns the incarnation of a device given the
+ // device's fullname. If not found, GetIncarnationFunc should return
+ // kIlledgalIncarnation.
+ static const uint64 kIllegalIncarnation = 0;
+ typedef std::function<uint64(const string&)> GetIncarnationFunc;
+ GetIncarnationFunc get_incarnation = nullptr;
+
+ // True if all the control flow "code" has already been added. The
+ // control flow code needs to be added when we still have the entire
+ // graph before any partitioning. So this flag should be false for
+ // the first partitioning but true for all subsequent partitioning.
+ //
+ // TODO(yuanbyu): We could also make the addition of the control
+ // flow code incremental based on 'node_to_loc'. This makes the
+ // communication a broadcast tree, which could be more efficient when
+ // the number of participating devices is large.
+ bool control_flow_added;
+
+ // A function that returns the data type into which the tensor
+ // should be cast before sent over the wire.
+ typedef std::function<DataType(const Edge*)> ShouldCastFunc;
+ ShouldCastFunc should_cast = nullptr;
+
+ // Schedule the execution of the recvs based on their start times
+ // computed by some scheduling algorithm. The recvs are divided into
+ // epochs based on their start times. A recv is enabled only when
+ // execution reaches its epoch - N for some predefined N.
+ bool scheduling_for_recvs = false;
+ // The start time for each node in the graph computed by some scheduling
+ // algorithm. If 'need_to_record_start_times' is true, we record them
+ // in the graph as a node attribute.
+ bool need_to_record_start_times = false;
+ std::vector<Microseconds> start_times;
+};
+
+// Partition "input" graph into a set of graphs, one per location.
+// The location for node n is derived by calling opts.node_to_loc(n).
+// New nodes added by Partition use "opts.new_name(old_name)" to
+// generate node names.
+//
+// Stores the partitions in *partitions.
+Status Partition(const PartitionOptions& opts, Graph* input,
+ std::unordered_map<string, GraphDef>* partitions);
+
+// Add control edges to the partitions to control the ordering
+// and timing of the recv nodes based on the start times calculated
+// using some scheduling algorithm.
+Status AddControlEdges(const PartitionOptions& opts,
+ std::unordered_map<string, GraphDef>* partitions);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc
new file mode 100644
index 0000000000..d912c94025
--- /dev/null
+++ b/tensorflow/core/graph/graph_partition_test.cc
@@ -0,0 +1,316 @@
+#include "tensorflow/core/graph/graph_partition.h"
+
+#include <unordered_map>
+
+#include <gtest/gtest.h>
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/control_flow_ops.h"
+#include "tensorflow/cc/ops/random_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/equal_graph_def.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace {
+
+const char gpu_device[] = "/job:a/replica:0/task:0/gpu:0";
+
+string SplitByDevice(const Node* node) { return node->assigned_device_name(); }
+
+string DeviceName(const Node* node) {
+ char first = node->name()[0];
+ if (first == 'G') {
+ return gpu_device;
+ } else {
+ const string cpu_prefix = "/job:a/replica:0/task:0/cpu:";
+ int index = first - 'A';
+ return strings::StrCat(cpu_prefix, index);
+ }
+}
+
+void Partition(const GraphDef& graph_def,
+ std::unordered_map<string, GraphDef>* partitions) {
+ Graph g(OpRegistry::Global());
+ GraphConstructorOptions opts;
+ TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &g));
+
+ // Assigns devices to each node. Uses 1st letter of the node name as
+ // the device index.
+ for (Node* node : g.nodes()) {
+ node->set_assigned_device_name(DeviceName(node));
+ }
+
+ PartitionOptions popts;
+ popts.node_to_loc = SplitByDevice;
+ popts.new_name = [&g](const string& prefix) { return g.NewName(prefix); };
+ popts.get_incarnation = [](const string& name) {
+ return (name[0] - 'A') + 100;
+ };
+ popts.control_flow_added = false;
+ Status s = Partition(popts, &g, partitions);
+ CHECK(s.ok()) << s;
+}
+
+void CheckLoopConstruction(const GraphDef& graph_def) {
+ std::unordered_map<string, GraphDef> partitions;
+ Partition(graph_def, &partitions);
+ GraphConstructorOptions opts;
+ for (const auto& kv : partitions) {
+ const GraphDef& gdef = kv.second;
+ bool has_control_enter = false;
+ bool has_control_merge = false;
+ bool has_control_switch = false;
+ bool has_control_next = false;
+ for (const NodeDef& ndef : gdef.node()) {
+ // _recvs must have a control input
+ if (ndef.op() == "_Recv") {
+ bool has_control = false;
+ for (const string& input_name : ndef.input()) {
+ if (StringPiece(input_name).starts_with("^")) {
+ has_control = true;
+ break;
+ }
+ }
+ EXPECT_TRUE(has_control);
+ }
+ // Must have a control loop
+ if (StringPiece(ndef.name()).starts_with("_cloop")) {
+ if (ndef.op() == "Enter") {
+ has_control_enter = true;
+ }
+ if (ndef.op() == "Merge") {
+ has_control_merge = true;
+ }
+ if (ndef.op() == "Switch") {
+ has_control_switch = true;
+ }
+ if (ndef.op() == "NextIteration") {
+ has_control_next = true;
+ }
+ }
+ }
+ EXPECT_TRUE(has_control_enter);
+ EXPECT_TRUE(has_control_merge);
+ EXPECT_TRUE(has_control_switch);
+ EXPECT_TRUE(has_control_next);
+ }
+}
+
+REGISTER_OP("Input").Output("o: float");
+REGISTER_OP("BoolInput").Output("o: bool");
+REGISTER_OP("Cross").Input("a: float").Input("b: float").Output("o: float");
+
+Node* Input(const GraphDefBuilder::Options& opts) {
+ return ops::SourceOp("Input", opts);
+}
+
+Node* BoolInput(const GraphDefBuilder::Options& opts) {
+ return ops::SourceOp("BoolInput", opts);
+}
+
+Node* Cross(ops::NodeOut a, ops::NodeOut b,
+ const GraphDefBuilder::Options& opts) {
+ return ops::BinaryOp("Cross", a, b, opts);
+}
+
+class GraphPartitionTest : public ::testing::Test {
+ protected:
+ GraphPartitionTest()
+ : in_(GraphDefBuilder::kFailImmediately),
+ builder_a_(GraphDefBuilder::kFailImmediately),
+ builder_b_(GraphDefBuilder::kFailImmediately),
+ a_opts_(builder_a_.opts().WithDevice("/job:a/replica:0/task:0/cpu:0")),
+ b_opts_(builder_b_.opts().WithDevice("/job:a/replica:0/task:0/cpu:1")) {
+ RequireDefaultOps();
+ }
+
+ const GraphDef& ToGraphDef() {
+ in_.ToGraphDef(&in_graph_def_);
+ return in_graph_def_;
+ }
+
+ void ExpectMatchA() {
+ GraphDef graph_def;
+ builder_a_.ToGraphDef(&graph_def);
+ string a = "/job:a/replica:0/task:0/cpu:0";
+ TF_EXPECT_GRAPH_EQ(graph_def, partitions_[a]);
+ }
+
+ void ExpectMatchB() {
+ GraphDef graph_def;
+ builder_b_.ToGraphDef(&graph_def);
+ string b = "/job:a/replica:0/task:0/cpu:1";
+ TF_EXPECT_GRAPH_EQ(graph_def, partitions_[b]);
+ }
+
+ GraphDefBuilder in_;
+ GraphDef in_graph_def_;
+ GraphDefBuilder builder_a_;
+ GraphDefBuilder builder_b_;
+ GraphDefBuilder::Options a_opts_;
+ GraphDefBuilder::Options b_opts_;
+ std::unordered_map<string, GraphDef> partitions_;
+};
+
+TEST_F(GraphPartitionTest, SingleDevice) {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Node* a1 = Input(in_.opts().WithName("A1"));
+ Cross(a1, a1, in_.opts().WithName("A2"));
+
+ Partition(ToGraphDef(), &partitions_);
+ EXPECT_EQ(1, partitions_.size());
+
+ a1 = Input(a_opts_.WithName("A1"));
+ Cross(a1, a1, a_opts_.WithName("A2"));
+ ExpectMatchA();
+}
+
+TEST_F(GraphPartitionTest, CrossDeviceData) {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Node* a1 = Input(in_.opts().WithName("A1"));
+ Node* b1 = Input(in_.opts().WithName("B1"));
+ Cross(a1, b1, in_.opts().WithName("B2"));
+
+ Partition(ToGraphDef(), &partitions_);
+ EXPECT_EQ(2, partitions_.size());
+
+ string a = "/job:a/replica:0/task:0/cpu:0";
+ string b = "/job:a/replica:0/task:0/cpu:1";
+ a1 = Input(a_opts_.WithName("A1"));
+ _Send(a1, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_0"));
+ ExpectMatchA();
+
+ b1 = Input(b_opts_.WithName("B1"));
+ Node* recv =
+ _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_1"));
+ Cross(recv, b1, b_opts_.WithName("B2"));
+ ExpectMatchB();
+}
+
+TEST_F(GraphPartitionTest, CrossDeviceControl) {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Node* a1 = Input(in_.opts().WithName("A1"));
+ Node* b1 = Input(in_.opts().WithName("B1"));
+ Cross(b1, b1, in_.opts().WithName("B2").WithControlInput(a1));
+
+ Partition(ToGraphDef(), &partitions_);
+ EXPECT_EQ(2, partitions_.size());
+
+ string a = "/job:a/replica:0/task:0/cpu:0";
+ string b = "/job:a/replica:0/task:0/cpu:1";
+ a1 = Input(a_opts_.WithName("A1"));
+ Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1));
+ _Send(c, "edge_3_A1", a, 82, b, a_opts_.WithName("A1/_1"));
+ ExpectMatchA();
+
+ Node* recv =
+ _Recv(DT_FLOAT, "edge_3_A1", a, 82, b, b_opts_.WithName("A1/_2"));
+ Node* id = Identity(recv, b_opts_.WithName("A1/_3"));
+ b1 = Input(b_opts_.WithName("B1"));
+ Cross(b1, b1, b_opts_.WithName("B2").WithControlInput(id));
+ ExpectMatchB();
+}
+
+TEST_F(GraphPartitionTest, CrossDeviceData_MultiUse) {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Node* a1 = Input(in_.opts().WithName("A1"));
+ Node* b1 = Input(in_.opts().WithName("B1"));
+ Cross(a1, b1, in_.opts().WithName("B2"));
+ Cross(a1, a1, in_.opts().WithName("B3"));
+
+ Partition(ToGraphDef(), &partitions_);
+ EXPECT_EQ(2, partitions_.size());
+
+ string a = "/job:a/replica:0/task:0/cpu:0";
+ string b = "/job:a/replica:0/task:0/cpu:1";
+ a1 = Input(a_opts_.WithName("A1"));
+ _Send(a1, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_0"));
+ ExpectMatchA();
+
+ Node* recv =
+ _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_1"));
+ b1 = Input(b_opts_.WithName("B1"));
+ Cross(recv, b1, b_opts_.WithName("B2"));
+ Cross(recv, recv, b_opts_.WithName("B3"));
+ ExpectMatchB();
+}
+
+TEST_F(GraphPartitionTest, CrossDeviceControl_MultiUse) {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Node* a1 = Input(in_.opts().WithName("A1"));
+ Node* b1 = Input(in_.opts().WithName("B1"));
+ Cross(b1, b1, in_.opts().WithName("B2").WithControlInput(a1));
+ Input(in_.opts().WithName("B3").WithControlInput(a1));
+
+ Partition(ToGraphDef(), &partitions_);
+ EXPECT_EQ(2, partitions_.size());
+
+ string a = "/job:a/replica:0/task:0/cpu:0";
+ string b = "/job:a/replica:0/task:0/cpu:1";
+ a1 = Input(a_opts_.WithName("A1"));
+ Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1));
+ _Send(c, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_1"));
+ ExpectMatchA();
+
+ Node* recv =
+ _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_2"));
+ Node* id = Identity(recv, b_opts_.WithName("A1/_3"));
+ b1 = Input(b_opts_.WithName("B1"));
+ Cross(b1, b1, b_opts_.WithName("B2").WithControlInput(id));
+ Input(b_opts_.WithName("B3").WithControlInput(id));
+ ExpectMatchB();
+}
+
+TEST_F(GraphPartitionTest, CrossDevice_DataControl) {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Node* a1 = Input(in_.opts().WithName("A1"));
+ Node* b1 = Input(in_.opts().WithName("B1"));
+ Cross(a1, b1, in_.opts().WithName("B2"));
+ Input(in_.opts().WithName("B3").WithControlInput(a1));
+
+ Partition(ToGraphDef(), &partitions_);
+ EXPECT_EQ(2, partitions_.size());
+
+ string a = "/job:a/replica:0/task:0/cpu:0";
+ string b = "/job:a/replica:0/task:0/cpu:1";
+ a1 = Input(a_opts_.WithName("A1"));
+ Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1));
+ // NOTE: Send 0 A1/_1 -> A1/_2 is not necessarily needed. We could
+ // use A1/_0 -> A1/_4 as the control as a minor optimization.
+ _Send(c, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_1"));
+ _Send(a1, "edge_2_A1", a, 82, b, a_opts_.WithName("A1/_4"));
+ ExpectMatchA();
+
+ Node* recv1 =
+ _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_2"));
+ Node* id1 = Identity(recv1, b_opts_.WithName("A1/_3"));
+ Node* recv2 =
+ _Recv(DT_FLOAT, "edge_2_A1", a, 82, b, b_opts_.WithName("A1/_5"));
+ b1 = Input(b_opts_.WithName("B1"));
+ Cross(recv2, b1, b_opts_.WithName("B2"));
+ Input(b_opts_.WithName("B3").WithControlInput(id1));
+ ExpectMatchB();
+}
+
+TEST_F(GraphPartitionTest, CrossDeviceLoop) {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ Node* a1 = BoolInput(in_.opts().WithName("A1"));
+ Node* a2 = Enter(a1, "foo", in_.opts().WithName("A2"));
+ Node* a3 = Merge({a2, {"A5", 0, DT_BOOL}}, in_.opts().WithName("A3"));
+ LoopCond(a3, in_.opts().WithName("A4"));
+ Node* b1 = Identity(a3, in_.opts().WithName("B1"));
+ NextIteration(b1, in_.opts().WithName("A5"));
+
+ CheckLoopConstruction(ToGraphDef());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc
new file mode 100644
index 0000000000..f7a8ffde89
--- /dev/null
+++ b/tensorflow/core/graph/graph_test.cc
@@ -0,0 +1,252 @@
+#include "tensorflow/core/graph/graph.h"
+
+#include <set>
+#include <gtest/gtest.h>
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+class GraphTest : public ::testing::Test {
+ protected:
+ GraphTest() : graph_(OpRegistry::Global()) { RequireDefaultOps(); }
+ ~GraphTest() override {}
+
+ static void VerifyNodes(Node* node, std::vector<Node*> expected_in,
+ std::vector<Node*> expected_out) {
+ std::vector<Node*> in;
+ for (const Edge* e : node->in_edges()) {
+ in.push_back(e->src());
+ }
+ EXPECT_EQ(Stringify(expected_in), Stringify(in));
+
+ std::vector<Node*> out;
+ for (const Edge* e : node->out_edges()) {
+ out.push_back(e->dst());
+ }
+ EXPECT_EQ(Stringify(expected_out), Stringify(out));
+ }
+
+ Node* AddNodeWithName(const string& name) {
+ Node* node;
+ TF_CHECK_OK(NodeBuilder(name, "NoOp").Finalize(&graph_, &node));
+ return node;
+ }
+
+ Graph graph_;
+
+ private:
+ // Convert a list of nodes to a sorted list of strings so failure messages
+ // are readable.
+ static std::vector<string> Stringify(const std::vector<Node*>& nodes) {
+ std::vector<string> result;
+ for (Node* n : nodes) {
+ result.push_back(n->DebugString());
+ }
+ std::sort(result.begin(), result.end());
+ return result;
+ }
+};
+
+TEST_F(GraphTest, Constructor) {
+ Node* source = graph_.source_node();
+ EXPECT_NE(source, nullptr);
+ Node* sink = graph_.sink_node();
+ EXPECT_NE(sink, nullptr);
+ VerifyNodes(source, {}, {sink});
+ VerifyNodes(sink, {source}, {});
+ EXPECT_EQ(2, graph_.num_node_ids());
+}
+
+TEST_F(GraphTest, RemoveThenAdd) {
+ AddNodeWithName("A");
+ Node* b = AddNodeWithName("B");
+ const int b_id = b->id();
+ AddNodeWithName("C");
+ EXPECT_EQ(5, graph_.num_node_ids());
+ graph_.RemoveNode(b);
+ EXPECT_EQ(5, graph_.num_node_ids());
+ Node* d = AddNodeWithName("D");
+ EXPECT_NE(b_id, d->id()); // Ids should not be reused.
+ EXPECT_EQ(6, graph_.num_node_ids());
+}
+
+TEST_F(GraphTest, InNodesAndOutNodes) {
+ Node* a = AddNodeWithName("A");
+ Node* b = AddNodeWithName("B");
+ Node* c = AddNodeWithName("C");
+ graph_.RemoveNode(b);
+ Node* d = AddNodeWithName("D");
+
+ const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a);
+ graph_.AddControlEdge(a, graph_.sink_node());
+ graph_.AddEdge(a, 0, c, 0);
+ graph_.AddControlEdge(c, graph_.sink_node());
+
+ EXPECT_EQ("A", a->name());
+ VerifyNodes(a, {graph_.source_node()}, {c, graph_.sink_node()});
+
+ EXPECT_EQ("C", c->name());
+ VerifyNodes(c, {a}, {graph_.sink_node()});
+
+ EXPECT_EQ("D", d->name());
+ VerifyNodes(d, {}, {});
+
+ VerifyNodes(graph_.source_node(), {}, {a, graph_.sink_node()});
+ VerifyNodes(graph_.sink_node(), {a, c, graph_.source_node()}, {});
+
+ graph_.RemoveEdge(source_to_a);
+ VerifyNodes(a, {}, {c, graph_.sink_node()});
+ VerifyNodes(graph_.source_node(), {}, {graph_.sink_node()}); // no more a
+
+ graph_.RemoveNode(c);
+ VerifyNodes(a, {}, {graph_.sink_node()}); // no more c
+ VerifyNodes(graph_.sink_node(), {a, graph_.source_node()}, {}); // no more c
+ EXPECT_EQ(6, graph_.num_node_ids());
+ EXPECT_EQ(5, graph_.num_edge_ids());
+}
+
+TEST_F(GraphTest, NodeIteration) {
+ // Set up the graph with some holes due to removals.
+ Node* a = AddNodeWithName("A");
+ Node* b = AddNodeWithName("B");
+ Node* c = AddNodeWithName("C");
+ graph_.RemoveNode(b);
+ Node* d = AddNodeWithName("D");
+ const Edge* source_to_a = graph_.AddControlEdge(graph_.source_node(), a);
+ graph_.AddControlEdge(a, graph_.sink_node());
+ graph_.AddEdge(a, 0, c, 0);
+ graph_.AddControlEdge(c, graph_.sink_node());
+ graph_.RemoveEdge(source_to_a);
+ graph_.RemoveNode(c);
+
+ // expected = set of all node DebugStrings we expect in the graph
+ std::set<string> expected;
+ expected.insert(graph_.source_node()->DebugString());
+ expected.insert(a->DebugString());
+ expected.insert(d->DebugString());
+ expected.insert(graph_.sink_node()->DebugString());
+
+ // Verify that iterating through ids gets the same set of nodes.
+ std::set<string> actual;
+ for (int id = 0; id < graph_.num_node_ids(); ++id) {
+ Node* node = graph_.FindNodeId(id);
+ if (node != nullptr) {
+ actual.insert(node->DebugString());
+ }
+ }
+ EXPECT_EQ(expected, actual);
+
+ // Verify that range-based for loop gets the same set of nodes.
+ actual.clear();
+ for (Node* node : graph_.nodes()) {
+ actual.insert(node->DebugString());
+ }
+ EXPECT_EQ(expected, actual);
+}
+
+static void CheckType(Node* node, bool b) {
+ EXPECT_TRUE(b) << node->DebugString();
+ // Make sure none of the other IsFoo() methods return true.
+ int count = 0;
+ if (node->IsSource()) count++;
+ if (node->IsSink()) count++;
+ if (node->IsOp()) count++;
+ EXPECT_EQ(1, count) << node->DebugString();
+}
+
+TEST_F(GraphTest, Type) {
+ Node* op = AddNodeWithName("A");
+ CheckType(graph_.source_node(), graph_.source_node()->IsSource());
+ CheckType(graph_.sink_node(), graph_.sink_node()->IsSink());
+ CheckType(op, op->IsOp());
+}
+
+// Convert edge iteration results into a sorted string.
+static string EdgeIter(const Graph& g) {
+ std::vector<std::pair<int, int> > edges;
+ for (const Edge* e : g.edges()) {
+ edges.push_back(std::make_pair(e->src()->id(), e->dst()->id()));
+ }
+ std::sort(edges.begin(), edges.end());
+ string result;
+ for (auto& p : edges) {
+ strings::StrAppend(&result, p.first, "->", p.second, ";");
+ }
+ return result;
+}
+
+TEST_F(GraphTest, EdgeIteration) {
+ EXPECT_EQ("0->1;", EdgeIter(graph_));
+
+ Node* a = AddNodeWithName("A");
+ Node* b = AddNodeWithName("B");
+ EXPECT_EQ("0->1;", EdgeIter(graph_)); // Since a,b are currently disconnected
+
+ graph_.AddEdge(a, 0, b, 0);
+ EXPECT_EQ("0->1;2->3;", EdgeIter(graph_));
+
+ graph_.AddControlEdge(graph_.source_node(), a);
+ graph_.AddControlEdge(b, graph_.sink_node());
+ EXPECT_EQ("0->1;0->2;2->3;3->1;", EdgeIter(graph_));
+
+ graph_.AddEdge(a, 1, a, 0);
+ EXPECT_EQ("0->1;0->2;2->2;2->3;3->1;", EdgeIter(graph_));
+}
+
+TEST_F(GraphTest, NewName) {
+ string a1 = graph_.NewName("A");
+ string a2 = graph_.NewName("A");
+ string b1 = graph_.NewName("B");
+ EXPECT_NE(a1, a2);
+ EXPECT_NE(a1, b1);
+ EXPECT_NE(a2, b1);
+ EXPECT_TRUE(StringPiece(a1).starts_with("A")) << a1;
+}
+
+REGISTER_OP("Input").Output("o: float");
+REGISTER_OP("In2Out1").Input("a: float").Input("b: float").Output("o: float");
+
+static void BM_InEdgeIteration(int iters, int num_nodes) {
+ testing::StopTiming();
+ string s;
+ for (int in = 0; in < 10; in++) {
+ s += strings::Printf("node { name: 'in%04d' op: 'Input' }", in);
+ }
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ for (int op = 0; op < num_nodes; op++) {
+ s += strings::Printf(
+ "node { name: 'op%04d' op: 'In2Out1' input: ['in%04d', 'in%04d' ] }",
+ op, rnd.Uniform(10), rnd.Uniform(10));
+ }
+
+ Graph graph(OpRegistry::Global());
+ GraphDef graph_def;
+ CHECK(protobuf::TextFormat::ParseFromString(s, &graph_def));
+ GraphConstructorOptions opts;
+ TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
+
+ int64 sum = 0;
+ testing::StartTiming();
+ for (int i = 0; i < iters; i += graph.num_node_ids()) {
+ for (const Node* node : graph.nodes()) {
+ for (auto e : node->in_edges()) {
+ sum += e->id();
+ }
+ }
+ }
+ VLOG(1) << sum;
+}
+BENCHMARK(BM_InEdgeIteration)->Range(10, 100000);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
new file mode 100644
index 0000000000..8c34323dbe
--- /dev/null
+++ b/tensorflow/core/graph/node_builder.cc
@@ -0,0 +1,115 @@
+#include "tensorflow/core/graph/node_builder.h"
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+NodeBuilder::NodeBuilder(const string& name, const string& op_name,
+ const OpRegistryInterface* op_registry)
+ : def_builder_(name, op_name, op_registry) {}
+
+NodeBuilder::NodeBuilder(const string& name, const OpDef* op_def)
+ : def_builder_(name, op_def) {}
+
+NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) {
+ inputs_.emplace_back(src_node, src_index);
+ DataType dt;
+ if (GetOutputType(src_node, src_index, &dt)) {
+ def_builder_.Input(src_node->name(), src_index, dt);
+ }
+ return *this;
+}
+
+NodeBuilder& NodeBuilder::Input(NodeOut src) {
+ if (src.error) {
+ AddIndexError(src.node, src.index);
+ } else {
+ inputs_.emplace_back(src.node, src.index);
+ def_builder_.Input(src.name, src.index, src.dt);
+ }
+ return *this;
+}
+
+NodeBuilder& NodeBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
+ std::vector<NodeDefBuilder::NodeOut> srcs;
+ srcs.reserve(src_list.size());
+ for (const auto& node_out : src_list) {
+ if (node_out.error) {
+ AddIndexError(node_out.node, node_out.index);
+ } else {
+ srcs.emplace_back(node_out.name, node_out.index, node_out.dt);
+ inputs_.emplace_back(node_out.node, node_out.index);
+ }
+ }
+ def_builder_.Input(srcs);
+ return *this;
+}
+
+NodeBuilder& NodeBuilder::ControlInput(Node* src_node) {
+ control_inputs_.emplace_back(src_node);
+ def_builder_.ControlInput(src_node->name());
+ return *this;
+}
+
+NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) {
+ control_inputs_.insert(control_inputs_.end(), src_nodes.begin(),
+ src_nodes.end());
+ for (Node* src_node : src_nodes) {
+ def_builder_.ControlInput(src_node->name());
+ }
+ return *this;
+}
+
+NodeBuilder& NodeBuilder::Device(const string& device_spec) {
+ def_builder_.Device(device_spec);
+ return *this;
+}
+
+Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
+ // In case of error, set *created_node to nullptr.
+ if (created_node != nullptr) *created_node = nullptr;
+ if (!errors_.empty()) {
+ return errors::InvalidArgument(str_util::Join(errors_, "\n"));
+ }
+
+ NodeDef node_def;
+ TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def));
+ TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def()));
+ Status status;
+ Node* node = graph->AddNode(node_def, &status);
+ if (!status.ok()) return status;
+
+ for (size_t i = 0; i < inputs_.size(); ++i) {
+ if (inputs_[i].node != nullptr) { // Skip back edges.
+ graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
+ }
+ }
+ for (Node* control_input : control_inputs_) {
+ graph->AddControlEdge(control_input, node);
+ }
+ if (created_node != nullptr) *created_node = node;
+ return Status::OK();
+}
+
+void NodeBuilder::AddIndexError(Node* node, int i) {
+ if (node == nullptr) {
+ errors_.emplace_back(
+ strings::StrCat("Attempt to add nullptr Node to node with type",
+ def_builder_.op_def().name()));
+ } else {
+ errors_.emplace_back(
+ strings::StrCat("Attempt to add output ", i, " of ", node->name(),
+ " not in range [0, ", node->num_outputs(),
+ ") to node with type ", def_builder_.op_def().name()));
+ }
+}
+
+bool NodeBuilder::GetOutputType(Node* node, int i, DataType* dt) {
+ bool error;
+ *dt = SafeGetOutput(node, i, &error);
+ if (error) AddIndexError(node, i);
+ return !error;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h
new file mode 100644
index 0000000000..dd34b97f23
--- /dev/null
+++ b/tensorflow/core/graph/node_builder.h
@@ -0,0 +1,146 @@
+#ifndef TENSORFLOW_GRAPH_NODE_BUILDER_H_
+#define TENSORFLOW_GRAPH_NODE_BUILDER_H_
+
+#include <vector>
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tensorflow {
+
+// This is a helper for creating a Node and adding it to a Graph.
+// Internally, it uses a NodeDefBuilder to automatically set attrs
+// that can be inferred from the inputs, and use default values
+// (where they exist) for unspecified attrs. Example usage:
+//
+// Node* node;
+// Status status = NodeBuilder(node_name, op_name)
+// .Input(...)
+// .Attr(...)
+// .Finalize(&graph, &node);
+// if (!status.ok()) return status;
+// // Use node here.
+class NodeBuilder {
+ public:
+ // For specifying the output of a Node to provide to one of the Input()
+ // functions below. It supports both regular inputs (where you are
+ // connecting to an existing Node*), and inputs from outside the graph
+ // (or haven't been added to the graph yet, like back edges, where
+ // you don't have a Node*). Both types can be mixed, e.g. in an
+ // ArraySlice.
+ struct NodeOut {
+ // For referencing an existing Node.
+ NodeOut(Node* n, int i = 0) // NOLINT(runtime/explicit)
+ : node(n),
+ error(false),
+ name(node != nullptr ? node->name() : (error = true, "")),
+ index(i),
+ dt(SafeGetOutput(node, i, &error)) {}
+
+ // For referencing Nodes not in the graph being built. It is
+ // useful when preparing a graph for ExtendSession or creating a
+ // back edge to a node that hasn't been added to the graph yet,
+ // but will be.
+ NodeOut(const string& name, int i, DataType t)
+ : node(nullptr), error(false), name(name), index(i), dt(t) {}
+
+ // Default constructor for std::vector<NodeOut>.
+ NodeOut() {}
+
+ Node* node = nullptr;
+ // error is set to true if:
+ // * the NodeOut was default constructed and never overwritten,
+ // * a nullptr Node* was passed to the NodeOut constructor, or
+ // * an out-of-range index was passed to the NodeOut constructor.
+ bool error = true;
+ string name;
+ int index = 0;
+ DataType dt = DT_FLOAT;
+ };
+
+ // Specify the name and the Op (either via an OpDef or the name of
+ // the Op plus a registry) for the Node. Other fields are
+ // specified by calling the methods below.
+ // REQUIRES: The OpDef must satisfy ValidateOpDef().
+ NodeBuilder(const string& name, const string& op_name,
+ const OpRegistryInterface* op_registry = OpRegistry::Global());
+ NodeBuilder(const string& name, const OpDef* op_def);
+
+ // You must call one Input() function per input_arg in the Op,
+ // *and in the same order as the input_args appear in the OpDef.*
+
+ // For inputs that take a single tensor.
+ NodeBuilder& Input(Node* src_node, int src_index = 0);
+ NodeBuilder& Input(NodeOut src);
+
+ // For inputs that take a list of tensors.
+ NodeBuilder& Input(gtl::ArraySlice<NodeOut> src_list);
+
+ // Require that this node run after src_node(s).
+ NodeBuilder& ControlInput(Node* src_node);
+ NodeBuilder& ControlInputs(gtl::ArraySlice<Node*> src_nodes);
+
+ // Sets the "requested device spec" in the NodeDef (not the
+ // "assigned device" in the Node).
+ NodeBuilder& Device(const string& device_spec);
+
+ // Set the value of an attr. attr_name must match the name of one of
+ // attrs defined by the Op, and value must have the corresponding type
+ // (see SetAttrValue() in ../framework/attr_value_util.h for legal
+ // types for value). Note that attrs will be set automatically if
+ // they can be determined by the inputs.
+ template <class T>
+ NodeBuilder& Attr(const string& attr_name, T&& value);
+ template <class T>
+ NodeBuilder& Attr(const string& attr_name, std::initializer_list<T> value);
+
+ // Validates the described node and adds it to *graph, adding edges
+ // for all (non-back) inputs. If created_node is not nullptr,
+ // *created_node will be set to the new node (or nullptr on error).
+ Status Finalize(Graph* graph, Node** created_node) const;
+
+ private:
+ static DataType SafeGetOutput(Node* node, int i, bool* error) {
+ if (node != nullptr && i >= 0 && i < node->num_outputs()) {
+ *error = false;
+ return node->output_type(i);
+ } else {
+ *error = true;
+ return DT_FLOAT;
+ }
+ }
+
+ // If SafeGetOutput indicates a range error, add it to errors_.
+ void AddIndexError(Node* node, int i);
+
+ // Set *dt and returns true if i is in range. Combines
+ // SafeGetOutput() and AddIndexError().
+ bool GetOutputType(Node* node, int i, DataType* dt);
+
+ NodeDefBuilder def_builder_;
+ std::vector<NodeOut> inputs_;
+ std::vector<Node*> control_inputs_;
+ std::vector<string> errors_;
+};
+
+// IMPLEMENTATION -------------------------------------------------------------
+
+template <class T>
+inline NodeBuilder& NodeBuilder::Attr(const string& attr_name, T&& value) {
+ def_builder_.Attr(attr_name, std::forward<T>(value));
+ return *this;
+}
+
+template <class T>
+NodeBuilder& NodeBuilder::Attr(const string& attr_name,
+ std::initializer_list<T> value) {
+ def_builder_.Attr(attr_name, value);
+ return *this;
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_NODE_BUILDER_H_
diff --git a/tensorflow/core/graph/node_builder_test.cc b/tensorflow/core/graph/node_builder_test.cc
new file mode 100644
index 0000000000..9f667d00e4
--- /dev/null
+++ b/tensorflow/core/graph/node_builder_test.cc
@@ -0,0 +1,59 @@
+#include "tensorflow/core/graph/node_builder.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+REGISTER_OP("Source").Output("o: out_types").Attr("out_types: list(type)");
+REGISTER_OP("Sink").Input("i: T").Attr("T: type");
+
+TEST(NodeBuilderTest, Simple) {
+ RequireDefaultOps();
+ Graph graph(OpRegistry::Global());
+ Node* source_node;
+ EXPECT_OK(NodeBuilder("source_op", "Source")
+ .Attr("out_types", {DT_INT32, DT_STRING})
+ .Finalize(&graph, &source_node));
+ ASSERT_TRUE(source_node != nullptr);
+
+ // Try connecting to each of source_node's outputs.
+ EXPECT_OK(NodeBuilder("sink1", "Sink")
+ .Input(source_node)
+ .Finalize(&graph, nullptr));
+ EXPECT_OK(NodeBuilder("sink2", "Sink")
+ .Input(source_node, 1)
+ .Finalize(&graph, nullptr));
+
+ // Generate an error if the index is out of range.
+ EXPECT_FALSE(NodeBuilder("sink3", "Sink")
+ .Input(source_node, 2)
+ .Finalize(&graph, nullptr)
+ .ok());
+ EXPECT_FALSE(NodeBuilder("sink4", "Sink")
+ .Input(source_node, -1)
+ .Finalize(&graph, nullptr)
+ .ok());
+ EXPECT_FALSE(NodeBuilder("sink5", "Sink")
+ .Input({source_node, -1})
+ .Finalize(&graph, nullptr)
+ .ok());
+
+ // Generate an error if the node is nullptr. This can happen when using
+ // GraphDefBuilder if there was an error creating the input node.
+ EXPECT_FALSE(NodeBuilder("sink6", "Sink")
+ .Input(nullptr)
+ .Finalize(&graph, nullptr)
+ .ok());
+ EXPECT_FALSE(NodeBuilder("sink7", "Sink")
+ .Input(NodeBuilder::NodeOut(nullptr, 0))
+ .Finalize(&graph, nullptr)
+ .ok());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/optimizer_cse.cc b/tensorflow/core/graph/optimizer_cse.cc
new file mode 100644
index 0000000000..2fa6f075c0
--- /dev/null
+++ b/tensorflow/core/graph/optimizer_cse.cc
@@ -0,0 +1,220 @@
+// This module implements a common subexpression elimination pass. We
+// process the nodes in the graph in reverse postorder
+// (i.e. inputs before their downstream dependencies). The rough algorithm is
+// as follows:
+//
+// std::unordered_map<size_t, Node*> available
+// for each node n in forward topological order:
+// h = NodeHash(n)
+// if available[h] exists and Equivalent(available(h), h)
+// redirect downstream uses of outputs of n to available[h]
+// remove n from graph
+// else
+// if available[h] does not exist
+// available[h] = n
+//
+// This is similar to the global value number algorithm describe in this
+// paper:
+//
+// "Global code motion/global value numbering", Cliff Click, PLDI '95
+// Proceedings of the ACM SIGPLAN 1995 conference on Programming
+// language design and implementation, Pages 246-257
+// http://dl.acm.org/citation.cfm?id=207154
+
+#include "tensorflow/core/graph/optimizer_cse.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+class OptimizerCSE {
+ public:
+ explicit OptimizerCSE(Graph* g) : g_(g) {}
+
+ void Optimize(std::function<bool(const Node*)> consider_fn);
+
+ private:
+ struct Scratch;
+
+ static size_t NodeHash(const Node* n);
+ static bool Equivalent(const Node* a, const Node* b, Scratch* s);
+ static bool EqualAttrs(const Node* a, const Node* b, Scratch* s);
+
+ Graph* g_;
+};
+
+static void FillInputs(const Node* n,
+ gtl::InlinedVector<Node*, 4>* control_edges,
+ gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
+ DCHECK_EQ(in->size(), n->num_inputs());
+ control_edges->clear();
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) {
+ control_edges->push_back(e->src());
+ } else {
+ (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
+ }
+ }
+ std::sort(control_edges->begin(), control_edges->end());
+ if (n->op_def().is_commutative()) {
+ // For commutative inputs, we sort the input by the input Node*
+ // to get a canonical ordering (so that add(a,b) and add(b, a) will
+ // hash to the same value if is_commutative is true for 'add').
+ std::sort(in->begin(), in->end());
+ }
+}
+
+static size_t kIllegalNodeHash = 0;
+
+size_t OptimizerCSE::NodeHash(const Node* n) {
+ const DataTypeVector& out = n->output_types();
+ string str_to_hash = strings::StrCat(n->type_string(), out.size());
+ for (DataType dt : out) {
+ strings::StrAppend(&str_to_hash, dt);
+ }
+
+ const int N_in = n->num_inputs();
+ strings::StrAppend(&str_to_hash, N_in);
+ gtl::InlinedVector<Node*, 4> control_edges;
+ gtl::InlinedVector<std::pair<Node*, int>, 4> in(N_in);
+ FillInputs(n, &control_edges, &in);
+ for (const auto& edge : in) {
+ strings::StrAppend(&str_to_hash, edge.first->id(), edge.second);
+ }
+
+ size_t h = Hash64(str_to_hash);
+
+#if !defined(__ANDROID__) && !defined(ANDROID)
+ // Hash the attrs. For example, this makes sure different constants
+ // end up in different hash buckets.
+ string tmp;
+ for (const auto& attr : n->def().attr()) {
+ tmp = attr.first;
+ attr.second.AppendToString(&tmp);
+ // Add hashes of attrs, so the order of attrs doesn't matter.
+ h += Hash32(tmp.data(), tmp.size(), 0x87341245);
+ }
+#endif
+
+ if (h == kIllegalNodeHash) h = kIllegalNodeHash + 1;
+ return h;
+}
+
+struct OptimizerCSE::Scratch {
+ // For EqualAttrs():
+ string a;
+ string b;
+};
+
+bool OptimizerCSE::EqualAttrs(const Node* a, const Node* b, Scratch* scratch) {
+ if (a->def().attr_size() != b->def().attr_size()) return false;
+
+ for (const auto& attr : b->def().attr()) {
+ auto iter = a->def().attr().find(attr.first);
+ if (iter == a->def().attr().end()) return false;
+ // Note: it should be safe to compare proto serializations of the attr
+ // values since at most one field should be set in each (indeed, it
+ // should be the same field).
+ iter->second.SerializeToString(&scratch->a);
+ attr.second.SerializeToString(&scratch->b);
+ if (scratch->a != scratch->b) return false;
+ }
+ return true;
+}
+
+static bool HasRefInput(const Node* n) {
+ for (auto dt : n->input_types()) {
+ if (IsRefType(dt)) return true;
+ }
+ return false;
+}
+
+bool OptimizerCSE::Equivalent(const Node* a, const Node* b, Scratch* scratch) {
+ // Different op names are different
+ if (a->type_string() != b->type_string()) return false;
+
+ // Never consider stateful nodes (such as non-const inputs) equivalent.
+ if (a->op_def().is_stateful()) return false;
+
+ // For now, we consider any node that takes a ref input to not be
+ // equivalent to any other node.
+ if (HasRefInput(a) || HasRefInput(b)) return false;
+
+ // Compare attrs. Note that equal attrs implies equal input and
+ // output types.
+ if (!EqualAttrs(a, b, scratch)) return false;
+
+ // Compare input sources
+ if (a->num_inputs() != b->num_inputs()) return false;
+ const int N_in = a->num_inputs();
+ gtl::InlinedVector<Node*, 4> a_control_edges;
+ gtl::InlinedVector<Node*, 4> b_control_edges;
+ gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
+ gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in);
+ FillInputs(a, &a_control_edges, &a_in);
+ FillInputs(b, &b_control_edges, &b_in);
+ if (a_in != b_in) return false;
+ if (a_control_edges != b_control_edges) return false;
+
+ return true;
+}
+
+void OptimizerCSE::Optimize(std::function<bool(const Node*)> consider_fn) {
+ // This very simple implementation works if the whole graph is one
+ // giant basic block (because we just traverse nodes in a
+ // topological order). We'll need to do something more
+ // sophisticated when we have control flow/loops/etc.
+
+ // TODO(jeff): We need to handle Update nodes specially, but dealing
+ // with more general control flow will also solve this issue, and for
+ // now, our updates are almost always the most downstream nodes in
+ // the graph.
+ std::vector<Node*> order;
+ GetReversePostOrder(*g_, &order);
+
+ // Our value is just a single Node*, meaning we keep just a single
+ // candidate for a given node hash value. This may cause us to
+ // (rarely) lose some optimization opportunities if there are
+ // hash collisions, but it allows us to avoid having the value
+ // be a set<Node*> (or equivalent).
+ std::unordered_map<size_t, Node*> available;
+
+ // Scratch space for Equivalent calls. Allocated here and passed in to
+ // Equivalent to avoid allocation inside the loop below.
+ Scratch scratch;
+ for (Node* n : order) {
+ if (!n->IsOp()) continue;
+
+ // See if we should consider this node at all
+ if (consider_fn != nullptr && !consider_fn(n)) continue;
+
+ size_t h = NodeHash(n);
+ Node** candidate = &available[h];
+ if (*candidate == nullptr) {
+ // No existing match: insert "n" into the hash table under "h"
+ *candidate = n;
+ } else if (Equivalent(*candidate, n, &scratch)) {
+ VLOG(1) << "CSE: equivalent: " << (*candidate)->name() << " and "
+ << n->name();
+ // *candidate and n are equivalent. Therefore, we can replace
+ // n with *candidate by fixing up outgoing edges from "n" to instead
+ // come from "*candidate", and then delete n from the graph
+ for (const Edge* e : n->out_edges()) {
+ g_->AddEdge(*candidate, e->src_output(), e->dst(), e->dst_input());
+ }
+ g_->RemoveNode(n);
+ }
+ }
+}
+
+void OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn) {
+ OptimizerCSE opt(g);
+ opt.Optimize(consider_fn);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/optimizer_cse.h b/tensorflow/core/graph/optimizer_cse.h
new file mode 100644
index 0000000000..430c97a449
--- /dev/null
+++ b/tensorflow/core/graph/optimizer_cse.h
@@ -0,0 +1,19 @@
+// An optimization pass that performs common subexpression elimination.
+
+#ifndef TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_
+#define TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_
+
+#include <sys/types.h>
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// Perform common-subexpression elimination on the graph "*g". If
+// "consider_fn" is not nullptr, then only nodes for which
+// consider_fn(node) returns true will be considered for combining
+// during the common subexpression elimination.
+extern void OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_
diff --git a/tensorflow/core/graph/optimizer_cse_test.cc b/tensorflow/core/graph/optimizer_cse_test.cc
new file mode 100644
index 0000000000..ebbb948fdc
--- /dev/null
+++ b/tensorflow/core/graph/optimizer_cse_test.cc
@@ -0,0 +1,365 @@
+#include "tensorflow/core/graph/optimizer_cse.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+namespace {
+
+static void InitGraph(const string& s, Graph* graph) {
+ GraphDef graph_def;
+
+ auto parser = protobuf::TextFormat::Parser();
+ // parser.AllowRelaxedWhitespace(true);
+ CHECK(parser.MergeFromString(s, &graph_def)) << s;
+ GraphConstructorOptions opts;
+ TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
+}
+
+class OptimizerCSETest : public ::testing::Test {
+ public:
+ OptimizerCSETest() : graph_(OpRegistry::Global()) { RequireDefaultOps(); }
+
+ void InitGraph(const string& s) {
+ ::tensorflow::InitGraph(s, &graph_);
+ original_ = CanonicalGraphString(&graph_);
+ }
+
+ static bool IncludeNode(const Node* n) { return n->IsOp(); }
+
+ static string EdgeId(const Node* n, int index) {
+ if (index == 0) {
+ return n->name();
+ } else if (index == Graph::kControlSlot) {
+ return strings::StrCat(n->name(), ":control");
+ } else {
+ return strings::StrCat(n->name(), ":", index);
+ }
+ }
+
+ string CanonicalGraphString(Graph* g) {
+ std::vector<string> nodes;
+ std::vector<string> edges;
+ for (const Node* n : g->nodes()) {
+ if (IncludeNode(n)) {
+ nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
+ }
+ }
+ for (const Edge* e : g->edges()) {
+ if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
+ edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
+ EdgeId(e->dst(), e->dst_input())));
+ }
+ }
+ // Canonicalize
+ std::sort(nodes.begin(), nodes.end());
+ std::sort(edges.begin(), edges.end());
+ return strings::StrCat(str_util::Join(nodes, ";"), "|",
+ str_util::Join(edges, ";"));
+ }
+
+ string DoCSE(std::function<bool(const Node*)> consider_fn = nullptr) {
+ string before = CanonicalGraphString(&graph_);
+ LOG(ERROR) << "Before rewrites: " << before;
+
+ OptimizeCSE(&graph_, consider_fn);
+
+ string result = CanonicalGraphString(&graph_);
+ LOG(ERROR) << "After rewrites: " << result;
+ return result;
+ }
+
+ const string& OriginalGraph() const { return original_; }
+
+ Graph graph_;
+ string original_;
+};
+
+REGISTER_OP("Input").Output("o: float").SetIsStateful();
+
+// Note that the "rules" in these tests are not meant to be logically correct
+TEST_F(OptimizerCSETest, Simple) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoCSE(),
+ "A(Input);B(Input);D(Mul)|"
+ "A->D;B->D:1");
+}
+
+TEST_F(OptimizerCSETest, Simple_ThreeEquivalent) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoCSE(),
+ "A(Input);B(Input);E(Mul)|"
+ "A->E;B->E:1");
+}
+
+TEST_F(OptimizerCSETest, Simple_WithFixups) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['C', 'D'] }");
+ EXPECT_EQ(DoCSE(),
+ "A(Input);B(Input);D(Mul);E(Mul)|"
+ "A->D;B->D:1;D->E;D->E:1");
+}
+
+TEST_F(OptimizerCSETest, Simple_Commutative) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['B', 'A'] }");
+ EXPECT_EQ(DoCSE(),
+ "A(Input);B(Input);D(Mul)|"
+ "A->D:1;B->D");
+}
+
+static bool IsNotMultiply(const Node* n) { return n->type_string() != "Mul"; }
+
+// Like Simple_Commutative,
+TEST_F(OptimizerCSETest, Simple_Filtered) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['B', 'A'] }");
+ EXPECT_EQ(DoCSE(IsNotMultiply), OriginalGraph());
+}
+
+TEST_F(OptimizerCSETest, Simple_NotCommutative) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['B', 'A'] }");
+ EXPECT_EQ(DoCSE(), OriginalGraph());
+}
+
+TEST_F(OptimizerCSETest, NotEquivalent_Ops) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoCSE(), OriginalGraph());
+}
+
+TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs1) {
+ // Should still do CSE for ops with attrs if they match.
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] attr { key: 'shape'"
+ " value { shape: { dim: { size: 37 name: 'SAME_NAME' } } } } }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] attr { key: 'shape'"
+ " value { shape: { dim: { size: 37 name: 'SAME_NAME' } } } } }");
+ EXPECT_EQ(DoCSE(),
+ "A(Input);B(Input);D(Mul)|"
+ "A->D;B->D:1");
+}
+
+TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs2) {
+ // Should still do CSE for ops with attrs if they match, even if they
+ // are not in the same order.
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B']"
+ " attr { key: 'a' value { i: 3 } }"
+ " attr { key: 't' value { type: DT_INT32 } } }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B']"
+ " attr { key: 't' value { type: DT_INT32 } }"
+ " attr { key: 'a' value { i: 3 } } }");
+ EXPECT_EQ(DoCSE(),
+ "A(Input);B(Input);D(Mul)|"
+ "A->D;B->D:1");
+}
+
+TEST_F(OptimizerCSETest, SameConstants) {
+ // Should still do CSE for ops with constants if the values are identical
+ InitGraph(
+ "node { name: 'A' op: 'Const' "
+ " attr { key: 'dtype' value { type: DT_INT32 } }"
+ " attr { key: 'value' value {"
+ " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+ " int_val: 0 } } } }"
+ "node { name: 'B' op: 'Const' "
+ " attr { key: 'dtype' value { type: DT_INT32 } }"
+ " attr { key: 'value' value {"
+ " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+ " int_val: 0 } } } }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_INT32 } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoCSE(),
+ "B(Const);D(Mul)|"
+ "B->D;B->D:1");
+}
+
+TEST_F(OptimizerCSETest, DifferentConstants) {
+ // Should still do CSE for ops with extensions if the extensions are identical
+ InitGraph(
+ "node { name: 'A' op: 'Const' "
+ " attr { key: 'dtype' value { type: DT_INT32 } }"
+ " attr { key: 'value' value {"
+ " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+ " int_val: 0 } } } }"
+ "node { name: 'B' op: 'Const' "
+ " attr { key: 'dtype' value { type: DT_INT32 } }"
+ " attr { key: 'value' value {"
+ " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+ " int_val: 100000 } } } }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_INT32 } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoCSE(),
+ "A(Const);B(Const);D(Mul)|"
+ "A->D;B->D:1");
+}
+
+TEST_F(OptimizerCSETest, SameOps_DifferentAttrs1) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B']"
+ " attr { key: 'a' value { i: 3 } }"
+ " attr { key: 't' value { type: DT_INT32 } } }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B']"
+ " attr { key: 't' value { type: DT_INT32 } }"
+ " attr { key: 'a' value { i: 4 } } }");
+ EXPECT_EQ(DoCSE(), OriginalGraph());
+}
+
+TEST_F(OptimizerCSETest, SameOps_DifferentAttrs2) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B']"
+ " attr { key: 'a' value { i: 3 } }"
+ " attr { key: 't' value { type: DT_FLOAT } } }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B']"
+ " attr { key: 't' value { type: DT_INT32 } }"
+ " attr { key: 'a' value { i: 3 } } }");
+ EXPECT_EQ(DoCSE(), OriginalGraph());
+}
+
+TEST_F(OptimizerCSETest, NotEquivalent_Inputs) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'C'] }");
+ EXPECT_EQ(DoCSE(), OriginalGraph());
+}
+
+TEST_F(OptimizerCSETest, Constant_Dedup) {
+ Tensor a(DT_FLOAT, TensorShape({1}));
+ a.flat<float>()(0) = 1.0;
+ Tensor b(DT_DOUBLE, TensorShape({1})); // Different type
+ b.flat<double>()(0) = 1.0;
+ Tensor c(DT_FLOAT, TensorShape({1, 1})); // Different shape
+ c.flat<float>()(0) = 1.0;
+ Tensor d(DT_FLOAT, TensorShape({1})); // Different value
+ d.flat<float>()(0) = 2.0;
+
+ // A graph contains a bunch of constants.
+ Graph g(OpRegistry::Global());
+ for (auto val : {a, b, c, d, d, c, b, a}) {
+ test::graph::Constant(&g, val); // Node name is n/_0, n/_1, ...
+ }
+ GraphDef gdef;
+ test::graph::ToGraphDef(&g, &gdef);
+ InitGraph(gdef.DebugString());
+
+ EXPECT_EQ(OriginalGraph(),
+ "n/_0(Const);n/_1(Const);n/_2(Const);n/_3(Const);"
+ "n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|");
+ // In theory, there are 2^4 possible correct output of CSE. In this
+ // test, it happens happens to eliminate the first 4 nodes.
+ EXPECT_EQ(DoCSE(), "n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|");
+}
+
+static void BM_CSE(int iters, int op_nodes) {
+ testing::StopTiming();
+ string s;
+ for (int in = 0; in < 10; in++) {
+ s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
+ }
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ for (int op = 0; op < op_nodes; op++) {
+ s += strings::Printf(
+ "node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { "
+ "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
+ op, rnd.Uniform(10), rnd.Uniform(10));
+ }
+
+ bool first = true;
+ while (iters > 0) {
+ Graph* graph = new Graph(OpRegistry::Global());
+ InitGraph(s, graph);
+ int N = graph->num_node_ids();
+ if (first) {
+ testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
+ first = false;
+ }
+ {
+ testing::StartTiming();
+ OptimizeCSE(graph, nullptr);
+ testing::StopTiming();
+ }
+ iters -= N; // Our benchmark units are individual graph nodes,
+ // not whole graphs
+ delete graph;
+ }
+}
+BENCHMARK(BM_CSE)->Arg(1000)->Arg(10000);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc
new file mode 100644
index 0000000000..7910511dfb
--- /dev/null
+++ b/tensorflow/core/graph/subgraph.cc
@@ -0,0 +1,258 @@
+#include "tensorflow/core/graph/subgraph.h"
+
+#include <algorithm>
+#include <deque>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// ----------------------------------------------------------------------------
+// Subgraph construction-related routines
+// ----------------------------------------------------------------------------
+// TODO(vrv): Profile the unordered_set and unordered_map use in this file to
+// see if we should use an alternative implementation.
+
+namespace {
+
+typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher> NameIndex;
+
+// Rewrite graph by replacing the output tensors specified in
+// "fed_outputs" with special feed nodes for each specified output
+// tensor, and removing any nodes that are now disconnected from the
+// part of the graph that reaches the sink node. The set of special
+// feed nodes added to the graph are returned in "*feed_nodes".
+//
+// Return true on success. On error, return false and sets *error to
+// an appropriate error message (and *g is left in an indeterminate
+// state).
+static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
+ const gtl::ArraySlice<string>& fed_outputs,
+ NameIndex* name_index) {
+ for (const string& t : fed_outputs) {
+ TensorId id(ParseTensorName(t));
+
+ auto iter = name_index->find(id.first);
+ if (iter == name_index->end()) {
+ return errors::NotFound("FeedInputs: unable to find feed output ", t);
+ }
+ const Node* n = iter->second;
+ DCHECK_EQ(n->name(), id.first);
+ if (id.second >= n->num_outputs()) {
+ return errors::InvalidArgument(
+ "FeedInputs: ", t, " should have output index < ", n->num_outputs());
+ }
+
+ Node* recv_node;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second),
+ "_Recv")
+ .Attr("tensor_type", BaseType(n->output_type(id.second)))
+ .Attr("tensor_name", t)
+ .Attr("send_device", device_info.name())
+ .Attr("recv_device", device_info.name())
+ .Attr("send_device_incarnation",
+ static_cast<int64>(device_info.incarnation()))
+ .Attr("client_terminated", true)
+ .Finalize(g, &recv_node));
+ recv_node->set_assigned_device_name(device_info.name());
+
+ // Update name_index
+ (*name_index)[recv_node->name()] = recv_node;
+ g->AddControlEdge(g->source_node(), recv_node);
+
+ // Look through edges coming out of "n" for edges whose src_output() index
+ // matches "output_index". If found, replace the edges with a connection
+ // from the special feed node.
+ std::vector<const Edge*> to_remove;
+ for (const Edge* e : n->out_edges()) {
+ if (e->src_output() == id.second) {
+ to_remove.emplace_back(e);
+ } else if (e->src_output() == Graph::kControlSlot &&
+ n->def().op() == "Placeholder") {
+ // When feeding a Placeholder node, any outgoing control edges
+ // will be replaced with a control edge from the replacement
+ // recv_node.
+ // TODO(josh11b,mrry): Come up with a more elegant way of addressing
+ // the general version of this problem.
+ to_remove.emplace_back(e);
+ }
+ }
+
+ for (const Edge* e : to_remove) {
+ if (e->src_output() == id.second) {
+ g->AddEdge(recv_node, 0, e->dst(), e->dst_input());
+ } else {
+ CHECK_EQ(Graph::kControlSlot, e->src_output());
+ g->AddControlEdge(recv_node, e->dst());
+ }
+ g->RemoveEdge(e);
+ }
+ }
+ return Status::OK();
+}
+
+// Augment "*g" by adding special "fetch" nodes that connect to the
+// tensor outputs specified in "fetch_outputs" to retrieve the output
+// of the tensors. The new nodes added are set up to execute on
+// "client_device_name", and are returned in "*fetch_nodes".
+//
+// Return true on success. On error, return false and sets *error to
+// an appropriate error message (and *g is left in an indeterminate
+// state).
+static Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
+ const gtl::ArraySlice<string>& fetch_outputs,
+ NameIndex* name_index,
+ std::vector<Node*>* fetch_nodes) {
+ fetch_nodes->clear();
+ for (const string& t : fetch_outputs) {
+ // Parse t into node_name and output_index.
+ TensorId id(ParseTensorName(t));
+
+ // Find node in graph with that name.
+ auto iter = name_index->find(id.first);
+ if (iter == name_index->end()) {
+ return errors::NotFound("FetchOutputs node ", t, ": not found");
+ }
+ Node* n = iter->second;
+ DCHECK_EQ(n->name(), id.first);
+ VLOG(2) << "Found fetch node for " << t;
+
+ // Validate output_index
+ if (id.second >= n->num_outputs()) {
+ return errors::InvalidArgument("FetchOutputs ", t,
+ ": output index too large, must be < ",
+ n->num_outputs());
+ }
+
+ // Create the fetch Node and connect it up
+ Node* send_node;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second),
+ "_Send")
+ .Input(n, id.second)
+ .Attr("tensor_name", t)
+ .Attr("send_device", device_info.name())
+ .Attr("recv_device", device_info.name())
+ .Attr("send_device_incarnation",
+ static_cast<int64>(device_info.incarnation()))
+ .Attr("client_terminated", true)
+ .Finalize(g, &send_node));
+ send_node->set_assigned_device_name(device_info.name());
+ VLOG(1) << "Created fetch node: " << SummarizeNodeDef(send_node->def());
+
+ // Update the index.
+ (*name_index)[send_node->name()] = send_node;
+
+ g->AddControlEdge(send_node, g->sink_node());
+ fetch_nodes->push_back(send_node);
+ }
+
+ return Status::OK();
+}
+
+static bool AddNodeToTargets(const string& node_or_tensor_name,
+ const NameIndex& name_index,
+ std::unordered_set<const Node*>* targets) {
+ TensorId id = ParseTensorName(node_or_tensor_name);
+ auto iter = name_index.find(id.first);
+ if (iter == name_index.end()) {
+ return false;
+ }
+ const Node* n = iter->second;
+ if (n->name() != node_or_tensor_name) {
+ return false;
+ }
+
+ targets->insert(n);
+ return true;
+}
+
+static Status PruneForTargets(Graph* g, const NameIndex& name_index,
+ const std::vector<Node*>& fetch_nodes,
+ const gtl::ArraySlice<string>& target_nodes) {
+ string not_found;
+ std::unordered_set<const Node*> targets;
+ for (Node* n : fetch_nodes) {
+ if (!AddNodeToTargets(n->name(), name_index, &targets)) {
+ strings::StrAppend(&not_found, n->name(), " ");
+ }
+ }
+ for (const string& s : target_nodes) {
+ if (!AddNodeToTargets(s, name_index, &targets)) {
+ strings::StrAppend(&not_found, s, " ");
+ }
+ }
+ if (!not_found.empty()) {
+ return errors::NotFound("PruneForTargets: Some target nodes not found: ",
+ not_found);
+ }
+ PruneForReverseReachability(g, targets);
+
+ return Status::OK();
+}
+
+} // namespace
+
+namespace subgraph {
+
+Status RewriteGraphForExecution(
+ Graph* g, const gtl::ArraySlice<string>& fed_outputs,
+ const gtl::ArraySlice<string>& fetch_outputs,
+ const gtl::ArraySlice<string>& target_node_names,
+ const DeviceAttributes& device_info) {
+ std::unordered_set<string> endpoints(fed_outputs.begin(), fed_outputs.end());
+ for (const auto& fetch : fetch_outputs) {
+ if (endpoints.count(fetch) > 0) {
+ return errors::InvalidArgument(fetch, " is both fed and fetched.");
+ }
+ }
+
+ // A separate index mapping name to Node*, for use by FeedInputs,
+ // FetchOutputs, and PruneForTargets
+ NameIndex name_index;
+ for (Node* n : g->nodes()) {
+ name_index[n->name()] = n;
+ }
+
+ // Add the feeds. This may replace nodes in the graph, including the nodes
+ // currently listed in "fetch_nodes". We pass "name_index" so the index is
+ // kept up to date.
+ if (!fed_outputs.empty()) {
+ TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs, &name_index));
+ }
+
+ // Add the fetch nodes, also updating "name_index".
+ std::vector<Node*> fetch_nodes;
+ if (!fetch_outputs.empty()) {
+ TF_RETURN_IF_ERROR(
+ FetchOutputs(g, device_info, fetch_outputs, &name_index, &fetch_nodes));
+ }
+
+ // Prune the graph to only compute what is needed for the fetch nodes and the
+ // targets nodes.
+ if (!fetch_nodes.empty() || !target_node_names.empty()) {
+ TF_RETURN_IF_ERROR(
+ PruneForTargets(g, name_index, fetch_nodes, target_node_names));
+ }
+
+ return Status::OK();
+}
+
+} // namespace subgraph
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/subgraph.h b/tensorflow/core/graph/subgraph.h
new file mode 100644
index 0000000000..d2e138e8ae
--- /dev/null
+++ b/tensorflow/core/graph/subgraph.h
@@ -0,0 +1,49 @@
+#ifndef TENSORFLOW_GRAPH_SUBGRAPH_H_
+#define TENSORFLOW_GRAPH_SUBGRAPH_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+namespace subgraph {
+
+// Rewrite the graph structure of "*g" to deal with feeding node
+// outputs, fetching node outputs, and only running a subset of the
+// graph. "fed_outputs" and "fetch_outputs" are both lists of
+// output tensor identifiers in the form of
+// "<name>[:<optional_output_index>]", and "target_nodes_str" is a
+// lists of of target node names in "*g" "g".
+//
+// In the resulting graph "*g", output edges in "fed_outputs" have
+// been redirected to special "_recv" nodes introduced into the graph.
+// If these fed nodes are not needed in order to compute the effects
+// of the nodes in "targets_nodes" and "fetch_outputs", then these may
+// be omitted from the graph.
+//
+// In the resulting graph "*g", additional "_send" nodes are connected
+// to every output in "fetch_outputs". These "_send" nodes are set up
+// to execute on the device described by device_info.
+//
+// On success, returns OK, and sets "*g" to a version of "*g"
+// that represents the portions of the graph necessary for producing
+// the output of all nodes listed in "target_node_names" and fetching the
+// specific node outputs specified in "fetch_outputs".
+//
+// On failure, returns the error status. Possible errors include:
+// - fed output "node:output_index" does not exist in "*g"
+// - fetch output "node:output_index" does not exist in "*g"
+// - target node "node" does not exist in "*g"
+Status RewriteGraphForExecution(
+ Graph* g, const gtl::ArraySlice<string>& fed_outputs,
+ const gtl::ArraySlice<string>& fetch_outputs,
+ const gtl::ArraySlice<string>& target_node_names,
+ const DeviceAttributes& device_info);
+
+} // namespace subgraph
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_SUBGRAPH_H_
diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc
new file mode 100644
index 0000000000..ffb3e6e403
--- /dev/null
+++ b/tensorflow/core/graph/subgraph_test.cc
@@ -0,0 +1,305 @@
+#include "tensorflow/core/graph/subgraph.h"
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/status.h"
+#include <gtest/gtest.h>
+
+// TODO(josh11b): Test setting the "device" field of a NodeDef.
+// TODO(josh11b): Test that feeding won't prune targets.
+
+namespace tensorflow {
+namespace {
+
+class SubgraphTest : public ::testing::Test {
+ protected:
+ SubgraphTest() : g_(new Graph(OpRegistry::Global())) {
+ RequireDefaultOps();
+ device_info_.set_name("/job:a/replica:0/task:0/cpu:0");
+ device_info_.set_device_type(DeviceType(DEVICE_CPU).type());
+ device_info_.set_incarnation(0);
+ }
+
+ ~SubgraphTest() override {}
+
+ void ExpectOK(const string& gdef_ascii) {
+ CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef_));
+ GraphConstructorOptions opts;
+ TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef_, g_.get()));
+ }
+
+ Node* FindNode(const string& name) {
+ for (Node* n : g_->nodes()) {
+ if (n->name() == name) return n;
+ }
+ return nullptr;
+ }
+
+ bool HasNode(const string& name) { return FindNode(name) != nullptr; }
+
+ void ExpectNodes(const string& nodes) {
+ int count = 0;
+ std::vector<string> actual_nodes;
+ for (Node* n : g_->nodes()) {
+ if (n->IsOp()) {
+ count++;
+ actual_nodes.push_back(n->name());
+ }
+ }
+ std::sort(actual_nodes.begin(), actual_nodes.end());
+
+ LOG(INFO) << "Nodes present: " << str_util::Join(actual_nodes, " ");
+
+ std::vector<string> expected_nodes = str_util::Split(nodes, ',');
+ std::sort(expected_nodes.begin(), expected_nodes.end());
+ for (const string& s : expected_nodes) {
+ Node* n = FindNode(s);
+ EXPECT_TRUE(n != nullptr) << s;
+ if (n->def().op() == "_Send" || n->def().op() == "_Recv") {
+ EXPECT_EQ(device_info_.name(), n->assigned_device_name()) << s;
+ }
+ }
+
+ EXPECT_TRUE(actual_nodes.size() == expected_nodes.size())
+ << "\nActual: " << str_util::Join(actual_nodes, ",")
+ << "\nExpected: " << str_util::Join(expected_nodes, ",");
+ }
+
+ bool HasEdge(const string& src, int src_out, const string& dst, int dst_in) {
+ for (const Edge* e : g_->edges()) {
+ if (e->src()->name() == src && e->src_output() == src_out &&
+ e->dst()->name() == dst && e->dst_input() == dst_in)
+ return true;
+ }
+ return false;
+ }
+ bool HasControlEdge(const string& src, const string& dst) {
+ return HasEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot);
+ }
+
+ string Subgraph(const string& fed_str, const string& fetch_str,
+ const string& targets_str) {
+ Graph* subgraph = new Graph(OpRegistry::Global());
+ CopyGraph(*g_, subgraph);
+ std::vector<string> fed =
+ str_util::Split(fed_str, ',', str_util::SkipEmpty());
+ std::vector<string> fetch =
+ str_util::Split(fetch_str, ',', str_util::SkipEmpty());
+ std::vector<string> targets =
+ str_util::Split(targets_str, ',', str_util::SkipEmpty());
+
+ Status s = subgraph::RewriteGraphForExecution(subgraph, fed, fetch,
+ targets, device_info_);
+ if (!s.ok()) {
+ delete subgraph;
+ return s.ToString();
+ }
+
+ // Replace the graph with the subgraph for the rest of the display program
+ g_.reset(subgraph);
+ return "OK";
+ }
+
+ Graph* graph() { return g_.get(); }
+
+ private:
+ GraphDef gdef_;
+ std::unique_ptr<Graph> g_;
+ DeviceAttributes device_info_;
+};
+
+REGISTER_OP("TestParams").Output("o: float");
+REGISTER_OP("TestInput").Output("a: float").Output("b: float");
+REGISTER_OP("TestRelu").Input("i: float").Output("o: float");
+REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float");
+
+TEST_F(SubgraphTest, Targets1) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'W2' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+ "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
+ "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
+ "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
+ EXPECT_EQ("OK", Subgraph("", "", "t1"));
+ ExpectNodes("W1,input,t1");
+}
+
+TEST_F(SubgraphTest, Targets2) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'W2' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: 'W1' input: 'input:1' }"
+ "node { name: 't2' op: 'TestMul' input: 'W2' input: 't1' }"
+ "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
+ "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
+ EXPECT_EQ("OK", Subgraph("", "", "t2,t3_a"));
+ ExpectNodes("W1,W2,input,t1,t2,t3_a");
+}
+
+TEST_F(SubgraphTest, FedOutputs1) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'W2' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+ "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
+ "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
+ "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
+ EXPECT_EQ("OK", Subgraph("input:1", "", "t2"));
+ ExpectNodes("W1,W2,_recv_input_1,t1,t2");
+}
+
+TEST_F(SubgraphTest, FedRefNode) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'W2' op: 'TestParams' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W2', 'W1' ] }");
+ EXPECT_EQ("OK", Subgraph("W1:0", "", "t1"));
+ ExpectNodes("_recv_W1_0,W2,t1");
+ Node* n = FindNode("_recv_W1_0");
+ EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0)));
+}
+
+TEST_F(SubgraphTest, FedOutputs2) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'W2' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+ "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
+ "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
+ "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
+ // We feed input:1, but nothing connects to it, so the _recv(input:1)
+ // node also disappears.
+ EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2"));
+ ExpectNodes("_recv_t1_0,_recv_W2_0,t2");
+}
+
+TEST_F(SubgraphTest, FetchOutputs1) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'W2' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+ "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
+ "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
+ "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
+ EXPECT_EQ("OK", Subgraph("", "W2,input:1,t1,t2", "t2"));
+ ExpectNodes(
+ "W1,W2,input,t1,t2,_send_W2_0,_send_input_1,_send_t1_0,_send_t2_0");
+}
+
+TEST_F(SubgraphTest, FetchOutputs2) {
+ ExpectOK(
+ "node { name: 'W1' op: 'TestParams' }"
+ "node { name: 'W2' op: 'TestParams' }"
+ "node { name: 'input' op: 'TestInput' }"
+ "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+ "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
+ "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
+ "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
+ EXPECT_EQ("OK", Subgraph("", "t3_a", "t2"));
+ ExpectNodes("W1,W2,input,t1,t2,t3_a,_send_t3_a_0");
+}
+
+TEST_F(SubgraphTest, ChainOfFools) {
+ ExpectOK(
+ "node { name: 'a' op: 'TestParams' }"
+ "node { name: 'b' op: 'TestRelu' input: 'a'}"
+ "node { name: 'c' op: 'TestRelu' input: 'b'}"
+ "node { name: 'd' op: 'TestRelu' input: 'c'}"
+ "node { name: 'e' op: 'TestRelu' input: 'd'}"
+ "node { name: 'f' op: 'TestRelu' input: 'e'}");
+ EXPECT_EQ("OK", Subgraph("c:0", "b:0,e:0", ""));
+ ExpectNodes("a,b,_send_b_0,_recv_c_0,d,e,_send_e_0");
+ EXPECT_TRUE(HasEdge("a", 0, "b", 0));
+ EXPECT_TRUE(HasEdge("b", 0, "_send_b_0", 0));
+ EXPECT_TRUE(HasEdge("_recv_c_0", 0, "d", 0));
+ EXPECT_TRUE(HasEdge("d", 0, "e", 0));
+ EXPECT_TRUE(HasEdge("e", 0, "_send_e_0", 0));
+}
+
+static bool HasSubstr(const string& base, const string& substr) {
+ bool ok = StringPiece(base).contains(substr);
+ EXPECT_TRUE(ok) << base << ", expected substring " << substr;
+ return ok;
+}
+
+TEST_F(SubgraphTest, Errors) {
+ ExpectOK(
+ "node { name: 'a' op: 'TestParams' }"
+ "node { name: 'b' op: 'TestRelu' input: 'a'}"
+ "node { name: 'c' op: 'TestRelu' input: 'b'}"
+ "node { name: 'd' op: 'TestRelu' input: 'c'}"
+ "node { name: 'e' op: 'TestRelu' input: 'd'}"
+ "node { name: 'f' op: 'TestRelu' input: 'e'}");
+ // Duplicated feed and fetch
+ EXPECT_TRUE(
+ HasSubstr(Subgraph("c:0", "b:0,c:0", ""), "both fed and fetched"));
+ // Feed not found.
+ EXPECT_TRUE(HasSubstr(Subgraph("foo:0", "", ""), "unable to find"));
+ // Fetch not found.
+ EXPECT_TRUE(HasSubstr(Subgraph("", "foo:0", ""), "not found"));
+ // Target not found.
+ EXPECT_TRUE(HasSubstr(Subgraph("", "", "foo"), "not found"));
+}
+
+REGISTER_OP("In").Output("o: float");
+REGISTER_OP("Op").Input("i: float").Output("o: float");
+
+static void BM_Subgraph(int iters, int num_nodes) {
+ DeviceAttributes device_info;
+ device_info.set_name("/job:a/replica:0/task:0/cpu:0");
+ device_info.set_device_type(DeviceType(DEVICE_CPU).type());
+ device_info.set_incarnation(0);
+
+ testing::StopTiming();
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* last_node = nullptr;
+ for (int i = 0; i < num_nodes; i++) {
+ string name = strings::StrCat("N", i);
+ if (i > 0) {
+ last_node = ops::UnaryOp("Op", last_node, b.opts().WithName(name));
+ } else {
+ last_node = ops::SourceOp("In", b.opts().WithName(name));
+ }
+ }
+ TF_CHECK_OK(b.ToGraph(&g));
+ }
+
+ std::vector<string> fed;
+ if (num_nodes > 1000) {
+ fed.push_back(strings::StrCat("N", num_nodes - 1000));
+ }
+ std::vector<string> fetch;
+ std::vector<string> targets = {strings::StrCat("N", num_nodes - 1)};
+ testing::StartTiming();
+ while (--iters > 0) {
+ Graph* subgraph = new Graph(OpRegistry::Global());
+ CopyGraph(g, subgraph);
+ TF_CHECK_OK(subgraph::RewriteGraphForExecution(subgraph, fed, fetch,
+ targets, device_info));
+ delete subgraph;
+ }
+}
+BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/tensor_id.cc b/tensorflow/core/graph/tensor_id.cc
new file mode 100644
index 0000000000..f789110ff3
--- /dev/null
+++ b/tensorflow/core/graph/tensor_id.cc
@@ -0,0 +1,41 @@
+#include "tensorflow/core/graph/tensor_id.h"
+
+#include <string>
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+
+TensorId ParseTensorName(const string& name) {
+ return ParseTensorName(StringPiece(name.data(), name.size()));
+}
+
+TensorId ParseTensorName(StringPiece name) {
+ // Parse either a name, or a name:digits. To do so, we go backwards
+ // from the end of the string, skipping over a run of digits. If
+ // we hit a ':' character, then we know we are in the 'name:digits'
+ // regime. Otherwise, the output index is implicitly 0, and the whole
+ // name string forms the first part of the tensor name.
+ //
+ // Equivalent to matching with this regexp: ([^:]+):(\\d+)
+ const char* base = name.data();
+ const char* p = base + name.size() - 1;
+ int index = 0;
+ int mul = 1;
+ while (p > base && (*p >= '0' && *p <= '9')) {
+ index += ((*p - '0') * mul);
+ mul *= 10;
+ p--;
+ }
+ TensorId id;
+ if (p > base && *p == ':' && mul > 1) {
+ id.first = StringPiece(base, p - base);
+ id.second = index;
+ } else {
+ id.first = name;
+ id.second = 0;
+ }
+ return id;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/tensor_id.h b/tensorflow/core/graph/tensor_id.h
new file mode 100644
index 0000000000..f1f3846875
--- /dev/null
+++ b/tensorflow/core/graph/tensor_id.h
@@ -0,0 +1,28 @@
+#ifndef TENSORFLOW_GRAPH_TENSOR_ID_H_
+#define TENSORFLOW_GRAPH_TENSOR_ID_H_
+
+#include <string>
+
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+
+// Identifier for a tensor within a step.
+// first == operation_name, second == output_index
+// Note: does not own backing storage for name.
+struct TensorId : public std::pair<StringPiece, int> {
+ typedef std::pair<StringPiece, int> Base;
+
+ // Inherit the set of constructors.
+ using Base::pair;
+
+ string ToString() const { return strings::StrCat(first, ":", second); }
+};
+
+TensorId ParseTensorName(const string& name);
+TensorId ParseTensorName(StringPiece name);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_TENSOR_ID_H_
diff --git a/tensorflow/core/graph/tensor_id_test.cc b/tensorflow/core/graph/tensor_id_test.cc
new file mode 100644
index 0000000000..b945774cc3
--- /dev/null
+++ b/tensorflow/core/graph/tensor_id_test.cc
@@ -0,0 +1,77 @@
+#include "tensorflow/core/graph/tensor_id.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+static string ParseHelper(const string& n) {
+ TensorId id = ParseTensorName(n);
+ return strings::StrCat(id.first, ":", id.second);
+}
+
+TEST(TensorIdTest, ParseTensorName) {
+ EXPECT_EQ(ParseHelper("W1"), "W1:0");
+ EXPECT_EQ(ParseHelper("weights:0"), "weights:0");
+ EXPECT_EQ(ParseHelper("W1:1"), "W1:1");
+ EXPECT_EQ(ParseHelper("W1:17"), "W1:17");
+ EXPECT_EQ(ParseHelper("xyz1_17"), "xyz1_17:0");
+}
+
+static uint32 Skewed(random::SimplePhilox* rnd, int max_log) {
+ const uint32 space = 1 << (rnd->Rand32() % (max_log + 1));
+ return rnd->Rand32() % space;
+}
+
+static void BM_ParseTensorName(int iters, int arg) {
+ testing::StopTiming();
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ std::vector<string> names;
+ for (int i = 0; i < 100; i++) {
+ string name;
+ switch (arg) {
+ case 0: { // Generate random names
+ size_t len = Skewed(&rnd, 4);
+ while (name.size() < len) {
+ name += rnd.OneIn(4) ? '0' : 'a';
+ }
+ if (rnd.OneIn(3)) {
+ strings::StrAppend(&name, ":", rnd.Uniform(12));
+ }
+ break;
+ }
+ case 1:
+ name = "W1";
+ break;
+ case 2:
+ name = "t0003";
+ break;
+ case 3:
+ name = "weights";
+ break;
+ case 4:
+ name = "weights:17";
+ break;
+ default:
+ LOG(FATAL) << "Unexpected arg";
+ break;
+ }
+ names.push_back(name);
+ }
+ testing::StartTiming();
+ TensorId id;
+ int index = 0;
+ int sum = 0;
+ while (--iters > 0) {
+ id = ParseTensorName(names[index++ % names.size()]);
+ sum += id.second;
+ }
+ VLOG(2) << sum; // Prevent compiler from eliminating loop body
+}
+BENCHMARK(BM_ParseTensorName)->Arg(0)->Arg(1)->Arg(2)->Arg(3)->Arg(4);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
new file mode 100644
index 0000000000..e49d5e819a
--- /dev/null
+++ b/tensorflow/core/graph/testlib.cc
@@ -0,0 +1,299 @@
+#include "tensorflow/core/graph/testlib.h"
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+namespace test {
+namespace graph {
+
+Node* Send(Graph* g, Node* input, const string& tensor, const string& sender,
+ const uint64 sender_incarnation, const string& receiver) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Send")
+ .Input(input, 0)
+ .Attr("tensor_name", tensor)
+ .Attr("send_device", sender)
+ .Attr("send_device_incarnation",
+ static_cast<int64>(sender_incarnation))
+ .Attr("recv_device", receiver)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Recv(Graph* g, const string& tensor, const string& type,
+ const string& sender, const uint64 sender_incarnation,
+ const string& receiver) {
+ Node* ret;
+ DataType dtype;
+ CHECK(DataTypeFromString(type, &dtype));
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Recv")
+ .Attr("tensor_type", dtype)
+ .Attr("tensor_name", tensor)
+ .Attr("send_device", sender)
+ .Attr("send_device_incarnation",
+ static_cast<int64>(sender_incarnation))
+ .Attr("recv_device", receiver)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Constant(Graph* g, const Tensor& tensor) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Const")
+ .Attr("dtype", tensor.dtype())
+ .Attr("value", tensor)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Constant(Graph* g, const Tensor& tensor, const string& name) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(name, "Const")
+ .Attr("dtype", tensor.dtype())
+ .Attr("value", tensor)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Var(Graph* g, const DataType dtype, const TensorShape& shape) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Variable")
+ .Attr("dtype", dtype)
+ .Attr("shape", shape)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Assign(Graph* g, Node* var, Node* val) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Assign")
+ .Input(var)
+ .Input(val)
+ .Attr("use_locking", true)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes,
+ bool keep_dims) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), reduce)
+ .Input(data)
+ .Input(axes)
+ .Attr("keep_dims", keep_dims)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* QuantizeToUINT8(Graph* g, Node* data) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Quantize")
+ .Input(data)
+ .Attr("T", DT_QUINT8)
+ .Attr("max_range", 1.0f)
+ .Attr("min_range", -1.0f)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a,
+ bool transpose_b) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MatMul")
+ .Input(in0)
+ .Input(in1)
+ .Attr("transpose_a", transpose_a)
+ .Attr("transpose_b", transpose_b)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* RandomNumberGenerator(const string& op, Graph* g, Node* input,
+ DataType dtype) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), op)
+ .Input(input)
+ .Attr("dtype", dtype)
+ .Attr("seed", 0)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* RandomUniform(Graph* g, Node* input, DataType dtype) {
+ return RandomNumberGenerator("RandomUniform", g, input, dtype);
+}
+
+Node* RandomGaussian(Graph* g, Node* input, DataType dtype) {
+ return RandomNumberGenerator("RandomStandardNormal", g, input, dtype);
+}
+
+Node* RandomParameters(Graph* g, Node* input, DataType dtype) {
+ return RandomNumberGenerator("RandomParameters", g, input, dtype);
+}
+
+Node* Unary(Graph* g, const string& func, Node* input, int index) {
+ Node* ret;
+ TF_CHECK_OK(
+ NodeBuilder(g->NewName("n"), func).Input(input, index).Finalize(g, &ret));
+ return ret;
+}
+
+Node* Binary(Graph* g, const string& func, Node* in0, Node* in1) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), func)
+ .Input(in0)
+ .Input(in1)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins) {
+ Node* ret;
+ auto b = NodeBuilder(g->NewName("n"), func);
+ for (Node* n : ins) b = b.Input(n);
+ TF_CHECK_OK(b.Finalize(g, &ret));
+ return ret;
+}
+
+Node* Identity(Graph* g, Node* input, int index) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Identity")
+ .Input(input, index)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Add(Graph* g, Node* in0, Node* in1) { return Binary(g, "Add", in0, in1); }
+
+Node* Error(Graph* g, Node* input, const string& errmsg) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error")
+ .Input(input)
+ .Attr("message", errmsg)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type) {
+ DCHECK(out_type != invalid_type);
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "InvalidRefType")
+ .Attr("TIn", out_type)
+ .Attr("TOut", invalid_type)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Delay(Graph* g, Node* input, Microseconds delay_micros) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Delay")
+ .Input(input)
+ .Attr("micros", delay_micros.value())
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "NoOp")
+ .ControlInputs(control_inputs)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Switch(Graph* g, Node* in0, Node* in1) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Switch")
+ .Input(in0)
+ .Input(in1)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Enter(Graph* g, Node* input, const string& frame_name) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Enter")
+ .Input(input)
+ .Attr("frame_name", frame_name)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Exit(Graph* g, Node* input) {
+ Node* ret;
+ TF_CHECK_OK(
+ NodeBuilder(g->NewName("n"), "Exit").Input(input).Finalize(g, &ret));
+ return ret;
+}
+
+Node* Merge(Graph* g, Node* in0, Node* in1) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Merge")
+ .Input({in0, in1})
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in) {
+ std::vector<NodeBuilder::NodeOut> inputs;
+ inputs.reserve(remaining_in.size() + 1);
+ inputs.emplace_back(in0);
+ for (const string& in_name : remaining_in) {
+ inputs.emplace_back(in_name, 0, inputs[0].dt);
+ }
+
+ Node* ret;
+ TF_CHECK_OK(
+ NodeBuilder(g->NewName("n"), "Merge").Input(inputs).Finalize(g, &ret));
+ return ret;
+}
+
+Node* Next(Graph* g, const string& name, Node* input) {
+ Node* ret;
+ TF_CHECK_OK(
+ NodeBuilder(name, "NextIteration").Input(input).Finalize(g, &ret));
+ return ret;
+}
+
+Node* LoopCond(Graph* g, Node* input) {
+ Node* ret;
+ TF_CHECK_OK(
+ NodeBuilder(g->NewName("n"), "LoopCond").Input(input).Finalize(g, &ret));
+ return ret;
+}
+
+Node* Less(Graph* g, Node* in0, Node* in1) {
+ return Binary(g, "Less", in0, in1);
+}
+
+Node* Select(Graph* g, Node* c, Node* inx, Node* iny) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Select")
+ .Input(c)
+ .Input(inx)
+ .Input(iny)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Cast(Graph* g, Node* in, DataType dst) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Cast")
+ .Input(in)
+ .Attr("DstT", dst)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
+
+} // end namespace graph
+} // end namespace test
+} // end namespace tensorflow
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
new file mode 100644
index 0000000000..11905bbf6a
--- /dev/null
+++ b/tensorflow/core/graph/testlib.h
@@ -0,0 +1,141 @@
+// DEPRECATED: Use GraphDefBuilder instead.
+
+#ifndef TENSORFLOW_GRAPH_TESTLIB_H_
+#define TENSORFLOW_GRAPH_TESTLIB_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/types.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+namespace test {
+namespace graph {
+
+// Converts "g" into its corresponding GraphDef "def".
+// DEPRECATED: call g->ToGraphDef(def) instead.
+void ToGraphDef(Graph* g, GraphDef* def);
+
+// A few helpers to construct a graph.
+
+// Adds a node in "g" producing a constant "tensor".
+Node* Constant(Graph* g, const Tensor& tensor);
+Node* Constant(Graph* g, const Tensor& tensor, const string& name);
+
+// Adds a variable in "g" of the given "shape" and "dtype".
+Node* Var(Graph* g, const DataType dtype, const TensorShape& shape);
+
+// Adds an assign node in "g" which assigns "val" into "var".
+Node* Assign(Graph* g, Node* var, Node* val);
+
+// Adds a send node "g" sending "input" as a named "tensor" from
+// "sender" to "receiver".
+Node* Send(Graph* g, Node* input, const string& tensor, const string& sender,
+ const uint64 sender_incarnation, const string& receiver);
+
+// Adds a recv node in "g" receiving a named "tensor" from "sender"
+// to "receiver".
+Node* Recv(Graph* g, const string& tensor, const string& type,
+ const string& sender, const uint64 sender_incarnation,
+ const string& receiver);
+
+// Adds a reduction "node" in "g" doing sum(data, axes). "reduce" is
+// a reduction, e.g., Sum, Max, Min, Mean, etc.
+Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes,
+ bool keep_dims = false);
+
+// Adds a Matmul node in g doing in0.contract(in1).
+Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a,
+ bool transpose_b);
+
+// Adds a Quantize node into g that quantize floats into QUINT8. The range of
+// the input float tensor is assumed to be [-1, 1].
+Node* QuantizeToUINT8(Graph* g, Node* data);
+
+// Adds a unary function "func" "node" in "g" taking "input".
+Node* Unary(Graph* g, const string& func, Node* input, int index = 0);
+
+// Adds an identity node in "g" taking "input" and producing an
+// identity copy.
+Node* Identity(Graph* g, Node* input, int index = 0);
+
+// Adds a binary function "func" node in "g" taking "in0" and "in1".
+// Requires that "func" name an attr-style Op.
+Node* Binary(Graph* g, const string& func, Node* in0, Node* in1);
+
+// Adds a function "func" node in "g" taking inputs "ins".
+// Requires that "func" name an attr-style Op.
+Node* Multi(Graph* g, const string& func, gtl::ArraySlice<Node*> ins);
+
+// Adds a binary add node in "g" doing in0 + in1.
+Node* Add(Graph* g, Node* in0, Node* in1);
+
+// Generates random unit uniform distribution of the input shape.
+Node* RandomUniform(Graph* g, Node* input, DataType dtype);
+
+// Generates random unit normal distribution of the input shape.
+Node* RandomGaussian(Graph* g, Node* input, DataType dtype);
+
+// Generates random parameters from the truncated standard normal distribution
+// of the nput shape
+Node* RandomParameters(Graph* g, Node* input, DataType dtype);
+
+// Adds an error node in "g". The node's computation always
+// generates an error with the given error message "errmsg".
+Node* Error(Graph* g, Node* input, const string& errmsg);
+
+// Adds a node that generates a invalid ref output.
+Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type);
+
+// Adds a node in "g". Its Compute() sleeps a while and outputs the
+// input (i.e., same as identity).
+Node* Delay(Graph* g, Node* input, Microseconds delay_micros);
+
+// Adds a no-op "node" in "g", with control inputs from all nodes in
+// control_inputs vector.
+Node* NoOp(Graph* g, const std::vector<Node*>& control_inputs);
+
+// Adds a Switch node in "g". If "in1" is true, it forwards "in0" to
+// output 1. Otherwise, it forwards "in0" to output 0.
+Node* Switch(Graph* g, Node* in0, Node* in1);
+
+// Adds an Enter node in "g", which enters a new frame.
+Node* Enter(Graph* g, Node* input, const string& frame_name);
+
+// Adds an Exit node in "g", which exits a frame.
+Node* Exit(Graph* g, Node* input);
+
+// Adds a Merge node in "g" with two inputs "in0" and "in1".
+Node* Merge(Graph* g, Node* in0, Node* in1);
+
+// Adds a Merge node in "g". The first input is "in0", the remaining
+// inputs are only given by their names in remaining_in.
+Node* Merge(Graph* g, Node* in0, gtl::ArraySlice<string> remaining_in);
+
+// Adds a NextIteration node in "g", which makes its input available
+// to the next iteration.
+Node* Next(Graph* g, const string& name, Node* input);
+
+// Adds a LoopCond node in "g", representing the "pivot" termination
+// condition of a loop.
+Node* LoopCond(Graph* g, Node* input);
+
+// Adds a less node in "g", which returns true iff "in0" < "in1".
+Node* Less(Graph* g, Node* in0, Node* in1);
+
+// Adds a select node in "g", which outputs either "inx" or "iny"
+// depending on the boolean value of "c".
+Node* Select(Graph* g, Node* c, Node* inx, Node* iny);
+
+// Casts "in" into data type "dst".
+Node* Cast(Graph* g, Node* in, DataType dst);
+
+} // end namespace graph
+} // end namespace test
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_TESTLIB_H_
diff --git a/tensorflow/core/graph/types.h b/tensorflow/core/graph/types.h
new file mode 100644
index 0000000000..41400611a9
--- /dev/null
+++ b/tensorflow/core/graph/types.h
@@ -0,0 +1,17 @@
+#ifndef TENSORFLOW_GRAPH_TYPES_H_
+#define TENSORFLOW_GRAPH_TYPES_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/gtl/int_type.h"
+
+namespace tensorflow {
+
+// We model running time in microseconds.
+TF_LIB_GTL_DEFINE_INT_TYPE(Microseconds, int64);
+
+// We model size in bytes.
+TF_LIB_GTL_DEFINE_INT_TYPE(Bytes, int64);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_GRAPH_TYPES_H_
diff --git a/tensorflow/core/kernels/adjust_contrast_op.cc b/tensorflow/core/kernels/adjust_contrast_op.cc
new file mode 100644
index 0000000000..7cc0534354
--- /dev/null
+++ b/tensorflow/core/kernels/adjust_contrast_op.cc
@@ -0,0 +1,121 @@
+// See docs in ../ops/image_ops.cc
+#define EIGEN_USE_THREADS
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/adjust_contrast_op.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class AdjustContrastOp : public OpKernel {
+ public:
+ explicit AdjustContrastOp(OpKernelConstruction* context) : OpKernel(context) {
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& factor = context->input(1);
+ const Tensor& min_value = context->input(2);
+ const Tensor& max_value = context->input(3);
+ OP_REQUIRES(context, input.dims() >= 3,
+ errors::InvalidArgument("input must be at least 3-D, got shape",
+ input.shape().ShortDebugString()));
+ const int64 height = input.dim_size(input.dims() - 3);
+ const int64 width = input.dim_size(input.dims() - 2);
+ const int64 channels = input.dim_size(input.dims() - 1);
+
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor.shape()),
+ errors::InvalidArgument("contrast_factor must be scalar: ",
+ factor.shape().ShortDebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(min_value.shape()),
+ errors::InvalidArgument("min_value must be scalar: ",
+ min_value.shape().ShortDebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(max_value.shape()),
+ errors::InvalidArgument("max_value must be scalar: ",
+ max_value.shape().ShortDebugString()));
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+
+ Tensor mean_values;
+ OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<float>::value,
+ TensorShape(input.shape()),
+ &mean_values));
+
+ if (input.NumElements() > 0) {
+ const int64 batch = input.NumElements() / (height * width * channels);
+ const int64 shape[4] = {batch, height, width, channels};
+ functor::AdjustContrast<Device, T>()(
+ context->eigen_device<Device>(), input.shaped<T, 4>(shape),
+ factor.scalar<float>(), min_value.scalar<float>(),
+ max_value.scalar<float>(), mean_values.shaped<float, 4>(shape),
+ output->shaped<float, 4>(shape));
+ }
+ }
+};
+
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("AdjustContrast").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ AdjustContrastOp<CPUDevice, T>);
+
+REGISTER_KERNEL(uint8);
+REGISTER_KERNEL(int8);
+REGISTER_KERNEL(int16);
+REGISTER_KERNEL(int32);
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+
+#if GOOGLE_CUDA
+// Forward declarations of the function specializations for GPU (to prevent
+// building the GPU versions here, they will be built compiling _gpu.cu.cc).
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void AdjustContrast<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \
+ typename TTypes<float>::ConstScalar contrast_factor, \
+ typename TTypes<float>::ConstScalar min_value, \
+ typename TTypes<float>::ConstScalar max_value, \
+ typename TTypes<float, 4>::Tensor mean_values, \
+ typename TTypes<float, 4>::Tensor output); \
+ extern template struct AdjustContrast<GPUDevice, T>;
+
+DECLARE_GPU_SPEC(uint8);
+DECLARE_GPU_SPEC(int8);
+DECLARE_GPU_SPEC(int16);
+DECLARE_GPU_SPEC(int32);
+DECLARE_GPU_SPEC(float);
+DECLARE_GPU_SPEC(double);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER_GPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("AdjustContrast").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ AdjustContrastOp<GPUDevice, T>);
+REGISTER_GPU_KERNEL(uint8);
+REGISTER_GPU_KERNEL(int8);
+REGISTER_GPU_KERNEL(int16);
+REGISTER_GPU_KERNEL(int32);
+REGISTER_GPU_KERNEL(float);
+REGISTER_GPU_KERNEL(double);
+#undef REGISTER_GPU_KERNEL
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/adjust_contrast_op.h b/tensorflow/core/kernels/adjust_contrast_op.h
new file mode 100644
index 0000000000..2182b33c03
--- /dev/null
+++ b/tensorflow/core/kernels/adjust_contrast_op.h
@@ -0,0 +1,64 @@
+#ifndef TENSORFLOW_KERNELS_ADJUST_CONTRAST_OP_H_
+#define TENSORFLOW_KERNELS_ADJUST_CONTRAST_OP_H_
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+// Functor used by AdjustContrastOp to do the computations.
+template <typename Device, typename T>
+struct AdjustContrast {
+ void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
+ typename TTypes<float>::ConstScalar contrast_factor,
+ typename TTypes<float>::ConstScalar min_value,
+ typename TTypes<float>::ConstScalar max_value,
+ typename TTypes<float, 4>::Tensor mean_values,
+ typename TTypes<float, 4>::Tensor output) {
+ const int batch = input.dimension(0);
+ const int height = input.dimension(1);
+ const int width = input.dimension(2);
+ const int channels = input.dimension(3);
+
+ Eigen::array<int, 4> scalar_broadcast{{batch, height, width, channels}};
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ Eigen::array<int, 2> reduction_axis{{1, 2}};
+ Eigen::array<int, 4> scalar{{1, 1, 1, 1}};
+ Eigen::array<int, 4> broadcast_dims{{1, height, width, 1}};
+ Eigen::Tensor<int, 4>::Dimensions reshape_dims{{batch, 1, 1, channels}};
+#else
+ Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2> >
+ reduction_axis;
+ Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<1>,
+ Eigen::type2index<1>, Eigen::type2index<1> > scalar;
+ Eigen::IndexList<Eigen::type2index<1>, int, int, Eigen::type2index<1> >
+ broadcast_dims;
+ broadcast_dims.set(1, height);
+ broadcast_dims.set(2, width);
+ Eigen::IndexList<int, Eigen::type2index<1>, Eigen::type2index<1>, int>
+ reshape_dims;
+ reshape_dims.set(0, batch);
+ reshape_dims.set(3, channels);
+#endif
+ mean_values.device(d) = input.template cast<float>()
+ .mean(reduction_axis)
+ .eval()
+ .reshape(reshape_dims)
+ .broadcast(broadcast_dims);
+
+ auto contrast_factor_tensor =
+ contrast_factor.reshape(scalar).broadcast(scalar_broadcast);
+ auto adjusted =
+ (input.template cast<float>() - mean_values) * contrast_factor_tensor +
+ mean_values;
+ auto min_bcast = min_value.reshape(scalar).broadcast(scalar_broadcast);
+ auto max_bcast = max_value.reshape(scalar).broadcast(scalar_broadcast);
+ // TODO(wicke): This is rather slow and should be re-written as pure cuda.
+ output.device(d) = adjusted.cwiseMin(max_bcast).cwiseMax(min_bcast);
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_ADJUST_CONTRAST_OP_H_
diff --git a/tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc b/tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc
new file mode 100644
index 0000000000..75b177cf4d
--- /dev/null
+++ b/tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc
@@ -0,0 +1,43 @@
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+static Graph* BM_AdjustContrast(int batches, int width, int height) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor in(DT_UINT8, TensorShape({batches, width, height, 3}));
+ in.flat<uint8>().setRandom();
+ Tensor factor(DT_FLOAT, TensorShape({}));
+ factor.flat<float>().setConstant(1.2);
+ Tensor min_value(DT_FLOAT, TensorShape({}));
+ min_value.flat<float>().setConstant(7.);
+ Tensor max_value(DT_FLOAT, TensorShape({}));
+ max_value.flat<float>().setConstant(250.);
+
+ Node* ret;
+ NodeBuilder(g->NewName("n"), "AdjustContrast")
+ .Input(test::graph::Constant(g, in))
+ .Input(test::graph::Constant(g, factor))
+ .Input(test::graph::Constant(g, min_value))
+ .Input(test::graph::Constant(g, max_value))
+ .Finalize(g, &ret);
+ return g;
+}
+
+#define BM_AdjustContrastDev(DEVICE, B, W, H) \
+ static void BM_AdjustContrast_##DEVICE##_##B##_##W##_##H(int iters) { \
+ testing::ItemsProcessed(iters* B* W* H * 3); \
+ test::Benchmark(#DEVICE, BM_AdjustContrast(B, W, H)).Run(iters); \
+ } \
+ BENCHMARK(BM_AdjustContrast_##DEVICE##_##B##_##W##_##H);
+
+// Benchmark results as of cl/106323955
+// BM_AdjustContrast_cpu_1_299_299 3416770 22008951 100 11.6M items/s
+
+// BM_AdjustContrast_gpu_32_299_299 37117844 45512374 100 179.8M items/s
+BM_AdjustContrastDev(cpu, 1, 299, 299) BM_AdjustContrastDev(gpu, 32, 299, 299)
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/adjust_contrast_op_gpu.cu.cc b/tensorflow/core/kernels/adjust_contrast_op_gpu.cu.cc
new file mode 100644
index 0000000000..7a9b0726fd
--- /dev/null
+++ b/tensorflow/core/kernels/adjust_contrast_op_gpu.cu.cc
@@ -0,0 +1,22 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/adjust_contrast_op.h"
+
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+template struct functor::AdjustContrast<GPUDevice, uint8>;
+template struct functor::AdjustContrast<GPUDevice, int8>;
+template struct functor::AdjustContrast<GPUDevice, int16>;
+template struct functor::AdjustContrast<GPUDevice, int32>;
+template struct functor::AdjustContrast<GPUDevice, int64>;
+template struct functor::AdjustContrast<GPUDevice, float>;
+template struct functor::AdjustContrast<GPUDevice, double>;
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/adjust_contrast_op_test.cc b/tensorflow/core/kernels/adjust_contrast_op_test.cc
new file mode 100644
index 0000000000..67891e4fa1
--- /dev/null
+++ b/tensorflow/core/kernels/adjust_contrast_op_test.cc
@@ -0,0 +1,88 @@
+#include "tensorflow/core/framework/allocator.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+class AdjustContrastOpTest : public OpsTestBase {
+ protected:
+ void MakeOp() { RequireDefaultOps(); }
+};
+
+TEST_F(AdjustContrastOpTest, Simple_1113) {
+ RequireDefaultOps();
+ EXPECT_OK(NodeDefBuilder("adjust_constrast_op", "AdjustContrast")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("T", DT_FLOAT)
+ .Finalize(node_def()));
+ EXPECT_OK(InitOp());
+ AddInputFromArray<float>(TensorShape({1, 1, 1, 3}), {-1, 2, 3});
+ AddInputFromArray<float>(TensorShape({}), {1.0});
+ AddInputFromArray<float>(TensorShape({}), {0.0});
+ AddInputFromArray<float>(TensorShape({}), {2.0});
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1, 3}));
+ test::FillValues<float>(&expected, {0, 2, 2});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(AdjustContrastOpTest, Simple_1223) {
+ RequireDefaultOps();
+ EXPECT_OK(NodeDefBuilder("adjust_constrast_op", "AdjustContrast")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("T", DT_FLOAT)
+ .Finalize(node_def()));
+ EXPECT_OK(InitOp());
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 3}),
+ {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12});
+ AddInputFromArray<float>(TensorShape({}), {0.2});
+ AddInputFromArray<float>(TensorShape({}), {0.0});
+ AddInputFromArray<float>(TensorShape({}), {10.0});
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 2, 3}));
+ test::FillValues<float>(
+ &expected, {2.2, 6.2, 10, 2.4, 6.4, 10, 2.6, 6.6, 10, 2.8, 6.8, 10});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(AdjustContrastOpTest, Big_99x99x3) {
+ EXPECT_OK(NodeDefBuilder("adjust_constrast_op", "AdjustContrast")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("T", DT_FLOAT)
+ .Finalize(node_def()));
+ EXPECT_OK(InitOp());
+
+ std::vector<float> values;
+ for (int i = 0; i < 99 * 99 * 3; ++i) {
+ values.push_back(i % 255);
+ }
+
+ AddInputFromArray<float>(TensorShape({1, 99, 99, 3}), values);
+ AddInputFromArray<float>(TensorShape({}), {0.2});
+ AddInputFromArray<float>(TensorShape({}), {0});
+ AddInputFromArray<float>(TensorShape({}), {255});
+ ASSERT_OK(RunOpKernel());
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/aggregate_ops.cc b/tensorflow/core/kernels/aggregate_ops.cc
new file mode 100644
index 0000000000..426e868735
--- /dev/null
+++ b/tensorflow/core/kernels/aggregate_ops.cc
@@ -0,0 +1,238 @@
+// See docs in ../ops/math_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/aggregate_ops.h"
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/register_types.h"
+
+#include "tensorflow/core/platform/logging.h"
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class AddNOp : public OpKernel {
+ public:
+ explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ if (!ctx->ValidateInputsAreSameShape(this)) return;
+
+ const Tensor& input0 = ctx->input(0);
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input0.shape(), &output));
+ auto To = output->flat<T>();
+
+ const int num = ctx->num_inputs();
+ if (num == 1) {
+ *output = input0;
+ return;
+ }
+
+#define I(IDX) ctx->input(IDX).flat<T>()
+
+#if defined(PLATFORM_POSIX_ANDROID) || defined(PLATFORM_GOOGLE_ANDROID)
+ // On Android, we only support additions of two arguments, so we
+ // can reduce the number of template instantiations.
+ OP_REQUIRES(ctx, num == 2,
+ errors::InvalidArgument("Only additions of two arguments "
+ "supported. Num inputs: ",
+ num));
+ functor::Add2Functor<Device, T> functor2;
+ functor2(ctx->template eigen_device<Device>(), To, I(0), I(1));
+#else
+ static const int kWidth = 8;
+ int r = num % kWidth;
+
+ switch (r) {
+ case 2: {
+ functor::Add2Functor<Device, T> functor2;
+ functor2(ctx->template eigen_device<Device>(), To, I(0), I(1));
+ break;
+ }
+ case 3: {
+ functor::Add3Functor<Device, T> functor3;
+ functor3(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2));
+ break;
+ }
+ case 4: {
+ functor::Add4Functor<Device, T> functor4;
+ functor4(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
+ I(3));
+ break;
+ }
+ case 5: {
+ functor::Add5Functor<Device, T> functor5;
+ functor5(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
+ I(3), I(4));
+ break;
+ }
+ case 6: {
+ functor::Add6Functor<Device, T> functor6;
+ functor6(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
+ I(3), I(4), I(5));
+ break;
+ }
+ case 7: {
+ functor::Add7Functor<Device, T> functor7;
+ functor7(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
+ I(3), I(4), I(5), I(6));
+ break;
+ }
+ case 0: {
+ functor::Add8Functor<Device, T> functor8;
+ functor8(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
+ I(3), I(4), I(5), I(6), I(7));
+ r = 8;
+ break;
+ }
+ case 1: {
+ functor::Add9Functor<Device, T> functor9;
+ functor9(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
+ I(3), I(4), I(5), I(6), I(7), I(8));
+ r = 9;
+ break;
+ }
+ }
+
+ for (; r < num; r += kWidth) {
+ functor::Add8pFunctor<Device, T> functor8p;
+ functor8p(ctx->template eigen_device<Device>(), To, I(r), I(r + 1),
+ I(r + 2), I(r + 3), I(r + 4), I(r + 5), I(r + 6), I(r + 7));
+ }
+#endif // defined(PLATFORM_POSIX_ANDROID) || defined(PLATFORM_GOOGLE_ANDROID)
+
+#undef I
+ }
+};
+
+// Partial specializations for a CPUDevice, that uses the Eigen implementation
+// from AddNEigenImpl.
+namespace functor {
+template <typename T>
+struct Add2Functor<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2) {
+ Add2EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2);
+ }
+};
+template <typename T>
+struct Add3Functor<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3) {
+ Add3EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3);
+ }
+};
+template <typename T>
+struct Add4Functor<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4) {
+ Add4EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4);
+ }
+};
+template <typename T>
+struct Add5Functor<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5) {
+ Add5EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5);
+ }
+};
+template <typename T>
+struct Add6Functor<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5,
+ typename TTypes<T>::ConstFlat in6) {
+ Add6EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6);
+ }
+};
+template <typename T>
+struct Add7Functor<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5,
+ typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7) {
+ Add7EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7);
+ }
+};
+
+template <typename T>
+struct Add8Functor<CPUDevice, T> {
+ void operator()(
+ const CPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
+ Add8EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7, in8);
+ }
+};
+
+template <typename T>
+struct Add8pFunctor<CPUDevice, T> {
+ void operator()(
+ const CPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
+ Add8pEigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7, in8);
+ }
+};
+
+template <typename T>
+struct Add9Functor<CPUDevice, T> {
+ void operator()(
+ const CPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
+ typename TTypes<T>::ConstFlat in9) {
+ Add9EigenImpl<CPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7, in8, in9);
+ }
+};
+
+} // namespace functor
+
+#define REGISTER_ADDN(type, dev) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("AddN").Device(DEVICE_##dev).TypeConstraint<type>("T"), \
+ AddNOp<dev##Device, type>)
+
+#define REGISTER_ADDN_CPU(type) REGISTER_ADDN(type, CPU)
+
+TF_CALL_NUMBER_TYPES(REGISTER_ADDN_CPU);
+#undef REGISTER_ADDN_CPU
+
+#if GOOGLE_CUDA
+REGISTER_ADDN(float, GPU);
+#endif // GOOGLE_CUDA
+
+#undef REGISTER_ADDN
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/aggregate_ops.h b/tensorflow/core/kernels/aggregate_ops.h
new file mode 100644
index 0000000000..2214901970
--- /dev/null
+++ b/tensorflow/core/kernels/aggregate_ops.h
@@ -0,0 +1,211 @@
+#ifndef TENSORFLOW_KERNELS_AGGREGATE_OPS_H_
+#define TENSORFLOW_KERNELS_AGGREGATE_OPS_H_
+
+// Functor definitions for Aggregate ops, must be compilable by nvcc.
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T>
+struct Add2Functor {
+ void operator()(const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2);
+};
+
+template <typename Device, typename T>
+struct Add2EigenImpl {
+ static void Compute(const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2) {
+ out.device(d) = in1 + in2;
+ }
+};
+
+template <typename Device, typename T>
+struct Add3Functor {
+ void operator()(const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3);
+};
+
+template <typename Device, typename T>
+struct Add3EigenImpl {
+ static void Compute(const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3) {
+ out.device(d) = in1 + in2 + in3;
+ }
+};
+
+template <typename Device, typename T>
+struct Add4Functor {
+ void operator()(const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4);
+};
+
+template <typename Device, typename T>
+struct Add4EigenImpl {
+ static void Compute(const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4) {
+ out.device(d) = in1 + in2 + in3 + in4;
+ }
+};
+
+template <typename Device, typename T>
+struct Add5Functor {
+ void operator()(const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5);
+};
+
+template <typename Device, typename T>
+struct Add5EigenImpl {
+ static void Compute(const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5) {
+ out.device(d) = in1 + in2 + in3 + in4 + in5;
+ }
+};
+
+template <typename Device, typename T>
+struct Add6Functor {
+ void operator()(const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5,
+ typename TTypes<T>::ConstFlat in6);
+};
+
+template <typename Device, typename T>
+struct Add6EigenImpl {
+ static void Compute(const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5,
+ typename TTypes<T>::ConstFlat in6) {
+ out.device(d) = in1 + in2 + in3 + in4 + in5 + in6;
+ }
+};
+
+template <typename Device, typename T>
+struct Add7Functor {
+ void operator()(const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5,
+ typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7);
+};
+
+template <typename Device, typename T>
+struct Add7EigenImpl {
+ static void Compute(const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5,
+ typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7) {
+ out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7;
+ }
+};
+
+template <typename Device, typename T>
+struct Add8Functor {
+ void operator()(
+ const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8);
+};
+
+template <typename Device, typename T>
+struct Add8EigenImpl {
+ static void Compute(
+ const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
+ out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8;
+ }
+};
+
+// Add8p is like Add8 except the underlying implementation should +=
+// rather than assign to the output.
+template <typename Device, typename T>
+struct Add8pFunctor {
+ void operator()(
+ const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8);
+};
+
+template <typename Device, typename T>
+struct Add8pEigenImpl {
+ static void Compute(
+ const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
+ out.device(d) += in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8;
+ }
+};
+
+template <typename Device, typename T>
+struct Add9Functor {
+ void operator()(
+ const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
+ typename TTypes<T>::ConstFlat in9);
+};
+
+template <typename Device, typename T>
+struct Add9EigenImpl {
+ static void Compute(
+ const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
+ typename TTypes<T>::ConstFlat in9) {
+ out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8 + in9;
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_AGGREGATE_OPS_H_
diff --git a/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc b/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc
new file mode 100644
index 0000000000..5cf2934ac1
--- /dev/null
+++ b/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc
@@ -0,0 +1,141 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/aggregate_ops.h"
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// Partial specialization for a GPUDevice, that uses the Eigen implementation.
+namespace functor {
+template <typename T>
+struct Add2Functor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2) {
+ Add2EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2);
+ }
+};
+
+template <typename T>
+struct Add3Functor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3) {
+ Add3EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3);
+ }
+};
+
+template <typename T>
+struct Add4Functor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4) {
+ Add4EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4);
+ }
+};
+
+template <typename T>
+struct Add5Functor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5) {
+ Add5EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5);
+ }
+};
+
+template <typename T>
+struct Add6Functor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5,
+ typename TTypes<T>::ConstFlat in6) {
+ Add6EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6);
+ }
+};
+
+template <typename T>
+struct Add7Functor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5,
+ typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7) {
+ Add7EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7);
+ }
+};
+
+template <typename T>
+struct Add8Functor<GPUDevice, T> {
+ void operator()(
+ const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
+ Add8EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7, in8);
+ }
+};
+
+template <typename T>
+struct Add8pFunctor<GPUDevice, T> {
+ void operator()(
+ const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
+ Add8pEigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7, in8);
+ }
+};
+
+template <typename T>
+struct Add9Functor<GPUDevice, T> {
+ void operator()(
+ const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
+ typename TTypes<T>::ConstFlat in9) {
+ Add9EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7, in8, in9);
+ }
+};
+
+} // end namespace functor
+
+// Instantiate the GPU implementation for float.
+template struct functor::Add2Functor<GPUDevice, float>;
+template struct functor::Add3Functor<GPUDevice, float>;
+template struct functor::Add4Functor<GPUDevice, float>;
+template struct functor::Add5Functor<GPUDevice, float>;
+template struct functor::Add6Functor<GPUDevice, float>;
+template struct functor::Add7Functor<GPUDevice, float>;
+template struct functor::Add8Functor<GPUDevice, float>;
+template struct functor::Add8pFunctor<GPUDevice, float>;
+template struct functor::Add9Functor<GPUDevice, float>;
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/argmax_op.cc b/tensorflow/core/kernels/argmax_op.cc
new file mode 100644
index 0000000000..0845eebf09
--- /dev/null
+++ b/tensorflow/core/kernels/argmax_op.cc
@@ -0,0 +1,163 @@
+// See docs in ../ops/math_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif // GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/argmax_op.h"
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T, typename ArgFunctor>
+class ArgOp : public OpKernel {
+ public:
+ explicit ArgOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& dimension = context->input(1);
+
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(dimension.shape()),
+ errors::InvalidArgument(
+ "dim must be a scalar, but received tensor of shape: ",
+ dimension.shape().DebugString()));
+
+ const int32 dim = dimension.scalar<int32>()();
+ const int input_dims = input.dims();
+
+ OP_REQUIRES(context, dim >= 0, errors::InvalidArgument("dim must be >= 0"));
+ OP_REQUIRES(context, dim < input_dims,
+ errors::InvalidArgument("Minimum tensor rank: ", dim,
+ " but got: ", input_dims));
+
+ TensorShape output_shape;
+ TensorShape input_shape = input.shape();
+ for (int d = 0; d < input_dims - 1; ++d) {
+ output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1));
+ }
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+
+#define HANDLE_DIM(NDIM) \
+ case NDIM: \
+ ArgFunctor::Reduce##NDIM(context->eigen_device<Device>(), \
+ input.tensor<T, NDIM>(), dim, \
+ output->tensor<int64, NDIM - 1>()); \
+ break;
+
+ switch (input_dims) {
+ HANDLE_DIM(1);
+ HANDLE_DIM(2);
+ HANDLE_DIM(3);
+ HANDLE_DIM(4);
+ HANDLE_DIM(5);
+
+ default:
+ OP_REQUIRES(context, false,
+ errors::InvalidArgument(
+ "ArgOp : Unhandled input dimensions: ", input_dims));
+ }
+ }
+#undef HANDLE_DIM
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ArgOp);
+};
+
+template <typename Device, typename T>
+class ArgMaxOp : public ArgOp<Device, T, functor::ArgMax<Device, T> > {
+ public:
+ explicit ArgMaxOp(OpKernelConstruction* context)
+ : ArgOp<Device, T, functor::ArgMax<Device, T> >(context) {}
+};
+
+template <typename Device, typename T>
+class ArgMinOp : public ArgOp<Device, T, functor::ArgMin<Device, T> > {
+ public:
+ explicit ArgMinOp(OpKernelConstruction* context)
+ : ArgOp<Device, T, functor::ArgMin<Device, T> >(context) {}
+};
+
+#define REGISTER_ARGMAX(type) \
+ REGISTER_KERNEL_BUILDER(Name("ArgMax") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("dimension"), \
+ ArgMaxOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER(Name("ArgMin") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("dimension"), \
+ ArgMinOp<CPUDevice, type>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_ARGMAX);
+
+#if GOOGLE_CUDA
+
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+
+#define DECLARE_GPU_SPEC(T, Dims) \
+ template <> \
+ void ArgMax<GPUDevice, T>::Reduce##Dims( \
+ const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
+ const int32 dimension, typename TTypes<int64, Dims - 1>::Tensor output); \
+ template <> \
+ void ArgMin<GPUDevice, T>::Reduce##Dims( \
+ const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
+ const int32 dimension, typename TTypes<int64, Dims - 1>::Tensor output);
+
+#define DECLARE_GPU_SPECS(T) \
+ DECLARE_GPU_SPEC(T, 1); \
+ DECLARE_GPU_SPEC(T, 2); \
+ DECLARE_GPU_SPEC(T, 3); \
+ DECLARE_GPU_SPEC(T, 4); \
+ DECLARE_GPU_SPEC(T, 5);
+
+#define DECLARE_GPU_CLASS(T) \
+ extern template struct ArgMax<GPUDevice, T>; \
+ extern template struct ArgMin<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS);
+
+#undef DECLARE_GPU_SPECS
+#undef DECLARE_GPU_CLASS
+
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER_ARGMAX_GPU(type) \
+ REGISTER_KERNEL_BUILDER(Name("ArgMax") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("dimension"), \
+ ArgMaxOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER(Name("ArgMin") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("dimension"), \
+ ArgMinOp<GPUDevice, type>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_ARGMAX_GPU);
+
+#undef REGISTER_ARGMAX_GPU
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/argmax_op.h b/tensorflow/core/kernels/argmax_op.h
new file mode 100644
index 0000000000..41734f3254
--- /dev/null
+++ b/tensorflow/core/kernels/argmax_op.h
@@ -0,0 +1,55 @@
+#ifndef TENSORFLOW_KERNELS_ARGMAX_OP_H_
+#define TENSORFLOW_KERNELS_ARGMAX_OP_H_
+// Generator definition for ArgMaxOp, must be compilable by nvcc.
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+namespace functor {
+
+template <typename Device, typename T>
+struct ArgMax {
+#define DECLARE_COMPUTE_SPEC(Dims) \
+ EIGEN_ALWAYS_INLINE static void Reduce##Dims( \
+ const Device& d, typename TTypes<T, Dims>::ConstTensor input, \
+ const int32 dimension, \
+ typename TTypes<int64, Dims - 1>::Tensor output) { \
+ output.device(d) = input.argmax(dimension).template cast<int64>(); \
+ }
+
+ DECLARE_COMPUTE_SPEC(1);
+ DECLARE_COMPUTE_SPEC(2);
+ DECLARE_COMPUTE_SPEC(3);
+ DECLARE_COMPUTE_SPEC(4);
+ DECLARE_COMPUTE_SPEC(5);
+
+#undef DECLARE_COMPUTE_SPEC
+};
+
+template <typename Device, typename T>
+struct ArgMin {
+#define DECLARE_COMPUTE_SPEC(Dims) \
+ EIGEN_ALWAYS_INLINE static void Reduce##Dims( \
+ const Device& d, typename TTypes<T, Dims>::ConstTensor input, \
+ const int32 dimension, \
+ typename TTypes<int64, Dims - 1>::Tensor output) { \
+ output.device(d) = input.argmin(dimension).template cast<int64>(); \
+ }
+
+ DECLARE_COMPUTE_SPEC(1);
+ DECLARE_COMPUTE_SPEC(2);
+ DECLARE_COMPUTE_SPEC(3);
+ DECLARE_COMPUTE_SPEC(4);
+ DECLARE_COMPUTE_SPEC(5);
+
+#undef DECLARE_COMPUTE_SPEC
+};
+
+} // namespace functor
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_ARGMAX_OP_H_
diff --git a/tensorflow/core/kernels/argmax_op_gpu.cu.cc b/tensorflow/core/kernels/argmax_op_gpu.cu.cc
new file mode 100644
index 0000000000..6c91fc2c86
--- /dev/null
+++ b/tensorflow/core/kernels/argmax_op_gpu.cu.cc
@@ -0,0 +1,20 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/argmax_op.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#define DEFINE_GPU_SPEC(T) \
+ template struct functor::ArgMax<GPUDevice, T>; \
+ template struct functor::ArgMin<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/assign_op.h b/tensorflow/core/kernels/assign_op.h
new file mode 100644
index 0000000000..3306f1eeaa
--- /dev/null
+++ b/tensorflow/core/kernels/assign_op.h
@@ -0,0 +1,92 @@
+#ifndef TENSORFLOW_KERNELS_ASSIGN_OP_H_
+#define TENSORFLOW_KERNELS_ASSIGN_OP_H_
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+// TODO(jeff): Get rid of use_exclusive_lock_ option
+
+// Computes *input[0] = input[1]
+class AssignOp : public OpKernel {
+ public:
+ explicit AssignOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr("use_locking", &use_exclusive_lock_));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("validate_shape", &validate_shape_));
+ OP_REQUIRES(context, IsRefType(context->input_type(0)),
+ errors::InvalidArgument("lhs input needs to be a ref type"));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ Tensor rhs = context->input(1);
+
+ // We always return the input ref.
+ context->forward_ref_input_to_ref_output(0, 0);
+
+ // If the left hand side is not initialized, or the shape of the
+ // right-hand side is different than the left hand side, we need
+ // to allocate a new tensor.
+ {
+ mutex_lock l(*context->input_ref_mutex(0));
+
+ Tensor old_lhs = context->mutable_input(0, true);
+
+ if (validate_shape_) {
+ OP_REQUIRES(
+ context, old_lhs.shape().IsSameSize(rhs.shape()),
+ errors::InvalidArgument(
+ "Assign requires shapes of both tensors to match. lhs shape= ",
+ old_lhs.shape().ShortDebugString(), " rhs shape= ",
+ rhs.shape().ShortDebugString()));
+ }
+
+ const bool same_shape = old_lhs.shape().IsSameSize(rhs.shape());
+ if (!old_lhs.IsInitialized() || !same_shape) {
+ // Create new tensor whose shape matches the right hand side
+ // and copy then hand off to lhs.
+ // We can't always know how this value will be used downstream,
+ // so make conservative assumptions in specifying the memory
+ // allocation attributes.
+ AllocatorAttributes attr;
+ attr.set_gpu_compatible(true);
+ PersistentTensor copy;
+ Tensor* copyTensor = nullptr;
+ OP_REQUIRES_OK(
+ context, context->allocate_persistent(old_lhs.dtype(), rhs.shape(),
+ &copy, &copyTensor, attr));
+ Copy(context, copyTensor, rhs);
+ context->replace_ref_input(0, *copyTensor, true);
+ return;
+ }
+
+ // The tensor has already been initialized and the right hand side
+ // matches the left hand side's shape.
+ if (use_exclusive_lock_) {
+ Copy(context, &old_lhs, rhs);
+ return;
+ }
+ }
+
+ // The tensor has already been initialized and the right hand side
+ // matches the left hand side's shape. We have been told to do the
+ // copy outside the lock.
+ Tensor old_unlocked_lhs = context->mutable_input(0, false);
+ Copy(context, &old_unlocked_lhs, rhs);
+ }
+
+ virtual void Copy(OpKernelContext* context, Tensor* lhs,
+ const Tensor& rhs) = 0;
+
+ bool use_exclusive_lock_;
+ bool validate_shape_;
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_ASSIGN_OP_H_
diff --git a/tensorflow/core/kernels/attention_ops.cc b/tensorflow/core/kernels/attention_ops.cc
new file mode 100644
index 0000000000..28763f65a4
--- /dev/null
+++ b/tensorflow/core/kernels/attention_ops.cc
@@ -0,0 +1,92 @@
+// See docs in ../ops/attention_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
+
+namespace tensorflow {
+
+class ExtractGlimpseOp : public OpKernel {
+ public:
+ explicit ExtractGlimpseOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("normalized", &normalized_));
+ OP_REQUIRES_OK(context, context->GetAttr("centered", &centered_));
+ OP_REQUIRES_OK(context, context->GetAttr("uniform_noise", &uniform_noise_));
+ }
+
+ // Expect input tensor of rank 4 with dimensions (batch_size, height, width,
+ // depth).
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const TensorShape input_shape = input.shape();
+ const int32 num_dims = input_shape.dims();
+ OP_REQUIRES(
+ context, num_dims == 4,
+ errors::InvalidArgument(
+ "input must be 4-dimensional (batch_size, height, width, depth)",
+ input_shape.ShortDebugString()));
+
+ const int64 batch_size = input_shape.dim_size(0);
+
+ const Tensor& window_size = context->input(1);
+ OP_REQUIRES(context, (window_size.shape().dims() == 1) &&
+ window_size.shape().dim_size(0) == 2,
+ errors::InvalidArgument(
+ "input must be a vector of size 2 (height, width)",
+ window_size.shape().ShortDebugString()));
+
+ const int64 output_height = window_size.tensor<int, 1>()(0);
+ const int64 output_width = window_size.tensor<int, 1>()(1);
+ TensorShape output_shape = input_shape;
+ output_shape.set_dim(1, output_height);
+ output_shape.set_dim(2, output_width);
+
+ const Tensor& offsets = context->input(2);
+ OP_REQUIRES(context, offsets.shape().dims() == 2,
+ errors::InvalidArgument("input must be a matrix",
+ offsets.shape().ShortDebugString()));
+ OP_REQUIRES(context, offsets.shape().dim_size(0) == batch_size,
+ errors::InvalidArgument("first dimension should be batch",
+ offsets.shape().ShortDebugString()));
+ OP_REQUIRES(
+ context, offsets.shape().dim_size(1) == 2,
+ errors::InvalidArgument("second dimension should be of size 2 (y,x)",
+ offsets.shape().ShortDebugString()));
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+
+ std::vector<Eigen::IndexPair<float> > offset_vec;
+ offset_vec.reserve(batch_size);
+ for (int i = 0; i < batch_size; ++i) {
+ float offset_y = offsets.tensor<float, 2>()(i, 0);
+ float offset_x = offsets.tensor<float, 2>()(i, 1);
+ // Eigen::ExtractGlimpses expects offsets as (x,y), whereas the
+ // calling TensorFlow operates with (y,x) as indices.
+ offset_vec.push_back(Eigen::IndexPair<float>(offset_x, offset_y));
+ }
+
+ output->tensor<float, 4>().swap_layout().device(
+ context->eigen_cpu_device()) =
+ Eigen::ExtractGlimpses(input.tensor<float, 4>().swap_layout(),
+ output_width, output_height, offset_vec,
+ normalized_, centered_, uniform_noise_);
+ }
+
+ private:
+ bool normalized_;
+ bool centered_;
+ bool uniform_noise_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExtractGlimpse").Device(DEVICE_CPU),
+ ExtractGlimpseOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/avgpooling_op.cc b/tensorflow/core/kernels/avgpooling_op.cc
new file mode 100644
index 0000000000..26f98ffbcd
--- /dev/null
+++ b/tensorflow/core/kernels/avgpooling_op.cc
@@ -0,0 +1,418 @@
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/avgpooling_op.h"
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/pooling_ops_common.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
+
+#if GOOGLE_CUDA
+#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
+#include "tensorflow/core/kernels/pooling_ops_common_gpu.h"
+#endif // GOOGLE_CUDA
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class AvgPoolingOp : public UnaryOp<T> {
+ public:
+ explicit AvgPoolingOp(OpKernelConstruction* context) : UnaryOp<T>(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window stride field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = context->input(0);
+ PoolParameters params{context, ksize_, stride_, padding_,
+ tensor_in.shape()};
+ if (!context->status().ok()) {
+ return;
+ }
+ OP_REQUIRES(context, params.depth_window == 1,
+ errors::Unimplemented(
+ "Non-spatial pooling is not "
+ "yet supported. Volunteers? :)"));
+
+ // For avgpooling, tensor_in should have 4 dimensions.
+ OP_REQUIRES(context, tensor_in.dims() == 4,
+ errors::InvalidArgument("tensor_in must be 4-dimensional"));
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, params.forward_output_shape(), &output));
+
+ if (std::is_same<Device, GPUDevice>::value) {
+ Eigen::PaddingType pt = BrainPadding2EigenPadding(padding_);
+ functor::SpatialAvgPooling<Device, T>()(
+ context->eigen_device<Device>(), output->tensor<T, 4>(),
+ tensor_in.tensor<T, 4>(), params.window_rows, params.window_cols,
+ params.row_stride, params.col_stride, pt);
+ } else {
+ SpatialAvgPool<Device, T>(context, output, tensor_in, params, padding_);
+ }
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("AvgPool")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ AvgPoolingOp<CPUDevice, float>);
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void SpatialAvgPooling<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4>::Tensor output, \
+ typename TTypes<T, 4>::ConstTensor input, int window_rows, \
+ int window_cols, int row_stride, int col_stride, \
+ const Eigen::PaddingType& padding); \
+ extern template struct SpatialAvgPooling<GPUDevice, T>;
+
+DECLARE_GPU_SPEC(float);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+REGISTER_KERNEL_BUILDER(Name("AvgPool")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T"),
+ AvgPoolingOp<GPUDevice, float>);
+#endif // GOOGLE_CUDA
+
+// The operation to compute AvgPool gradients.
+// It takes two inputs:
+// - The original input tensor shape
+// - Backprop tensor for output
+// It produces one output: backprop tensor for input.
+template <typename Device, class T>
+class AvgPoolingGradOp : public OpKernel {
+ public:
+ explicit AvgPoolingGradOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in_shape = context->input(0);
+ const Tensor& out_backprop = context->input(1);
+ // For avgpooling, tensor_in_shape should have 1 dimension, and 4 elements.
+ OP_REQUIRES(context, tensor_in_shape.dims() == 1 &&
+ tensor_in_shape.NumElements() == 4,
+ errors::InvalidArgument(
+ "out_backprop must be 1-dimensional and 4 "
+ "elements"));
+ // For avgpooling, out_backprop should have 4 dimensions.
+ OP_REQUIRES(context, out_backprop.dims() == 4,
+ errors::InvalidArgument("out_backprop must be 4-dimensional"));
+ const int64 out_backprop_batch = out_backprop.dim_size(0);
+ const int64 out_backprop_rows = out_backprop.dim_size(1);
+ const int64 out_backprop_cols = out_backprop.dim_size(2);
+ const int64 out_backprop_depth = out_backprop.dim_size(3);
+
+ TensorShape output_shape;
+ auto shape_vec = tensor_in_shape.vec<int32>();
+ for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) {
+ output_shape.AddDim(shape_vec(i));
+ }
+ const int64 in_rows = output_shape.dim_size(1);
+ const int64 in_cols = output_shape.dim_size(2);
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+ output->flat<T>().setZero();
+
+ const int window_rows = ksize_[1];
+ const int window_cols = ksize_[2];
+ const int depth_window = ksize_[3];
+
+ const int row_stride = stride_[1];
+ const int col_stride = stride_[2];
+
+ // We (will) use different code for spatial pooling and
+ // non-spatial pooling.
+ //
+ // Spatial pooling is when depth_window = 1
+ OP_REQUIRES(context, depth_window == 1,
+ errors::Unimplemented(
+ "Non-spatial pooling is not "
+ "yet supported. Volunteers? :)"));
+
+ int out_height, out_width, pad_rows, pad_cols;
+ OP_REQUIRES_OK(
+ context, Get2dOutputSize(in_rows, in_cols, window_rows, window_cols,
+ row_stride, col_stride, padding_, &out_height,
+ &out_width, &pad_rows, &pad_cols));
+
+ const T* out_backprop_ptr = out_backprop.flat<T>().data();
+ T* input_backprop_ptr = output->flat<T>().data();
+
+ for (int64 b = 0; b < out_backprop_batch; ++b) {
+ for (int64 r = 0; r < out_backprop_rows; ++r) {
+ // Calculates row broadcast size. For SAME padding, current
+ // index could be in the padding area, and r*row_stride +
+ // window_rows could be beyond the input tensor's boundary. In
+ // such cases, change the starting index and reduce the
+ // broadcast size.
+ int rindex, rsize;
+ OP_REQUIRES_OK(context,
+ GetBroadcastSize(r, in_rows, window_rows, row_stride,
+ pad_rows, &rindex, &rsize));
+ for (int64 c = 0; c < out_backprop_cols; ++c) {
+ // Calculates col broadcast size. For SAME padding, current
+ // index could be in the padding area, and c*col_stride +
+ // window_cols could be beyond the input tensor's boundary. In
+ // such cases, change the starting index and reduce the
+ // broadcast size.
+ int cindex, csize;
+ OP_REQUIRES_OK(context,
+ GetBroadcastSize(c, in_cols, window_cols, col_stride,
+ pad_cols, &cindex, &csize));
+
+ T divide_coeff = 1.0 / (rsize * csize);
+ int64 output_index =
+ (b * out_backprop_rows + r) * out_backprop_cols + c;
+ for (int64 r_dst = rindex; r_dst < rindex + rsize; ++r_dst) {
+ for (int64 c_dst = cindex; c_dst < cindex + csize; ++c_dst) {
+ int64 input_index = (b * in_rows + r_dst) * in_cols + c_dst;
+ const T* output_offset =
+ out_backprop_ptr + output_index * out_backprop_depth;
+ T* input_offset =
+ input_backprop_ptr + input_index * out_backprop_depth;
+ for (int64 d = 0; d < out_backprop_depth; ++d) {
+ *input_offset += *output_offset * divide_coeff;
+ ++output_offset;
+ ++input_offset;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .HostMemory("orig_input_shape"),
+ AvgPoolingGradOp<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<double>("T")
+ .HostMemory("orig_input_shape"),
+ AvgPoolingGradOp<CPUDevice, double>);
+
+#if GOOGLE_CUDA
+
+// A CUDNN based AvgPoolingGrad implementation. It includes the padding as the
+// candidates for the pooling operation.
+template <class T>
+class AvgPoolingGradOp<GPUDevice, T> : public OpKernel {
+ public:
+ typedef GPUDevice Device;
+
+ explicit AvgPoolingGradOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in_shape = context->input(0);
+ const Tensor& out_backprop = context->input(1);
+ // For avgpooling, tensor_in_shape should have 1 dimension, and 4 elements.
+ OP_REQUIRES(
+ context,
+ tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4,
+ errors::InvalidArgument("out_backprop must be 1-dimensional and 4 "
+ "elements"));
+ // For avgpooling, out_backprop should have 4 dimensions.
+ OP_REQUIRES(context, out_backprop.dims() == 4,
+ errors::InvalidArgument("out_backprop must be 4-dimensional"));
+
+ TensorShape output_shape;
+ auto shape_vec = tensor_in_shape.vec<int32>();
+ for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) {
+ output_shape.AddDim(shape_vec(i));
+ }
+
+ DnnPoolingGradOp<T>::Compute(
+ context, perftools::gputools::dnn::PoolingMode::kAverage, ksize_,
+ stride_, padding_, nullptr, nullptr, out_backprop, output_shape);
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .HostMemory("orig_input_shape")
+ .Label("cudnn"),
+ AvgPoolingGradOp<GPUDevice, float>);
+
+// A custom GPU kernel based AvgPoolingGrad implementation. It includes the
+// padding as the candidates for the pooling operation.
+template <class T>
+class AvgPoolingGradOpCustomGPUKernel : public OpKernel {
+ public:
+ typedef GPUDevice Device;
+
+ explicit AvgPoolingGradOpCustomGPUKernel(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in_shape = context->input(0);
+ const Tensor& out_backprop = context->input(1);
+ // For avgpooling, tensor_in_shape should have 1 dimension, and 4 elements.
+ OP_REQUIRES(
+ context,
+ tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4,
+ errors::InvalidArgument("out_backprop must be 1-dimensional and 4 "
+ "elements"));
+ // For avgpooling, out_backprop should have 4 dimensions.
+ OP_REQUIRES(context, out_backprop.dims() == 4,
+ errors::InvalidArgument("out_backprop must be 4-dimensional"));
+ const int64 out_backprop_batch = out_backprop.dim_size(0);
+ const int64 out_backprop_rows = out_backprop.dim_size(1);
+ const int64 out_backprop_cols = out_backprop.dim_size(2);
+ const int64 out_backprop_depth = out_backprop.dim_size(3);
+
+ TensorShape output_shape;
+ auto shape_vec = tensor_in_shape.vec<int32>();
+ for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) {
+ output_shape.AddDim(shape_vec(i));
+ }
+ const int64 in_rows = output_shape.dim_size(1);
+ const int64 in_cols = output_shape.dim_size(2);
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+
+ const int window_rows = ksize_[1];
+ const int window_cols = ksize_[2];
+ const int depth_window = ksize_[3];
+
+ const int row_stride = stride_[1];
+ const int col_stride = stride_[2];
+
+ // We (will) use different code for spatial pooling and
+ // non-spatial pooling.
+ //
+ // Spatial pooling is when depth_window = 1
+ OP_REQUIRES(context, depth_window == 1,
+ errors::Unimplemented("Non-spatial pooling is not "
+ "yet supported. Volunteers? :)"));
+
+ int out_height, out_width, pad_rows, pad_cols;
+ OP_REQUIRES_OK(
+ context, Get2dOutputSize(in_rows, in_cols, window_rows, window_cols,
+ row_stride, col_stride, padding_, &out_height,
+ &out_width, &pad_rows, &pad_cols));
+
+ RunAvePoolBackwardNHWC<T>(out_backprop.flat<T>().data(), // top_diff
+ out_backprop_batch, // num
+ in_rows, // height
+ in_cols, // width
+ out_backprop_depth, // channels
+ out_backprop_rows, // pooled_height
+ out_backprop_cols, // pooled_width
+ window_rows, // kernel_h
+ window_cols, // kernel_w
+ row_stride, // stride_h
+ col_stride, // stride_w
+ pad_rows, // pad_t
+ pad_cols, // pad_l
+ output->flat<T>().data(), // bottom_diff
+ context->eigen_gpu_device()); // d
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .HostMemory("orig_input_shape"),
+ AvgPoolingGradOpCustomGPUKernel<float>);
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/avgpooling_op.h b/tensorflow/core/kernels/avgpooling_op.h
new file mode 100644
index 0000000000..38f0eb97e5
--- /dev/null
+++ b/tensorflow/core/kernels/avgpooling_op.h
@@ -0,0 +1,58 @@
+#ifndef TENSORFLOW_KERNELS_AVGPOOLING_OP_H_
+#define TENSORFLOW_KERNELS_AVGPOOLING_OP_H_
+// Functor definition for AvgPoolingOp, must be compilable by nvcc.
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T>
+struct SpatialAvgPooling {
+ void operator()(const Device& d, typename TTypes<T, 4>::Tensor output,
+ typename TTypes<T, 4>::ConstTensor input, int window_rows,
+ int window_cols, int row_stride, int col_stride,
+ const Eigen::PaddingType& padding) {
+ // Because we swap the layout, we swap the row/cols as well
+ output.swap_layout().device(d) =
+ Eigen::SpatialAvgPooling(input.swap_layout(), window_cols, window_rows,
+ col_stride, row_stride, padding);
+ }
+};
+
+} // namespace functor
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// Lauch a custom GPU kernels from Yanqing for the avgpooling backward operation
+// that works NHWC data formats.
+// Arguments:
+// top_diff: backprop to the output of the pooling layer
+// num: number of input batches
+// height: input height
+// width: input width
+// channels: number of input channels
+// pooled_height: the height of the output to the pooling layer
+// pooled_width: the width of the output to the pooling layer
+// kernel_h: the height of the pooling kernel
+// kernel_w: the width of the pooling kernel
+// stride_h: the height of the vertical stride
+// stride_w: the width of the horizontal stride
+// pad_t: padding size to the top side
+// pad_l: padding size to the left side
+// bottom_diff: backprop to the input of the pooling layer.
+template <typename T>
+bool RunAvePoolBackwardNHWC(const T* const top_diff, const int num,
+ const int height, const int width,
+ const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h,
+ const int kernel_w, const int stride_h,
+ const int stride_w, const int pad_t,
+ const int pad_l, T* const bottom_diff,
+ const GPUDevice& d);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_AVGPOOLING_OP_H_
diff --git a/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc b/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc
new file mode 100644
index 0000000000..ec84ee6862
--- /dev/null
+++ b/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc
@@ -0,0 +1,101 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <stdio.h>
+#include <iostream>
+
+#include "tensorflow/core/kernels/avgpooling_op.h"
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#define DEFINE_GPU_KERNELS(T) \
+ template struct functor::SpatialAvgPooling<GPUDevice, T>;
+
+DEFINE_GPU_KERNELS(float)
+
+#undef DEFINE_GPU_KERNELS
+
+#define CUDA_1D_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+ i += blockDim.x * gridDim.x)
+
+static const int CAFFE_CUDA_NUM_THREADS = 1024;
+
+template <typename dtype>
+__global__ void AvePoolBackwardNHWC(const int nthreads,
+ const dtype* const top_diff, const int num,
+ const int height, const int width,
+ const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h,
+ const int kernel_w, const int stride_h,
+ const int stride_w, const int pad_t,
+ const int pad_l, dtype* const bottom_diff) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // find out the local index
+ // find out the local offset
+ const int c = index % channels;
+ const int w = index / channels % width + pad_l;
+ const int h = (index / channels / width) % height + pad_t;
+ const int n = index / channels / width / height;
+ const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
+ const int phend = min(h / stride_h + 1, pooled_height);
+ const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
+ const int pwend = min(w / stride_w + 1, pooled_width);
+ dtype gradient = 0;
+ const dtype* const top_diff_slice =
+ top_diff + n * pooled_height * pooled_width * channels + c;
+ for (int ph = phstart; ph < phend; ++ph) {
+ for (int pw = pwstart; pw < pwend; ++pw) {
+ // figure out the pooling size
+ int hstart = ph * stride_h - pad_t;
+ int wstart = pw * stride_w - pad_l;
+ int hend = min(hstart + kernel_h, height);
+ int wend = min(wstart + kernel_w, width);
+ hstart = max(hstart, 0);
+ wstart = max(wstart, 0);
+ int pool_size = (hend - hstart) * (wend - wstart);
+ gradient +=
+ top_diff_slice[(ph * pooled_width + pw) * channels] / pool_size;
+ }
+ }
+ bottom_diff[index] = gradient;
+ }
+}
+
+template <typename T>
+bool RunAvePoolBackwardNHWC(const T* const top_diff, const int num,
+ const int height, const int width,
+ const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h,
+ const int kernel_w, const int stride_h,
+ const int stride_w, const int pad_t,
+ const int pad_l, T* const bottom_diff,
+ const GPUDevice& d) {
+ int x_size = num * height * width * channels;
+ int thread_per_block =
+ std::min(CAFFE_CUDA_NUM_THREADS, d.maxCudaThreadsPerMultiProcessor());
+ int block_count = (x_size + thread_per_block - 1) / thread_per_block;
+ AvePoolBackwardNHWC<T><<<block_count, thread_per_block, 0, d.stream()>>>(
+ x_size, top_diff, num, height, width, channels, pooled_height,
+ pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_t,
+ bottom_diff);
+
+ return d.ok();
+}
+
+template bool RunAvePoolBackwardNHWC(
+ const float* const top_diff, const int num, const int height,
+ const int width, const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h, const int kernel_w,
+ const int stride_h, const int stride_w, const int pad_t, const int pad_l,
+ float* const bottom_diff, const GPUDevice& d);
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/batch_matmul_op.cc b/tensorflow/core/kernels/batch_matmul_op.cc
new file mode 100644
index 0000000000..349aac0158
--- /dev/null
+++ b/tensorflow/core/kernels/batch_matmul_op.cc
@@ -0,0 +1,260 @@
+// See docs in ../ops/math_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/fill_functor.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+#if GOOGLE_CUDA
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
+#include "tensorflow/stream_executor/stream.h"
+#endif // GOOGLE_CUDA
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename Scalar>
+struct LaunchBatchMatMul;
+
+template <typename Scalar>
+struct LaunchBatchMatMul<CPUDevice, Scalar> {
+ static void Launch(OpKernelContext* context, const Tensor& in_x,
+ const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
+ auto Tx = in_x.tensor<Scalar, 3>();
+ auto Ty = in_y.tensor<Scalar, 3>();
+ auto Tz = out->tensor<Scalar, 3>();
+
+ // Shards "n"-matmuls into "num" shards. Each shard is
+ // dispatched to a thread.
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+ const int64 num_units = in_x.dim_size(0);
+ const int64 cost_per_unit =
+ in_x.dim_size(0) * in_x.dim_size(1) * out->dim_size(2);
+ Shard(worker_threads.num_threads, worker_threads.workers, num_units,
+ cost_per_unit, [&Tx, &Ty, adj_x, adj_y, &Tz](int start, int limit) {
+ LaunchBatchMatMul<CPUDevice, Scalar>::Run(Tx, Ty, adj_x, adj_y, Tz,
+ start, limit);
+ });
+ }
+
+ template <typename In, typename Out>
+ static void Run(In Tx, In Ty, bool adj_x, bool adj_y, Out Tz, int start,
+ int limit) {
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
+
+ Eigen::internal::scalar_conjugate_op<Scalar> conj;
+ if (!adj_x && !adj_y) {
+ for (int i = start; i < limit; ++i) {
+ auto x = Tx.template chip<0>(i);
+ auto y = Ty.template chip<0>(i);
+ auto z = Tz.template chip<0>(i);
+ contract_pairs[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
+ z = x.contract(y, contract_pairs); // matmul
+ }
+ } else if (!adj_x && adj_y) {
+ for (int i = start; i < limit; ++i) {
+ auto x = Tx.template chip<0>(i);
+ auto y = Ty.template chip<0>(i).unaryExpr(conj);
+ auto z = Tz.template chip<0>(i);
+ contract_pairs[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 1);
+ z = x.contract(y, contract_pairs); // matmul
+ }
+ } else if (adj_x && !adj_y) {
+ for (int i = start; i < limit; ++i) {
+ auto x = Tx.template chip<0>(i).unaryExpr(conj);
+ auto y = Ty.template chip<0>(i);
+ auto z = Tz.template chip<0>(i);
+ contract_pairs[0] = Eigen::IndexPair<Eigen::DenseIndex>(0, 0);
+ z = x.contract(y, contract_pairs); // matmul
+ }
+ } else {
+ for (int i = start; i < limit; ++i) {
+ auto x = Tx.template chip<0>(i).unaryExpr(conj);
+ auto y = Ty.template chip<0>(i).unaryExpr(conj);
+ auto z = Tz.template chip<0>(i);
+ contract_pairs[0] = Eigen::IndexPair<Eigen::DenseIndex>(0, 1);
+ z = x.contract(y, contract_pairs); // matmul
+ }
+ }
+ }
+};
+
+#if GOOGLE_CUDA
+
+namespace {
+template <typename T>
+perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
+ perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
+ perftools::gputools::DeviceMemory<T> typed(wrapped);
+ return typed;
+}
+} // namespace
+
+template <typename Scalar>
+struct LaunchBatchMatMul<GPUDevice, Scalar> {
+ static void Launch(OpKernelContext* context, const Tensor& in_x,
+ const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
+ perftools::gputools::blas::Transpose trans[] = {
+ perftools::gputools::blas::Transpose::kNoTranspose,
+ perftools::gputools::blas::Transpose::kTranspose};
+ const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
+ const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
+ const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
+ const uint64 batch_size = in_x.dim_size(0);
+ auto blas_transpose_a = trans[adj_x];
+ auto blas_transpose_b = trans[adj_y];
+
+ auto* stream = context->op_device_context<GPUDeviceContext>()->stream();
+ OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+
+ typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
+ std::vector<DeviceMemoryType> a_device_memory;
+ std::vector<DeviceMemoryType> b_device_memory;
+ std::vector<DeviceMemoryType> c_device_memory;
+ std::vector<DeviceMemoryType*> a_ptrs;
+ std::vector<DeviceMemoryType*> b_ptrs;
+ std::vector<DeviceMemoryType*> c_ptrs;
+ a_device_memory.reserve(batch_size);
+ b_device_memory.reserve(batch_size);
+ c_device_memory.reserve(batch_size);
+ a_ptrs.reserve(batch_size);
+ b_ptrs.reserve(batch_size);
+ c_ptrs.reserve(batch_size);
+ auto* a_base_ptr = in_x.template flat<Scalar>().data();
+ auto* b_base_ptr = in_y.template flat<Scalar>().data();
+ auto* c_base_ptr = out->template flat<Scalar>().data();
+ for (int64 i = 0; i < batch_size; ++i) {
+ a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
+ b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
+ c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
+ a_ptrs.push_back(&a_device_memory.back());
+ b_ptrs.push_back(&b_device_memory.back());
+ c_ptrs.push_back(&c_device_memory.back());
+ }
+
+ // Cublas does
+ // C = A x B
+ // where A, B and C are assumed to be in column major.
+ // We want the output to be in row-major, so we can compute
+ // C' = B' x A' (' stands for transpose)
+ bool blas_launch_status =
+ stream->ThenBlasGemmBatched(blas_transpose_b, blas_transpose_a, n, m, k,
+ static_cast<Scalar>(1.0), b_ptrs,
+ adj_y ? k : n, a_ptrs, adj_x ? m : k,
+ static_cast<Scalar>(0.0), c_ptrs, n,
+ batch_size)
+ .ok();
+ if (!blas_launch_status) {
+ context->SetStatus(errors::Internal(
+ "Blas SGEMMBatched launch failed : a.shape=",
+ in_x.shape().DebugString(), ", b.shape=", in_y.shape().DebugString(),
+ ", m=", m, ", n=", n, ", k=", k, ", batch_size=", batch_size));
+ }
+ }
+};
+
+#endif // GOOGLE_CUDA
+
+template <typename Device, typename Scalar>
+class BatchMatMul : public OpKernel {
+ public:
+ explicit BatchMatMul(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
+ OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
+ }
+
+ virtual ~BatchMatMul() {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& in0 = ctx->input(0);
+ const Tensor& in1 = ctx->input(1);
+ OP_REQUIRES(ctx, in0.dims() == in1.dims(),
+ errors::InvalidArgument("In[0] and In[1] has different ndims: ",
+ in0.shape().ShortDebugString(), " vs. ",
+ in1.shape().ShortDebugString()));
+ const int ndims = in0.dims();
+ OP_REQUIRES(
+ ctx, ndims >= 3,
+ errors::InvalidArgument("In[0] and In[1] ndims must be >= 3: ", ndims));
+ TensorShape out_shape;
+ for (int i = 0; i < ndims - 2; ++i) {
+ OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
+ errors::InvalidArgument("In[0].dim(", i, ") and In[1].dim(",
+ i, ") must be the same: ",
+ in0.shape().DebugString(), " vs ",
+ in1.shape().DebugString()));
+ out_shape.AddDim(in0.dim_size(i));
+ }
+ auto n = out_shape.num_elements();
+ auto d0 = in0.dim_size(ndims - 2);
+ auto d1 = in0.dim_size(ndims - 1);
+ Tensor in0_reshaped;
+ CHECK(in0_reshaped.CopyFrom(in0, TensorShape({n, d0, d1})));
+ auto d2 = in1.dim_size(ndims - 2);
+ auto d3 = in1.dim_size(ndims - 1);
+ Tensor in1_reshaped;
+ CHECK(in1_reshaped.CopyFrom(in1, TensorShape({n, d2, d3})));
+ if (adj_x_) std::swap(d0, d1);
+ if (adj_y_) std::swap(d2, d3);
+ OP_REQUIRES(ctx, d1 == d2,
+ errors::InvalidArgument(
+ "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
+ in0.shape().ShortDebugString(), " ",
+ in1.shape().ShortDebugString(), " ", adj_x_, " ", adj_y_));
+ out_shape.AddDim(d0);
+ out_shape.AddDim(d3);
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
+ if (out->NumElements() == 0) {
+ return;
+ }
+ if (in0.NumElements() == 0 || in1.NumElements() == 0) {
+ functor::SetZeroFunctor<Device, Scalar> f;
+ f(ctx->eigen_device<Device>(), out->flat<Scalar>());
+ return;
+ }
+ Tensor out_reshaped;
+ CHECK(out_reshaped.CopyFrom(*out, TensorShape({n, d0, d3})));
+ LaunchBatchMatMul<Device, Scalar>::Launch(ctx, in0_reshaped, in1_reshaped,
+ adj_x_, adj_y_, &out_reshaped);
+ }
+
+ private:
+ bool adj_x_;
+ bool adj_y_;
+};
+
+#define REGISTER_CPU(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
+ BatchMatMul<CPUDevice, TYPE>)
+
+#define REGISTER_GPU(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
+ BatchMatMul<GPUDevice, TYPE>)
+
+REGISTER_CPU(float);
+REGISTER_CPU(double);
+REGISTER_CPU(int32);
+REGISTER_CPU(complex64);
+
+#ifdef GOOGLE_CUDA
+// TODO(kalakris): The GPU implementation is currently disabled due to issues
+// encountered in practice. See b/24534272.
+// REGISTER_GPU(float);
+#endif // GOOGLE_CUDA
+
+#undef REGISTER_CPU
+#undef REGISTER_GPU
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/batch_norm_op.cc b/tensorflow/core/kernels/batch_norm_op.cc
new file mode 100644
index 0000000000..c67c921631
--- /dev/null
+++ b/tensorflow/core/kernels/batch_norm_op.cc
@@ -0,0 +1,223 @@
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/batch_norm_op.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class BatchNormOp : public OpKernel {
+ public:
+ explicit BatchNormOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr("variance_epsilon", &variance_epsilon_));
+ OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
+ &scale_after_normalization_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& mean = context->input(1);
+ const Tensor& var = context->input(2);
+ const Tensor& beta = context->input(3);
+ const Tensor& gamma = context->input(4);
+
+ OP_REQUIRES(context, input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().ShortDebugString()));
+ OP_REQUIRES(context, mean.dims() == 1,
+ errors::InvalidArgument("mean must be 1-dimensional",
+ mean.shape().ShortDebugString()));
+ OP_REQUIRES(context, var.dims() == 1,
+ errors::InvalidArgument("var must be 1-dimensional",
+ var.shape().ShortDebugString()));
+ OP_REQUIRES(context, beta.dims() == 1,
+ errors::InvalidArgument("beta must be 1-dimensional",
+ beta.shape().ShortDebugString()));
+ OP_REQUIRES(context, gamma.dims() == 1,
+ errors::InvalidArgument("gamma must be 1-dimensional",
+ gamma.shape().ShortDebugString()));
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+
+ functor::BatchNorm<Device, T>()(
+ context->eigen_device<Device>(), input.tensor<T, 4>(), mean.vec<T>(),
+ var.vec<T>(), beta.vec<T>(), gamma.vec<T>(), variance_epsilon_,
+ scale_after_normalization_, output->tensor<T, 4>());
+ }
+
+ private:
+ float variance_epsilon_;
+ bool scale_after_normalization_;
+};
+
+template <typename Device, typename T>
+class BatchNormGradOp : public OpKernel {
+ public:
+ explicit BatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr("variance_epsilon", &variance_epsilon_));
+ OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
+ &scale_after_normalization_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& mean = context->input(1);
+ const Tensor& var = context->input(2);
+ const Tensor& gamma = context->input(3);
+ const Tensor& out_backprop = context->input(4);
+
+ OP_REQUIRES(context, input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().ShortDebugString()));
+ OP_REQUIRES(context, mean.dims() == 1,
+ errors::InvalidArgument("mean must be 1-dimensional",
+ mean.shape().ShortDebugString()));
+ OP_REQUIRES(context, var.dims() == 1,
+ errors::InvalidArgument("var must be 1-dimensional",
+ var.shape().ShortDebugString()));
+ OP_REQUIRES(context, gamma.dims() == 1,
+ errors::InvalidArgument("gamma must be 1-dimensional",
+ gamma.shape().ShortDebugString()));
+ OP_REQUIRES(
+ context, out_backprop.dims() == 4,
+ errors::InvalidArgument("out_backprop must be 4-dimensional",
+ out_backprop.shape().ShortDebugString()));
+
+ Tensor* dx = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, input.shape(), &dx));
+ Tensor* dm = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(1, mean.shape(), &dm));
+ Tensor* dv = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(2, var.shape(), &dv));
+ Tensor* db = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(3, mean.shape(), &db));
+ Tensor* dg = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg));
+
+ // Scratch buffer of [depth] dimension, aka the 4th dimension of input,
+ // which is dim_size(3), for calculating various combinations of
+ // (var + epsilon).
+ Tensor scratch1;
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({input.dim_size(3)}), &scratch1));
+
+ // Scratch buffer of [depth] dimension for saving intermediate calculation
+ // values.
+ Tensor scratch2;
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({input.dim_size(3)}), &scratch2));
+
+ functor::BatchNormGrad<Device, T>()(
+ context->eigen_device<Device>(), input.tensor<T, 4>(), mean.vec<T>(),
+ var.vec<T>(), gamma.vec<T>(), out_backprop.tensor<T, 4>(),
+ variance_epsilon_, scale_after_normalization_, dx->tensor<T, 4>(),
+ dm->vec<T>(), dv->vec<T>(), db->vec<T>(), dg->vec<T>(),
+ scratch1.vec<T>(), scratch2.vec<T>());
+ }
+
+ private:
+ float variance_epsilon_;
+ bool scale_after_normalization_;
+};
+
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ BatchNormOp<CPUDevice, T>);
+
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void BatchNorm<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \
+ typename TTypes<T>::ConstVec mean, typename TTypes<T>::ConstVec var, \
+ typename TTypes<T>::ConstVec beta, typename TTypes<T>::ConstVec gamma, \
+ float variance_epsilon, bool scale_after_normalization, \
+ typename TTypes<T, 4>::Tensor output); \
+ extern template struct BatchNorm<GPUDevice, T>;
+
+#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
+
+DECLARE_GPU_SPECS(float);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER_GPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T"), \
+ BatchNormOp<GPUDevice, T>);
+
+REGISTER_GPU_KERNEL(float);
+#undef REGISTER_GPU_KERNEL
+
+#endif // GOOGLE_CUDA
+
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ BatchNormGradOp<CPUDevice, T>);
+
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void BatchNormGrad<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \
+ typename TTypes<T>::ConstVec mean, typename TTypes<T>::ConstVec var, \
+ typename TTypes<T>::ConstVec gamma, \
+ typename TTypes<T, 4>::ConstTensor out_backprop, float variance_epsilon, \
+ bool scale_after_normalization, typename TTypes<T, 4>::Tensor dx, \
+ typename TTypes<T>::Vec dm, typename TTypes<T>::Vec dv, \
+ typename TTypes<T>::Vec db, typename TTypes<T>::Vec dg, \
+ typename TTypes<T>::Vec scratch1, typename TTypes<T>::Vec scratch2); \
+ extern template struct BatchNormGrad<GPUDevice, T>;
+
+#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
+
+DECLARE_GPU_SPECS(float);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER_GPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T"), \
+ BatchNormGradOp<GPUDevice, T>);
+
+REGISTER_GPU_KERNEL(float);
+#undef REGISTER_GPU_KERNEL
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/batch_norm_op.h b/tensorflow/core/kernels/batch_norm_op.h
new file mode 100644
index 0000000000..5981e58460
--- /dev/null
+++ b/tensorflow/core/kernels/batch_norm_op.h
@@ -0,0 +1,133 @@
+#ifndef TENSORFLOW_KERNELS_BATCH_NORM_OP_H_
+#define TENSORFLOW_KERNELS_BATCH_NORM_OP_H_
+// Functor definition for BatchNormOp, must be compilable by nvcc.
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+// Functor used by BatchNormOp to do the computations.
+template <typename Device, typename T>
+struct BatchNorm {
+ void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
+ typename TTypes<T>::ConstVec mean,
+ typename TTypes<T>::ConstVec var,
+ typename TTypes<T>::ConstVec beta,
+ typename TTypes<T>::ConstVec gamma, float variance_epsilon,
+ bool scale_after_normalization,
+ typename TTypes<T, 4>::Tensor output) {
+ const int depth = mean.dimension(0);
+ const int rest_size = input.size() / depth;
+
+ Eigen::DSizes<int, 2> rest_by_depth(rest_size, depth);
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ Eigen::DSizes<int, 2> rest_by_one(rest_size, 1);
+ Eigen::DSizes<int, 2> one_by_depth(1, depth);
+ Eigen::DSizes<int, 2> depth_by_one(depth, 1);
+#else
+ Eigen::IndexList<int, Eigen::type2index<1> > rest_by_one;
+ rest_by_one.set(0, rest_size);
+ Eigen::IndexList<Eigen::type2index<1>, int> one_by_depth;
+ one_by_depth.set(1, depth);
+ Eigen::IndexList<int, Eigen::type2index<1> > depth_by_one;
+ depth_by_one.set(0, depth);
+#endif
+ if (scale_after_normalization) {
+ output.reshape(rest_by_depth).device(d) =
+ (input.reshape(rest_by_depth) -
+ mean.reshape(one_by_depth).broadcast(rest_by_one)) *
+ ((var + var.constant(variance_epsilon)).rsqrt() * gamma)
+ .eval()
+ .reshape(one_by_depth)
+ .broadcast(rest_by_one) +
+ beta.reshape(one_by_depth).broadcast(rest_by_one);
+ } else {
+ output.reshape(rest_by_depth).device(d) =
+ (input.reshape(rest_by_depth) -
+ mean.reshape(one_by_depth).broadcast(rest_by_one)) *
+ ((var + var.constant(variance_epsilon)).rsqrt())
+ .eval()
+ .reshape(one_by_depth)
+ .broadcast(rest_by_one) +
+ beta.reshape(one_by_depth).broadcast(rest_by_one);
+ }
+ }
+};
+
+template <typename Device, typename T>
+struct BatchNormGrad {
+ void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
+ typename TTypes<T>::ConstVec mean,
+ typename TTypes<T>::ConstVec var,
+ typename TTypes<T>::ConstVec gamma,
+ typename TTypes<T, 4>::ConstTensor out_backprop,
+ float variance_epsilon, bool scale_after_normalization,
+ typename TTypes<T, 4>::Tensor dx, typename TTypes<T>::Vec dm,
+ typename TTypes<T>::Vec dv, typename TTypes<T>::Vec db,
+ typename TTypes<T>::Vec dg, typename TTypes<T>::Vec scratch1,
+ typename TTypes<T>::Vec scratch2) {
+ const int depth = mean.dimension(0);
+ const int rest_size = input.size() / depth;
+
+ typedef typename TTypes<T>::ConstVec::Index Index;
+ Eigen::DSizes<Index, 2> rest_by_depth(rest_size, depth);
+ Eigen::DSizes<Index, 2> rest_by_one(rest_size, 1);
+ Eigen::DSizes<Index, 2> one_by_depth(1, depth);
+
+ // db = out_backprop
+ //
+ // dg = out_backprop * ((x - m) * rsqrt(v + epsilon))
+ //
+ // dv = sum_over_rest(out_backprop * gamma * (x - m)) *
+ // (-1/2) * (v + epsilon) ^ (-3/2)
+ //
+ // dm = sum_over_rest(out_backprop * gamma) * (-1 / rsqrt(v + epsilon))
+ //
+ // dx = out_backprop * (gamma * rsqrt(v + epsilon))
+ Eigen::array<Index, 1> reduction_axis;
+ reduction_axis[0] = 0; // Reduces on first dimension.
+
+ db.device(d) = out_backprop.reshape(rest_by_depth).sum(reduction_axis);
+
+ // scratch1 = rsqrt(v + epsilon)
+ scratch1.device(d) = (var + var.constant(variance_epsilon)).rsqrt();
+
+ // scratch2 = sum_over_rest(out_backprop * (x - m))
+ scratch2.device(d) = (out_backprop.reshape(rest_by_depth) *
+ (input.reshape(rest_by_depth) -
+ mean.reshape(one_by_depth).broadcast(rest_by_one)))
+ .sum(reduction_axis);
+
+ if (scale_after_normalization) {
+ dx.reshape(rest_by_depth).device(d) =
+ out_backprop.reshape(rest_by_depth) * ((scratch1 * gamma)
+ .eval()
+ .reshape(one_by_depth)
+ .broadcast(rest_by_one));
+ dm.device(d) = -db * (scratch1 * gamma).eval();
+ dg.device(d) = scratch2 * scratch1;
+ } else {
+ dx.reshape(rest_by_depth).device(d) =
+ out_backprop.reshape(rest_by_depth) *
+ scratch1.reshape(one_by_depth).broadcast(rest_by_one);
+ dm.device(d) = -db * scratch1;
+ dg.device(d) = dg.constant(static_cast<T>(0.0)); // Gamma is not learned.
+ }
+
+ // scratch1 = - 1/2 * (var + epsilon) ^ (-3/2)
+ scratch1.device(d) = scratch1 * scratch1.constant(static_cast<T>(-0.5f)) /
+ (var + var.constant(variance_epsilon));
+
+ if (scale_after_normalization) {
+ dv.device(d) = scratch2 * (scratch1 * gamma).eval();
+ } else {
+ dv.device(d) = scratch2 * scratch1;
+ }
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_BATCH_NORM_OP_H_
diff --git a/tensorflow/core/kernels/batch_norm_op_gpu.cu.cc b/tensorflow/core/kernels/batch_norm_op_gpu.cu.cc
new file mode 100644
index 0000000000..02e0eeecfa
--- /dev/null
+++ b/tensorflow/core/kernels/batch_norm_op_gpu.cu.cc
@@ -0,0 +1,17 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/batch_norm_op.h"
+
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+template struct functor::BatchNorm<GPUDevice, float>;
+template struct functor::BatchNormGrad<GPUDevice, float>;
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/bcast_ops.cc b/tensorflow/core/kernels/bcast_ops.cc
new file mode 100644
index 0000000000..bb1492e5b4
--- /dev/null
+++ b/tensorflow/core/kernels/bcast_ops.cc
@@ -0,0 +1,71 @@
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/util/bcast.h"
+
+namespace tensorflow {
+
+// Given shapes of two tensors, computes the reduction indices for the
+// gradient computation.
+//
+// TODO(zhifengc):
+// 1. Adds support for n-ary (n >= 2).
+class BCastGradArgsOp : public OpKernel {
+ public:
+ explicit BCastGradArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(
+ ctx, ctx->MatchSignature({DT_INT32, DT_INT32}, {DT_INT32, DT_INT32}));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ OP_REQUIRES(
+ ctx, ctx->num_inputs() == 2,
+ errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
+ gtl::InlinedVector<BCast::Vec, 4> shapes;
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ const Tensor& in = ctx->input(i);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()),
+ errors::InvalidArgument("In[", i, "] must be a vector.",
+ in.shape().ShortDebugString()));
+ BCast::Vec vec;
+ for (int64 i = 0; i < in.NumElements(); ++i) {
+ vec.push_back(in.vec<int32>()(i));
+ }
+ shapes.push_back(vec);
+ }
+ BCast bcast(shapes[0], shapes[1]);
+ OP_REQUIRES(ctx, bcast.IsValid(),
+ errors::InvalidArgument(
+ "Incompatible shapes: [", str_util::Join(shapes[0], ","),
+ "] vs. [", str_util::Join(shapes[1], ","), "]"));
+ Output(ctx, 0, bcast.grad_x_reduce_idx());
+ Output(ctx, 1, bcast.grad_y_reduce_idx());
+ }
+
+ private:
+ void Output(OpKernelContext* ctx, int idx, const BCast::Vec& v) {
+ const int len = v.size();
+ Tensor* o = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(idx, TensorShape({len}), &o));
+ for (int i = 0; i < len; ++i) o->flat<int32>()(i) = v[i];
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
+ .Device(DEVICE_CPU)
+ .HostMemory("s0")
+ .HostMemory("s1")
+ .HostMemory("r0")
+ .HostMemory("r1"),
+ BCastGradArgsOp);
+REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
+ .Device(DEVICE_GPU)
+ .HostMemory("s0")
+ .HostMemory("s1")
+ .HostMemory("r0")
+ .HostMemory("r1"),
+ BCastGradArgsOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc
new file mode 100644
index 0000000000..68737f6c2d
--- /dev/null
+++ b/tensorflow/core/kernels/bias_op.cc
@@ -0,0 +1,112 @@
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/bias_op.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class BiasOp : public BinaryOp<T> {
+ public:
+ explicit BiasOp(OpKernelConstruction* context) : BinaryOp<T>(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& bias = context->input(1);
+
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()),
+ errors::InvalidArgument("Input tensor must be at least 2D: ",
+ input.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
+ errors::InvalidArgument("Biases must be 1D: ",
+ bias.shape().DebugString()));
+ const auto last_dim = input.shape().dims() - 1;
+ OP_REQUIRES(
+ context, bias.shape().dim_size(0) == input.shape().dim_size(last_dim),
+ errors::InvalidArgument(
+ "Must provide as many biases as the last dimension "
+ "of the input tensor: ",
+ bias.shape().DebugString(), " vs. ", input.shape().DebugString()));
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+
+ switch (input.shape().dims()) {
+ case 2:
+ Compute<2>(context, input, bias, output);
+ break;
+ case 3:
+ Compute<3>(context, input, bias, output);
+ break;
+ case 4:
+ Compute<4>(context, input, bias, output);
+ break;
+ case 5:
+ Compute<5>(context, input, bias, output);
+ break;
+ default:
+ OP_REQUIRES(context, false,
+ errors::InvalidArgument("Only ranks up to 5 supported: ",
+ input.shape().DebugString()));
+ }
+ }
+
+ // Add biases for an input matrix of rank Dims, by using the Bias.
+ template <int Dims>
+ void Compute(OpKernelContext* ctx, const Tensor& input, const Tensor& bias,
+ Tensor* output) {
+ functor::Bias<Device, T, Dims> functor;
+ functor(ctx->eigen_device<Device>(), input.tensor<T, Dims>(), bias.vec<T>(),
+ output->tensor<T, Dims>());
+ }
+};
+
+#define REGISTER_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("BiasAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ BiasOp<CPUDevice, type>);
+
+TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
+#undef REGISTER_KERNEL
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T, Dims) \
+ template <> \
+ void Bias<GPUDevice, T, Dims>::operator()( \
+ const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
+ typename TTypes<T>::ConstVec bias, \
+ typename TTypes<T, Dims>::Tensor output); \
+ extern template struct Bias<GPUDevice, T, Dims>;
+
+#define DECLARE_GPU_SPECS(T) \
+ DECLARE_GPU_SPEC(T, 2); \
+ DECLARE_GPU_SPEC(T, 3); \
+ DECLARE_GPU_SPEC(T, 4); \
+ DECLARE_GPU_SPEC(T, 5);
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("BiasAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ BiasOp<GPUDevice, type>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/bias_op.h b/tensorflow/core/kernels/bias_op.h
new file mode 100644
index 0000000000..513406d251
--- /dev/null
+++ b/tensorflow/core/kernels/bias_op.h
@@ -0,0 +1,41 @@
+#ifndef TENSORFLOW_KERNELS_BIAS_OP_H_
+#define TENSORFLOW_KERNELS_BIAS_OP_H_
+// Functor definition for BiasOp, must be compilable by nvcc.
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+// Functor used by BiasOp to do the computations.
+template <typename Device, typename T, int Dims>
+struct Bias {
+ // Add "bias" to "input", broadcasting it on all dimensions but the last one.
+ void operator()(const Device& d, typename TTypes<T, Dims>::ConstTensor input,
+ typename TTypes<T>::ConstVec bias,
+ typename TTypes<T, Dims>::Tensor output) {
+ const int bias_size = bias.dimension(0);
+ const int rest_size = input.size() / bias_size;
+
+ Eigen::DSizes<int, 2> rest_by_bias(rest_size, bias_size);
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ Eigen::DSizes<int, 2> rest_by_one(rest_size, 1);
+ Eigen::DSizes<int, 2> one_by_bias(1, bias_size);
+#else
+ Eigen::IndexList<int, Eigen::type2index<1> > rest_by_one;
+ rest_by_one.set(0, rest_size);
+ Eigen::IndexList<Eigen::type2index<1>, int> one_by_bias;
+ one_by_bias.set(1, bias_size);
+#endif
+
+ output.reshape(rest_by_bias).device(d) =
+ input.reshape(rest_by_bias) +
+ bias.reshape(one_by_bias).broadcast(rest_by_one);
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_BIAS_OP_H_
diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc
new file mode 100644
index 0000000000..d3377b3ce8
--- /dev/null
+++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc
@@ -0,0 +1,23 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/bias_op.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// Definition of the GPU implementations declared in bias_op.cc.
+#define DEFINE_GPU_SPECS(T) \
+ template struct functor::Bias<GPUDevice, T, 2>; \
+ template struct functor::Bias<GPUDevice, T, 3>; \
+ template struct functor::Bias<GPUDevice, T, 4>; \
+ template struct functor::Bias<GPUDevice, T, 5>;
+
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/candidate_sampler_ops.cc b/tensorflow/core/kernels/candidate_sampler_ops.cc
new file mode 100644
index 0000000000..cd5fde37a6
--- /dev/null
+++ b/tensorflow/core/kernels/candidate_sampler_ops.cc
@@ -0,0 +1,243 @@
+// See docs in ../ops/candidate_sampling_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include <cfloat>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/range_sampler.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/util/guarded_philox_random.h"
+
+namespace tensorflow {
+
+class BaseCandidateSamplerOp : public OpKernel {
+ public:
+ explicit BaseCandidateSamplerOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("num_sampled", &num_sampled_));
+ OP_REQUIRES_OK(context, context->GetAttr("num_true", &num_true_));
+ OP_REQUIRES_OK(context, context->GetAttr("unique", &unique_));
+ OP_REQUIRES_OK(context, generator_.Init(context));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& true_classes = context->input(0);
+ OP_REQUIRES(context, true_classes.dims() == 2,
+ errors::InvalidArgument("true_classes must be a matrix"));
+ const int32 batch_size = true_classes.dim_size(0);
+ OP_REQUIRES(context, true_classes.dim_size(1) == num_true_,
+ errors::InvalidArgument("true_classes must have "
+ "num_true columns"));
+
+ // Output candidates and expected_count.
+ Tensor* out_sampled_candidates = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({num_sampled_}),
+ &out_sampled_candidates));
+
+ Tensor* out_true_expected_count = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 1, TensorShape({batch_size, num_true_}),
+ &out_true_expected_count));
+ Tensor* out_sampled_expected_count = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(2, TensorShape({num_sampled_}),
+ &out_sampled_expected_count));
+
+ gtl::ArraySlice<int64> true_candidate(true_classes.matrix<int64>().data(),
+ batch_size * num_true_);
+ gtl::MutableArraySlice<int64> sampled_candidate(
+ out_sampled_candidates->vec<int64>().data(), num_sampled_);
+ gtl::MutableArraySlice<float> true_expected_count(
+ out_true_expected_count->matrix<float>().data(),
+ batch_size * num_true_);
+ gtl::MutableArraySlice<float> sampled_expected_count(
+ out_sampled_expected_count->vec<float>().data(), num_sampled_);
+
+ CHECK(sampler_) << "CandidateSamplerOp did not set sampler_";
+
+ // Approximately conservatively estimate the number of samples required.
+ // In cases where rejection sampling is used we may occasionally use more
+ // samples than expected, which will result in reused random bits.
+ const int64 samples32 = 2048 * num_sampled_;
+
+ // Pick sampled candidates.
+ auto local_gen = generator_.ReserveSamples32(samples32);
+ random::SimplePhilox random(&local_gen);
+ sampler_->SampleBatchGetExpectedCount(&random, unique_, &sampled_candidate,
+ &sampled_expected_count,
+ true_candidate, &true_expected_count);
+
+ if (sampler_->NeedsUpdates()) {
+ sampler_->Update(true_candidate);
+ }
+ }
+
+ protected:
+ void set_sampler(RangeSampler* sampler) { sampler_.reset(sampler); }
+
+ private:
+ int32 num_true_;
+ int32 num_sampled_;
+ bool unique_;
+ std::unique_ptr<RangeSampler> sampler_;
+ GuardedPhiloxRandom generator_;
+};
+
+template <class RangeSamplerType>
+class SimpleCandidateSamplerOp : public BaseCandidateSamplerOp {
+ public:
+ explicit SimpleCandidateSamplerOp(OpKernelConstruction* context)
+ : BaseCandidateSamplerOp(context) {
+ int64 range_max;
+ OP_REQUIRES_OK(context, context->GetAttr("range_max", &range_max));
+ set_sampler(new RangeSamplerType(range_max));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("UniformCandidateSampler").Device(DEVICE_CPU),
+ SimpleCandidateSamplerOp<UniformSampler>);
+
+REGISTER_KERNEL_BUILDER(Name("LogUniformCandidateSampler").Device(DEVICE_CPU),
+ SimpleCandidateSamplerOp<LogUniformSampler>);
+
+REGISTER_KERNEL_BUILDER(Name("LearnedUnigramCandidateSampler")
+ .Device(DEVICE_CPU),
+ SimpleCandidateSamplerOp<UnigramSampler>);
+
+REGISTER_KERNEL_BUILDER(Name("ThreadUnsafeUnigramCandidateSampler")
+ .Device(DEVICE_CPU),
+ SimpleCandidateSamplerOp<ThreadUnsafeUnigramSampler>);
+
+class AllCandidateSamplerOp : public BaseCandidateSamplerOp {
+ public:
+ explicit AllCandidateSamplerOp(OpKernelConstruction* context)
+ : BaseCandidateSamplerOp(context) {
+ int64 range_max;
+ OP_REQUIRES_OK(context, context->GetAttr("num_sampled", &range_max));
+ set_sampler(new AllSampler(range_max));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("AllCandidateSampler").Device(DEVICE_CPU),
+ AllCandidateSamplerOp);
+
+class FixedUnigramCandidateSamplerOp : public BaseCandidateSamplerOp {
+ public:
+ explicit FixedUnigramCandidateSamplerOp(OpKernelConstruction* context)
+ : BaseCandidateSamplerOp(context) {
+ int64 range_max;
+ OP_REQUIRES_OK(context, context->GetAttr("range_max", &range_max));
+ string vocab_file;
+ OP_REQUIRES_OK(context, context->GetAttr("vocab_file", &vocab_file));
+ std::vector<float> unigrams;
+ OP_REQUIRES_OK(context, context->GetAttr("unigrams", &unigrams));
+ OP_REQUIRES(
+ context, !vocab_file.empty() || !unigrams.empty(),
+ errors::InvalidArgument("Must provide either vocab_file or unigrams."));
+ OP_REQUIRES(context, vocab_file.empty() || unigrams.empty(),
+ errors::InvalidArgument(
+ "Must only provide one of vocab_file and unigrams."));
+ float distortion;
+ OP_REQUIRES_OK(context, context->GetAttr("distortion", &distortion));
+ int64 num_reserved_ids;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("num_reserved_ids", &num_reserved_ids));
+ int64 num_shards;
+ OP_REQUIRES_OK(context, context->GetAttr("num_shards", &num_shards));
+ int64 shard;
+ OP_REQUIRES_OK(context, context->GetAttr("shard", &shard));
+
+ if (!vocab_file.empty()) {
+ set_sampler(new FixedUnigramSampler(context->env(), range_max, vocab_file,
+ distortion, num_reserved_ids,
+ num_shards, shard));
+ } else {
+ set_sampler(new FixedUnigramSampler(range_max, unigrams, distortion,
+ num_reserved_ids, num_shards, shard));
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("FixedUnigramCandidateSampler").Device(DEVICE_CPU),
+ FixedUnigramCandidateSamplerOp);
+
+class ComputeAccidentalHitsOp : public OpKernel {
+ public:
+ explicit ComputeAccidentalHitsOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("num_true", &num_true_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& in_true_candidates = context->input(0);
+ TensorShape in_true_candidates_shape = in_true_candidates.shape();
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(in_true_candidates_shape) &&
+ in_true_candidates_shape.dim_size(1) == num_true_,
+ errors::InvalidArgument(
+ "true_candidates must be a batch_size * num_true matrix"));
+
+ const int64 batch_size = in_true_candidates_shape.dim_size(0);
+
+ const Tensor& in_sampled_candidates = context->input(1);
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsVector(in_sampled_candidates.shape()),
+ errors::InvalidArgument(
+ "sampled_candidates must be a vector, which is typically "
+ "an output from CandidateSampler"));
+
+ std::unordered_map<int64, int> sampled_candidate_to_pos;
+ for (int64 i = 0; i < in_sampled_candidates.dim_size(0); ++i) {
+ sampled_candidate_to_pos[in_sampled_candidates.vec<int64>()(i)] = i;
+ }
+
+ // Produce output in the same format as UnpackSparseFeatures.
+ std::vector<int> indices;
+ std::vector<int64> ids;
+ std::vector<float> weights;
+
+ for (int64 i = 0; i < batch_size; ++i) {
+ for (int64 j = 0; j < num_true_; ++j) {
+ const int64 true_candidate = in_true_candidates.matrix<int64>()(i, j);
+ const auto look = sampled_candidate_to_pos.find(true_candidate);
+ if (look != sampled_candidate_to_pos.end()) {
+ indices.push_back(i);
+ ids.push_back(look->second);
+ weights.push_back(-FLT_MAX);
+ }
+ }
+ }
+
+ Tensor* out_indices = nullptr;
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_output(
+ 0, TensorShape({static_cast<int>(indices.size())}), &out_indices));
+ Tensor* out_ids = nullptr;
+ OP_REQUIRES_OK(
+ context, context->allocate_output(
+ 1, TensorShape({static_cast<int>(ids.size())}), &out_ids));
+ Tensor* out_weights = nullptr;
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_output(
+ 2, TensorShape({static_cast<int>(weights.size())}), &out_weights));
+
+ for (size_t i = 0; i < indices.size(); ++i) {
+ out_indices->vec<int32>()(i) = indices[i];
+ out_ids->vec<int64>()(i) = ids[i];
+ out_weights->vec<float>()(i) = weights[i];
+ }
+ }
+
+ private:
+ int64 num_true_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ComputeAccidentalHits").Device(DEVICE_CPU),
+ ComputeAccidentalHitsOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
new file mode 100644
index 0000000000..779ac57b6a
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -0,0 +1,233 @@
+// See docs in ../ops/math_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/cast_op.h"
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+template <typename Device, typename Tout, typename Tin>
+void CastMaybeInline(const Device& d, typename TTypes<Tout>::Flat o,
+ typename TTypes<Tin>::ConstFlat i) {
+ if (o.size() * (sizeof(Tin) + sizeof(Tout)) < 131072) {
+ // Small cast on a CPU: do inline
+ o = i.template cast<Tout>();
+ } else {
+ o.device(d) = i.template cast<Tout>();
+ }
+}
+
+template <typename O, typename I>
+struct CastFunctor<CPUDevice, O, I> {
+ void operator()(const CPUDevice& d, typename TTypes<O>::Flat o,
+ typename TTypes<I>::ConstFlat i) {
+ CastMaybeInline<CPUDevice, O, I>(d, o, i);
+ }
+};
+
+} // namespace functor
+
+#define CAST_CASE(DEVICE, IN, OUT) \
+ if (DataTypeToEnum<IN>::value == src_dtype_ && \
+ DataTypeToEnum<OUT>::value == dst_dtype_) { \
+ work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { \
+ functor::CastFunctor<DEVICE, OUT, IN> func; \
+ func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>()); \
+ }; \
+ return Status::OK(); \
+ }
+
+class CastOpBase : public OpKernel {
+ public:
+ explicit CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& inp = ctx->input(0);
+ if (work_ == nullptr) {
+ ctx->set_output(0, inp);
+ } else {
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
+ work_(ctx, inp, out);
+ }
+ }
+
+ protected:
+ DataType src_dtype_;
+ DataType dst_dtype_;
+ std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr;
+
+ virtual Status Prepare() = 0;
+ Status Unimplemented() {
+ return errors::Unimplemented("Cast ", DataTypeString(src_dtype_), " to ",
+ DataTypeString(dst_dtype_),
+ " is not supported");
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase);
+};
+
+class CpuCastOp : public CastOpBase {
+ public:
+ explicit CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
+ OP_REQUIRES_OK(ctx, Prepare());
+ }
+
+ protected:
+ Status Prepare() override {
+ if (src_dtype_ == dst_dtype_) {
+ work_ = nullptr; // Identity
+ return Status::OK();
+ }
+ CAST_CASE(CPUDevice, bool, float);
+ CAST_CASE(CPUDevice, bool, int32);
+ CAST_CASE(CPUDevice, bool, double);
+ CAST_CASE(CPUDevice, double, float);
+ CAST_CASE(CPUDevice, double, int32);
+ CAST_CASE(CPUDevice, double, int64);
+ CAST_CASE(CPUDevice, float, double);
+ CAST_CASE(CPUDevice, float, uint8);
+ CAST_CASE(CPUDevice, float, int32);
+ CAST_CASE(CPUDevice, float, int64);
+ CAST_CASE(CPUDevice, int32, double);
+ CAST_CASE(CPUDevice, int32, float);
+ CAST_CASE(CPUDevice, int32, uint8);
+ CAST_CASE(CPUDevice, int32, int64);
+ CAST_CASE(CPUDevice, int64, double);
+ CAST_CASE(CPUDevice, int64, float);
+ CAST_CASE(CPUDevice, int64, int32);
+ CAST_CASE(CPUDevice, uint8, float);
+ CAST_CASE(CPUDevice, uint8, int32);
+ CAST_CASE(CPUDevice, uint8, int64);
+ CAST_CASE(CPUDevice, uint8, double);
+ if (src_dtype_ == DT_BFLOAT16 && dst_dtype_ == DT_FLOAT) {
+ work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) {
+ int64 N = out->NumElements();
+ auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
+ int num_threads =
+ std::min<int>(std::min(4, worker_threads->num_threads), N / 4096);
+ if (num_threads < 1) {
+ BFloat16ToFloat(inp.flat<bfloat16>().data(),
+ out->flat<float>().data(), N);
+ } else {
+ auto work = [&inp, &out](int64 start, int64 end) {
+ BFloat16ToFloat(inp.flat<bfloat16>().data() + start,
+ out->flat<float>().data() + start, end - start);
+ };
+ Shard(num_threads, worker_threads->workers, N, 100, work);
+ }
+ };
+ return Status::OK();
+ }
+ if (src_dtype_ == DT_FLOAT && dst_dtype_ == DT_BFLOAT16) {
+ work_ = [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) {
+ int64 N = out->NumElements();
+ auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
+ int num_threads =
+ std::min<int>(std::min(4, worker_threads->num_threads), N / 4096);
+ if (num_threads < 1) {
+ FloatToBFloat16(inp.flat<float>().data(),
+ out->flat<bfloat16>().data(), N);
+ } else {
+ auto work = [&inp, &out](int64 start, int64 end) {
+ FloatToBFloat16(inp.flat<float>().data() + start,
+ out->flat<bfloat16>().data() + start, end - start);
+ };
+ Shard(num_threads, worker_threads->workers, N, 100, work);
+ }
+ };
+ return Status::OK();
+ }
+ return Unimplemented();
+ }
+};
+
+class GpuCastOp : public CastOpBase {
+ public:
+ explicit GpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
+ OP_REQUIRES_OK(ctx, Prepare());
+ }
+
+ protected:
+ Status Prepare() override {
+ if (src_dtype_ == dst_dtype_) {
+ work_ = nullptr; // Identity
+ return Status::OK();
+ }
+ CAST_CASE(GPUDevice, bfloat16, float);
+ CAST_CASE(GPUDevice, bool, float);
+ CAST_CASE(GPUDevice, double, float);
+ CAST_CASE(GPUDevice, double, int64);
+ CAST_CASE(GPUDevice, float, bfloat16);
+ CAST_CASE(GPUDevice, float, double);
+ CAST_CASE(GPUDevice, float, int64);
+ CAST_CASE(GPUDevice, int64, double);
+ CAST_CASE(GPUDevice, int64, float);
+ CAST_CASE(GPUDevice, uint8, float);
+ CAST_CASE(GPUDevice, float, uint8);
+ CAST_CASE(GPUDevice, bool, int32);
+ CAST_CASE(GPUDevice, double, int32);
+ CAST_CASE(GPUDevice, float, int32);
+ CAST_CASE(GPUDevice, int32, double);
+ CAST_CASE(GPUDevice, int32, float);
+ CAST_CASE(GPUDevice, int32, int64);
+ CAST_CASE(GPUDevice, int64, int32);
+ return Unimplemented();
+ }
+};
+
+#undef CAST_CASE
+
+REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp);
+
+#if GOOGLE_CUDA
+#define REGISTER_CAST_GPU(srctype, dsttype) \
+ REGISTER_KERNEL_BUILDER(Name("Cast") \
+ .TypeConstraint<srctype>("SrcT") \
+ .TypeConstraint<dsttype>("DstT") \
+ .Device(DEVICE_GPU), \
+ GpuCastOp);
+REGISTER_CAST_GPU(bfloat16, float);
+REGISTER_CAST_GPU(bool, float);
+REGISTER_CAST_GPU(double, float);
+REGISTER_CAST_GPU(double, int64);
+REGISTER_CAST_GPU(float, bfloat16);
+REGISTER_CAST_GPU(float, double);
+REGISTER_CAST_GPU(float, int64);
+REGISTER_CAST_GPU(int64, double);
+REGISTER_CAST_GPU(int64, float);
+REGISTER_CAST_GPU(uint8, float);
+REGISTER_CAST_GPU(float, uint8);
+REGISTER_CAST_GPU(bool, int32);
+REGISTER_CAST_GPU(double, int32);
+REGISTER_CAST_GPU(float, int32);
+REGISTER_CAST_GPU(int32, double);
+REGISTER_CAST_GPU(int32, float);
+REGISTER_CAST_GPU(int32, int64);
+REGISTER_CAST_GPU(int64, int32);
+#undef REGISTER_CAST_GPU
+#endif // GOOGLE_CUDA
+
+// HostCast differs from Cast in that its input and output are in host memory.
+REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp);
+REGISTER_KERNEL_BUILDER(
+ Name("_HostCast").Device(DEVICE_GPU).HostMemory("x").HostMemory("y"),
+ CpuCastOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h
new file mode 100644
index 0000000000..d066206abc
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op.h
@@ -0,0 +1,71 @@
+#ifndef TENSORFLOW_KERNELS_CAST_OP_H_
+#define TENSORFLOW_KERNELS_CAST_OP_H_
+
+#include "tensorflow/core/framework/bfloat16.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/port.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename Tout, typename Tin>
+void Cast(const Device& d, typename TTypes<Tout>::Flat o,
+ typename TTypes<Tin>::ConstFlat i) {
+ o.device(d) = i.template cast<Tout>();
+}
+
+template <typename Device, typename Tout, typename Tin>
+struct CastFunctor {
+ void operator()(const Device& d, typename TTypes<Tout>::Flat o,
+ typename TTypes<Tin>::ConstFlat i);
+};
+
+} // end namespace functor
+} // end namespace tensorflow
+
+namespace Eigen {
+namespace internal {
+
+// Specialized cast op impls for bfloat16.
+template <>
+struct scalar_cast_op< ::tensorflow::bfloat16, float> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef float result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(
+ const ::tensorflow::bfloat16& a) const {
+ static_assert(::tensorflow::port::kLittleEndian, "");
+ float ret;
+ uint16_t* p = reinterpret_cast<uint16_t*>(&ret);
+ p[0] = 0;
+ p[1] = a.value;
+ return ret;
+ }
+};
+
+template <>
+struct functor_traits<scalar_cast_op< ::tensorflow::bfloat16, float> > {
+ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false };
+};
+
+template <>
+struct scalar_cast_op<float, ::tensorflow::bfloat16> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef ::tensorflow::bfloat16 result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const ::tensorflow::bfloat16 operator()(
+ const float a) const {
+ static_assert(::tensorflow::port::kLittleEndian, "");
+ const uint16_t* p = reinterpret_cast<const uint16_t*>(&a);
+ return ::tensorflow::bfloat16(p[1]);
+ }
+};
+
+template <>
+struct functor_traits<scalar_cast_op<float, ::tensorflow::bfloat16> > {
+ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false };
+};
+
+} // namespace internal
+} // namespace Eigen
+
+#endif // TENSORFLOW_KERNELS_CAST_OP_H_
diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc
new file mode 100644
index 0000000000..cd198c752b
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc
@@ -0,0 +1,45 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/framework/bfloat16.h"
+#include "tensorflow/core/kernels/cast_op.h"
+
+namespace tensorflow {
+namespace functor {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename O, typename I>
+struct CastFunctor<GPUDevice, O, I> {
+ void operator()(const GPUDevice& d, typename TTypes<O>::Flat o,
+ typename TTypes<I>::ConstFlat i) {
+ Cast<GPUDevice, O, I>(d, o, i);
+ }
+};
+
+#define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I>;
+DEFINE(float, double);
+DEFINE(float, int32);
+DEFINE(float, int64);
+DEFINE(double, float);
+DEFINE(double, int32);
+DEFINE(double, int64);
+DEFINE(int32, float);
+DEFINE(int32, double);
+DEFINE(int32, int64);
+DEFINE(int64, float);
+DEFINE(int64, double);
+DEFINE(int64, int32);
+DEFINE(int32, bool);
+DEFINE(float, bool);
+DEFINE(float, uint8);
+DEFINE(uint8, float);
+DEFINE(float, bfloat16);
+DEFINE(bfloat16, float);
+#undef DEFINE
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc
new file mode 100644
index 0000000000..f774fbcfe8
--- /dev/null
+++ b/tensorflow/core/kernels/cast_op_test.cc
@@ -0,0 +1,100 @@
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+template <typename Src, typename Dst>
+static Graph* Cast(int num) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor data(DataTypeToEnum<Src>::value,
+ TensorShape({64, 64, num / (64 * 64)}));
+ data.flat<Src>().setRandom();
+ test::graph::Cast(g, test::graph::Constant(g, data),
+ DataTypeToEnum<Dst>::value);
+ return g;
+}
+
+class CastOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(DataType src, DataType dst) {
+ RequireDefaultOps();
+ EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
+ .Input(FakeInput(DT_INT32))
+ .Attr("SrcT", src)
+ .Attr("DstT", dst)
+ .Finalize(node_def()));
+ EXPECT_OK(InitOp());
+ }
+};
+
+TEST_F(CastOpTest, Int32ToUint8) {
+ MakeOp(DT_INT32, DT_UINT8);
+ AddInputFromArray<int32>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_UINT8, TensorShape({1, 2, 2, 1}));
+ test::FillValues<uint8>(&expected, {1, 2, 3, 4});
+ test::ExpectTensorEqual<uint8>(expected, *GetOutput(0));
+}
+
+static void BM_cpu_float_int64(int iters, int num) {
+ testing::ItemsProcessed(static_cast<int64>(iters) * num);
+ testing::BytesProcessed(static_cast<int64>(iters) * num *
+ (sizeof(float) + sizeof(int64)));
+ testing::UseRealTime();
+ test::Benchmark("cpu", Cast<float, int64>(num)).Run(iters);
+}
+BENCHMARK(BM_cpu_float_int64)->Arg(64 << 10)->Arg(32 << 20);
+
+static void BM_gpu_float_int64(int iters, int num) {
+ testing::ItemsProcessed(static_cast<int64>(iters) * num);
+ testing::BytesProcessed(static_cast<int64>(iters) * num *
+ (sizeof(float) + sizeof(int64)));
+ testing::UseRealTime();
+ test::Benchmark("gpu", Cast<float, int64>(num)).Run(iters);
+}
+BENCHMARK(BM_gpu_float_int64)->Arg(64 << 10)->Arg(32 << 20);
+
+static void BM_cpu_bool_float(int iters, int num) {
+ testing::ItemsProcessed(static_cast<int64>(iters) * num);
+ testing::BytesProcessed(static_cast<int64>(iters) * num *
+ (sizeof(bool) + sizeof(float)));
+ testing::UseRealTime();
+ test::Benchmark("cpu", Cast<bool, float>(num)).Run(iters);
+}
+BENCHMARK(BM_cpu_bool_float)->Arg(64 << 10)->Arg(32 << 20);
+
+static void BM_gpu_bool_float(int iters, int num) {
+ testing::ItemsProcessed(static_cast<int64>(iters) * num);
+ testing::BytesProcessed(static_cast<int64>(iters) * num *
+ (sizeof(bool) + sizeof(float)));
+ testing::UseRealTime();
+ test::Benchmark("gpu", Cast<bool, float>(num)).Run(iters);
+}
+BENCHMARK(BM_gpu_bool_float)->Arg(64 << 10)->Arg(32 << 20);
+
+static void BM_cpu_float_bfloat16(int iters, int num) {
+ testing::ItemsProcessed(static_cast<int64>(iters) * num);
+ testing::BytesProcessed(static_cast<int64>(iters) * num *
+ (sizeof(float) + sizeof(bfloat16)));
+ testing::UseRealTime();
+ test::Benchmark("cpu", Cast<float, bfloat16>(num)).Run(iters);
+}
+BENCHMARK(BM_cpu_float_bfloat16)->Arg(64 << 10)->Arg(32 << 20);
+
+static void BM_cpu_bfloat16_float(int iters, int num) {
+ testing::ItemsProcessed(static_cast<int64>(iters) * num);
+ testing::BytesProcessed(static_cast<int64>(iters) * num *
+ (sizeof(float) + sizeof(bfloat16)));
+ testing::UseRealTime();
+ test::Benchmark("cpu", Cast<bfloat16, float>(num)).Run(iters);
+}
+BENCHMARK(BM_cpu_bfloat16_float)->Arg(64 << 10)->Arg(32 << 20);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/check_numerics_op.cc b/tensorflow/core/kernels/check_numerics_op.cc
new file mode 100644
index 0000000000..65487a303c
--- /dev/null
+++ b/tensorflow/core/kernels/check_numerics_op.cc
@@ -0,0 +1,190 @@
+// See docs in ../ops/array_ops.cc.
+
+#include <math.h>
+#include <algorithm>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor.h"
+
+#if GOOGLE_CUDA
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
+#include "tensorflow/stream_executor/stream.h"
+#endif // GOOGLE_CUDA
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+#if GOOGLE_CUDA
+template <typename T>
+struct CheckNumericsLaunch {
+ void Run(const GPUDevice& d, const T* data, int size,
+ int abnormal_detected[2]);
+};
+#endif
+
+namespace {
+
+template <typename Device, typename T>
+class CheckNumericsOp;
+
+// Partial specialization for CPU
+template <typename T>
+class CheckNumericsOp<CPUDevice, T> : public OpKernel {
+ public:
+ explicit CheckNumericsOp(OpKernelConstruction* context) : OpKernel(context) {
+ // message_ is used as the prefix for the assertion error message. For
+ // instance, this can be the name of the input op that produced the tensor.
+ OP_REQUIRES_OK(context, context->GetAttr("message", &message_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // pass along the input to the output
+ context->set_output(0, context->input(0));
+
+ auto in = context->input(0).flat<T>();
+ const T* data = in.data();
+ const int size = in.size();
+ // Check to see if any element of the tensor is NaN or Inf.
+ int fp_props =
+ std::accumulate(data, data + size, 0, [](const int& x, const T& y) {
+ int prop = std::fpclassify(y);
+ int result = x;
+ if (prop == FP_INFINITE) {
+ result |= kInfBit;
+ } else if (prop == FP_NAN) {
+ result |= kNaNBit;
+ }
+ return result;
+ });
+ string status;
+ if ((fp_props & kInfBit) && (fp_props & kNaNBit)) {
+ status = "Inf and NaN";
+ } else {
+ if (fp_props & kInfBit) {
+ status = "Inf";
+ }
+ if (fp_props & kNaNBit) {
+ status = "NaN";
+ }
+ }
+ if (!status.empty()) {
+ context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ",
+ status, " values"));
+ }
+ }
+
+ private:
+ string message_;
+ static const int kInfBit = 0x01;
+ static const int kNaNBit = 0x02;
+};
+
+#if GOOGLE_CUDA
+// Partial specialization for GPU
+template <typename T>
+class CheckNumericsOp<GPUDevice, T> : public OpKernel {
+ public:
+ typedef GPUDevice Device;
+
+ explicit CheckNumericsOp(OpKernelConstruction* context) : OpKernel(context) {
+ // message_ is used as the prefix for the assertion error message. For
+ // instance, this can be the name of the input op that produced the tensor.
+ OP_REQUIRES_OK(context, context->GetAttr("message", &message_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // pass along the input to the output
+ context->set_output(0, context->input(0));
+ auto input = context->input(0).flat<T>();
+
+ // Allocate and initialize the elements to hold the check results
+ const int abnormal_detected_size = 2;
+ Tensor abnormal_detected;
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DT_INT32, TensorShape({abnormal_detected_size}),
+ &abnormal_detected));
+
+ auto* stream = context->op_device_context<GPUDeviceContext>()->stream();
+ OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+
+ perftools::gputools::DeviceMemoryBase abnormal_detected_ptr(
+ abnormal_detected.flat<int>().data(),
+ abnormal_detected.flat<int>().size());
+ stream->ThenMemset32(&abnormal_detected_ptr, 0,
+ abnormal_detected.flat<int>().size() * sizeof(int));
+
+ // Call the Cuda kernels for the numerical checks
+ const Device& d = context->eigen_device<Device>();
+ CheckNumericsLaunch<T>().Run(d, input.data(), input.size(),
+ abnormal_detected.flat<int>().data());
+
+ // Copy the results from device to host
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ attr.set_gpu_compatible(true);
+ Tensor abnormal_detected_out;
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DT_INT32, TensorShape({abnormal_detected_size}),
+ &abnormal_detected_out, attr));
+ int* abnormal_detected_host = abnormal_detected_out.flat<int>().data();
+ stream->ThenMemcpy(abnormal_detected_host, abnormal_detected_ptr,
+ abnormal_detected_size * sizeof(int));
+ stream->BlockHostUntilDone();
+ OP_REQUIRES(context, stream->ok(),
+ errors::Internal("cudaMemcpy from device to host failed"));
+
+ int is_nan = abnormal_detected_host[0];
+ int is_inf = abnormal_detected_host[1];
+ if (is_nan || is_inf) {
+ string status;
+ LOG(ERROR) << "abnormal_detected_host @" << abnormal_detected_host
+ << " = {" << is_nan << ", " << is_inf << "} " << message_;
+
+ // Results should always be 1 or 0. If we see anything else then
+ // there has been some GPU memory corruption.
+ CHECK_GE(is_nan, 0);
+ CHECK_GE(is_inf, 0);
+ CHECK_LE(is_nan, 1);
+ CHECK_LE(is_inf, 1);
+
+ if (is_nan && is_inf) {
+ status = "Inf and NaN";
+ } else if (is_nan) {
+ status = "NaN";
+ } else if (is_inf) {
+ status = "Inf";
+ }
+ context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ",
+ status, " values"));
+ }
+ }
+
+ private:
+ string message_;
+};
+#endif // GOOGLE_CUDA
+
+} // namespace
+
+REGISTER_KERNEL_BUILDER(Name("CheckNumerics")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ CheckNumericsOp<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("CheckNumerics")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<double>("T"),
+ CheckNumericsOp<CPUDevice, double>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("CheckNumerics")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T"),
+ CheckNumericsOp<GPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("CheckNumerics")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<double>("T"),
+ CheckNumericsOp<GPUDevice, double>);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc b/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc
new file mode 100644
index 0000000000..cb84f98731
--- /dev/null
+++ b/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc
@@ -0,0 +1,62 @@
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+
+#include <stdio.h>
+#include <assert.h>
+
+#include <math.h>
+#include <algorithm>
+
+#include "tensorflow/core/platform/port.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+namespace {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// A Cuda kernel to check if each element is Inf or Nan. If any exists, the
+// relevant elements in abnormal_detected will be set
+template <typename T>
+__global__ void CheckNumericsKernel(const T *data, int size,
+ int abnormal_detected[2]) {
+ const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
+ const int32 total_thread_count = gridDim.x * blockDim.x;
+
+ int32 offset = thread_id;
+
+ while (offset < size) {
+ if (isnan(data[offset])) {
+ abnormal_detected[0] = 1;
+ }
+ if (isinf(data[offset])) {
+ abnormal_detected[1] = 1;
+ }
+ offset += total_thread_count;
+ }
+}
+
+} // namespace
+
+// A simple launch pad to launch the Cuda kernels that checks the numerical
+// abnormality in the given array
+template <typename T>
+struct CheckNumericsLaunch {
+ void Run(const GPUDevice &d, const T *data, int size,
+ int abnormal_detected[2]) {
+ const int32 block_size = d.maxCudaThreadsPerBlock();
+ const int32 num_blocks =
+ (d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor()) /
+ block_size;
+
+ CheckNumericsKernel<T><<<num_blocks, block_size, 0, d.stream()>>>(
+ data, size, abnormal_detected);
+ }
+};
+
+template struct CheckNumericsLaunch<float>;
+template struct CheckNumericsLaunch<double>;
+
+} // namespace tensorflow
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cholesky_op.cc b/tensorflow/core/kernels/cholesky_op.cc
new file mode 100644
index 0000000000..12632fb248
--- /dev/null
+++ b/tensorflow/core/kernels/cholesky_op.cc
@@ -0,0 +1,71 @@
+// See docs in ../ops/linalg_ops.cc.
+// TODO(konstantinos): Enable complex inputs. This will require additional tests
+// and OP_REQUIRES.
+
+#include <cmath>
+
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/linalg_ops_common.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "third_party/eigen3/Eigen/Cholesky"
+
+namespace tensorflow {
+
+template <class Scalar, bool SupportsBatchOperationT>
+class CholeskyOp : public LinearAlgebraOp<Scalar, SupportsBatchOperationT> {
+ public:
+ explicit CholeskyOp(OpKernelConstruction* context)
+ : LinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {}
+
+ TensorShape GetOutputMatrixShape(
+ const TensorShape& input_matrix_shape) override {
+ return input_matrix_shape;
+ }
+
+ int64 GetCostPerUnit(const TensorShape& input_matrix_shape) override {
+ const int64 rows = input_matrix_shape.dim_size(0);
+ if (rows > (1LL << 20)) {
+ // A big number to cap the cost in case overflow.
+ return kint32max;
+ } else {
+ return rows * rows * rows;
+ }
+ }
+
+ using typename LinearAlgebraOp<Scalar, SupportsBatchOperationT>::MatrixMap;
+ using
+ typename LinearAlgebraOp<Scalar, SupportsBatchOperationT>::ConstMatrixMap;
+
+ void ComputeMatrix(OpKernelContext* context, const ConstMatrixMap& input,
+ MatrixMap* output) override {
+ OP_REQUIRES(context, input.rows() == input.cols(),
+ errors::InvalidArgument("Input matrix must be square."));
+ if (input.rows() == 0) {
+ // If X is an empty matrix (0 rows, 0 col), X * X' == X.
+ // Therefore, we return X.
+ return;
+ }
+ // Perform the actual LL^T Cholesky decomposition. This will only use
+ // the lower triangular part of data_in by default. The upper triangular
+ // part of the matrix will not be read.
+ Eigen::LLT<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic,
+ Eigen::RowMajor>> llt_decomposition(input);
+
+ // Output the lower triangular in a dense form.
+ *output = llt_decomposition.matrixL();
+
+ OP_REQUIRES(context, llt_decomposition.info() == Eigen::Success,
+ errors::InvalidArgument("LLT decomposition was not successful. "
+ "The input might not be valid."));
+ }
+};
+
+REGISTER_LINALG_OP("Cholesky", (CholeskyOp<float, false>), float);
+REGISTER_LINALG_OP("Cholesky", (CholeskyOp<double, false>), double);
+REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<float, true>), float);
+REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<double, true>), double);
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc
new file mode 100644
index 0000000000..b68fcec515
--- /dev/null
+++ b/tensorflow/core/kernels/concat_op.cc
@@ -0,0 +1,153 @@
+// See docs in ../ops/array_ops.cc.
+
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/concat_op.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+// --------------------------------------------------------------------------
+template <typename Device, typename T>
+class ConcatOp : public OpKernel {
+ public:
+ typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
+ ConstMatrixVector;
+
+ explicit ConcatOp(OpKernelConstruction* c) : OpKernel(c) {}
+
+ void Compute(OpKernelContext* c) override {
+ const Tensor* concat_dim_tensor;
+ OP_REQUIRES_OK(c, c->input("concat_dim", &concat_dim_tensor));
+ OP_REQUIRES(
+ c, TensorShapeUtils::IsLegacyScalar(concat_dim_tensor->shape()),
+ errors::InvalidArgument(
+ "Concat dim tensor should be a scalar integer, but got shape ",
+ concat_dim_tensor->shape().DebugString()));
+ const int32 concat_dim = concat_dim_tensor->scalar<int32>()();
+ OpInputList values;
+ OP_REQUIRES_OK(c, c->input_list("values", &values));
+ const int N = values.size();
+ const int input_dims = values[0].dims();
+ const TensorShape& input_shape = values[0].shape();
+ OP_REQUIRES(
+ c, (0 <= concat_dim && concat_dim < input_dims) ||
+ (kAllowLegacyScalars && concat_dim == 0),
+ errors::InvalidArgument(
+ "ConcatOp : Expected concatenating dimensions in the range [", 0,
+ ", ", input_dims, "), but got ", concat_dim));
+
+ // Note that we reduce the concat of n-dimensional tensors into a two
+ // dimensional concat. Assuming the dimensions of any input/output
+ // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along
+ // the dimension indicated with size y0, we flatten it to {x, y}, where y =
+ // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1).
+ ConstMatrixVector inputs_flat;
+ inputs_flat.reserve(N);
+ int64 inputs_flat_dim0 = 1;
+ for (int d = 0; d < concat_dim; ++d) {
+ inputs_flat_dim0 *= input_shape.dim_size(d);
+ }
+ int output_concat_dim = 0;
+ const bool input_is_scalar = TensorShapeUtils::IsLegacyScalar(input_shape);
+ for (int i = 0; i < N; ++i) {
+ const auto in = values[i];
+ const bool in_is_scalar = TensorShapeUtils::IsLegacyScalar(in.shape());
+ OP_REQUIRES(
+ c, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
+ errors::InvalidArgument(
+ "ConcatOp : Ranks of all input tensors should match: shape[0] = ",
+ input_shape.ShortDebugString(), " vs. shape[", i, "] = ",
+ in.shape().ShortDebugString()));
+ for (int j = 0; j < input_dims; ++j) {
+ if (j == concat_dim) {
+ continue;
+ }
+ OP_REQUIRES(
+ c, in.dim_size(j) == input_shape.dim_size(j),
+ errors::InvalidArgument(
+ "ConcatOp : Dimensions of inputs should match: shape[0] = ",
+ input_shape.ShortDebugString(), " vs. shape[", i, "] = ",
+ in.shape().ShortDebugString()));
+ }
+ if (in.NumElements() > 0) {
+ int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
+ inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
+ in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
+ }
+ // TODO(irving): Remove check once !kAllowLegacyScalars
+ output_concat_dim += in.dims() > 0 ? in.dim_size(concat_dim) : 1;
+ }
+
+ TensorShape output_shape(input_shape);
+ // TODO(irving): Remove rank 0 case once !kAllowLegacyScalars
+ if (output_shape.dims() == 0) {
+ output_shape.AddDim(output_concat_dim);
+ } else {
+ output_shape.set_dim(concat_dim, output_concat_dim);
+ }
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
+ if (output->NumElements() > 0) {
+ int64 output_dim1 = output->NumElements() / inputs_flat_dim0;
+ auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1});
+ if (std::is_same<Device, GPUDevice>::value) {
+ ConcatGPU<T>(c->eigen_gpu_device(), inputs_flat, &output_flat);
+ } else {
+ ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
+ }
+ }
+ }
+};
+
+#define REGISTER_CONCAT(type) \
+ REGISTER_KERNEL_BUILDER(Name("Concat") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("concat_dim"), \
+ ConcatOp<CPUDevice, type>)
+
+TF_CALL_ALL_TYPES(REGISTER_CONCAT);
+REGISTER_CONCAT(quint8);
+REGISTER_CONCAT(qint8);
+REGISTER_CONCAT(qint32);
+REGISTER_CONCAT(bfloat16);
+
+#undef REGISTER_CONCAT
+
+#if GOOGLE_CUDA
+
+#define REGISTER_GPU(type) \
+ REGISTER_KERNEL_BUILDER(Name("Concat") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("concat_dim"), \
+ ConcatOp<GPUDevice, type>)
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
+#undef REGISTER_GPU
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Concat")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int32>("T")
+ .HostMemory("concat_dim")
+ .HostMemory("values")
+ .HostMemory("output"),
+ ConcatOp<CPUDevice, int32>);
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/concat_op.h b/tensorflow/core/kernels/concat_op.h
new file mode 100644
index 0000000000..664e55080d
--- /dev/null
+++ b/tensorflow/core/kernels/concat_op.h
@@ -0,0 +1,27 @@
+#ifndef TENSORFLOW_KERNELS_CONCAT_OP_H_
+#define TENSORFLOW_KERNELS_CONCAT_OP_H_
+
+#include <vector>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/device_base.h"
+
+namespace tensorflow {
+
+// Assumes all inputs are nonempty
+template <typename T>
+void ConcatCPU(DeviceBase* d,
+ const std::vector<
+ std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
+ typename TTypes<T, 2>::Matrix* output);
+
+// Assumes all inputs are nonempty
+template <typename T>
+void ConcatGPU(const Eigen::GpuDevice& d,
+ const std::vector<
+ std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
+ typename TTypes<T, 2>::Matrix* output);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_CONCAT_OP_H_
diff --git a/tensorflow/core/kernels/concat_op_cpu.cc b/tensorflow/core/kernels/concat_op_cpu.cc
new file mode 100644
index 0000000000..679a53721c
--- /dev/null
+++ b/tensorflow/core/kernels/concat_op_cpu.cc
@@ -0,0 +1,122 @@
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/concat_op.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+template <typename T>
+static inline void Copy(T* dst, const T* src, int n) {
+ if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
+ memcpy(dst, src, n * sizeof(T));
+ } else {
+ for (int k = 0; k < n; ++k) {
+ *dst++ = *src++;
+ }
+ }
+}
+
+template <typename T>
+void ConcatCPU(DeviceBase* d,
+ const std::vector<
+ std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
+ typename TTypes<T, 2>::Matrix* output) {
+ int num_inputs = inputs.size();
+ std::vector<ptrdiff_t> sizes;
+ sizes.reserve(num_inputs);
+ int row_size = 0;
+ for (int j = 0; j < num_inputs; ++j) {
+ sizes.push_back(inputs[j]->dimension(1));
+ row_size += sizes.back();
+ }
+
+ auto worker_threads = d->tensorflow_cpu_worker_threads();
+ int num_threads = std::min<int>(std::min(4, worker_threads->num_threads),
+ output->size() / 4096);
+ // Single threaded mode.
+ if (num_threads == 0) {
+ T* out = &(*output)(0, 0);
+ std::vector<const T*> inp;
+ inp.reserve(num_inputs);
+ for (int j = 0; j < num_inputs; ++j) {
+ inp.push_back(&(*inputs[j])(0, 0));
+ }
+ const int dim0 = output->dimension(0);
+ for (int i = 0; i < dim0; ++i) {
+ for (int j = 0; j < num_inputs; ++j) {
+ auto size = sizes[j];
+ Copy(out, inp[j], size);
+ out += size;
+ inp[j] += size;
+ }
+ }
+ return;
+ }
+
+ // Sharded mode.
+ auto work = [&row_size, &sizes, &inputs, &output, &num_inputs](int64 start,
+ int64 end) {
+ int64 skipped_rows = start / row_size;
+ T* out = output->data() + skipped_rows * row_size;
+ T* out_start = output->data() + start;
+ T* out_end = output->data() + end;
+
+ // Handle partial row at start
+ if (out < out_start) {
+ for (int j = 0; j < num_inputs; ++j) {
+ ptrdiff_t size = sizes[j];
+ ptrdiff_t offset = out_start - out;
+ if (size <= offset) {
+ out += size;
+ continue;
+ }
+ const T* inp = &(*inputs[j])(skipped_rows, 0);
+ if (offset > 0) {
+ out += offset;
+ inp += offset;
+ size -= offset;
+ }
+ size = std::min(size, out_end - out);
+ if (size <= 0) break;
+ Copy(out, inp, size);
+ out += size;
+ }
+ ++skipped_rows;
+ }
+ if (out == out_end) return;
+ CHECK(out >= out_start);
+ CHECK(out < out_end);
+
+ // Copy remaining data.
+ std::vector<const T*> inp;
+ inp.reserve(num_inputs);
+ for (int j = 0; j < num_inputs; ++j) {
+ inp.push_back(&(*inputs[j])(skipped_rows, 0));
+ }
+ const int dim0 = output->dimension(0);
+ for (int i = skipped_rows; i < dim0; ++i) {
+ for (int j = 0; j < num_inputs; ++j) {
+ ptrdiff_t size = std::min(sizes[j], out_end - out);
+ Copy(out, inp[j], size);
+ out += size;
+ inp[j] += size;
+ if (out == out_end) return;
+ }
+ }
+ };
+ Shard(num_threads, worker_threads->workers, output->size(), 100, work);
+}
+
+#define REGISTER(T) \
+ template void ConcatCPU<T>( \
+ DeviceBase*, \
+ const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>&, \
+ typename TTypes<T, 2>::Matrix* output);
+TF_CALL_ALL_TYPES(REGISTER)
+REGISTER(quint8)
+REGISTER(qint8)
+REGISTER(qint32)
+REGISTER(bfloat16)
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/concat_op_gpu.cu.cc b/tensorflow/core/kernels/concat_op_gpu.cu.cc
new file mode 100644
index 0000000000..d8ce6bd85d
--- /dev/null
+++ b/tensorflow/core/kernels/concat_op_gpu.cu.cc
@@ -0,0 +1,41 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <stdio.h>
+
+#include <memory>
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename T>
+void ConcatGPU(const GPUDevice& d,
+ const std::vector<
+ std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
+ typename TTypes<T, 2>::Matrix* output) {
+ Eigen::array<ptrdiff_t, 2> offset(0, 0);
+ for (int i = 0; i < inputs.size(); ++i) {
+ Eigen::array<ptrdiff_t, 2> size = inputs[i]->dimensions();
+ output->slice(offset, size).device(d) = *inputs[i];
+ offset[1] += size[1];
+ }
+}
+
+#define REGISTER_GPU(T) \
+ template void ConcatGPU<T>( \
+ const GPUDevice& d, \
+ const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \
+ inputs, \
+ typename TTypes<T, 2>::Matrix* output);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
+#undef REGISTER_GPU
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/concat_op_test.cc b/tensorflow/core/kernels/concat_op_test.cc
new file mode 100644
index 0000000000..4ccc5b5b19
--- /dev/null
+++ b/tensorflow/core/kernels/concat_op_test.cc
@@ -0,0 +1,240 @@
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace {
+
+// For the benchmark, we set up two 2-dimensional tensors, each kDim1 x 'dim'
+// in size, and concat them together along "concat_dimension"
+template <typename T>
+static void ConcatHelper(int iters, int concat_dimension, int dim2) {
+ testing::StopTiming();
+ RequireDefaultOps();
+ Graph* g = new Graph(OpRegistry::Global());
+
+ DataType dt = DataTypeToEnum<T>::v();
+ const int kDim1 = 100;
+ Tensor concat_dim(DT_INT32, TensorShape({}));
+ concat_dim.scalar<int32>()() = concat_dimension;
+ Tensor in0(dt, TensorShape({kDim1, dim2}));
+ in0.flat<T>().setRandom();
+ Tensor in1(dt, TensorShape({kDim1, dim2}));
+ in1.flat<T>().setRandom();
+
+ Node* node;
+ TF_CHECK_OK(
+ NodeBuilder(g->NewName("n"), "Concat")
+ .Input(test::graph::Constant(g, concat_dim))
+ .Input({test::graph::Constant(g, in0), test::graph::Constant(g, in1)})
+ .Attr("N", 2)
+ .Attr("T", dt)
+ .Finalize(g, &node));
+
+ testing::BytesProcessed(static_cast<int64>(iters) *
+ ((kDim1 * dim2) + (kDim1 * dim2)) * sizeof(T));
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+ testing::UseRealTime();
+}
+
+static void BM_ConcatDim0Float(int iters, int dim2) {
+ ConcatHelper<float>(iters, 0, dim2);
+}
+
+static void BM_ConcatDim1Float(int iters, int dim2) {
+ ConcatHelper<float>(iters, 1, dim2);
+}
+
+BENCHMARK(BM_ConcatDim0Float)->Arg(1000)->Arg(100000)->Arg(1000000);
+BENCHMARK(BM_ConcatDim1Float)->Arg(1000)->Arg(100000)->Arg(1000000);
+
+static void BM_ConcatDim1int16(int iters, int dim2) {
+ ConcatHelper<int16>(iters, 1, dim2);
+}
+static void BM_ConcatDim1bfloat16(int iters, int dim2) {
+ ConcatHelper<bfloat16>(iters, 1, dim2);
+}
+
+BENCHMARK(BM_ConcatDim1int16)->Arg(1000)->Arg(100000)->Arg(1000000);
+BENCHMARK(BM_ConcatDim1bfloat16)->Arg(1000)->Arg(100000)->Arg(1000000);
+
+template <typename T>
+static void ConcatManyHelper(int iters, int concat_dimension, int dim2) {
+ testing::StopTiming();
+ RequireDefaultOps();
+ Graph* g = new Graph(OpRegistry::Global());
+
+ DataType dt = DataTypeToEnum<T>::v();
+ const int kDim1 = 40000;
+ const int kNumInputs = 64;
+ Tensor concat_dim(DT_INT32, TensorShape({}));
+ concat_dim.scalar<int32>()() = concat_dimension;
+ std::vector<NodeBuilder::NodeOut> inputs;
+ inputs.reserve(kNumInputs);
+ for (int i = 0; i < kNumInputs; ++i) {
+ Tensor in(dt, TensorShape({kDim1, dim2}));
+ in.flat<T>().setRandom();
+ inputs.push_back(test::graph::Constant(g, in));
+ }
+
+ Node* node;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Concat")
+ .Input(test::graph::Constant(g, concat_dim))
+ .Input(inputs)
+ .Attr("N", 64)
+ .Attr("T", dt)
+ .Finalize(g, &node));
+ testing::BytesProcessed(static_cast<int64>(iters) * kDim1 * dim2 *
+ kNumInputs * sizeof(T));
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+ testing::UseRealTime();
+}
+
+static void BM_ConcatManyDim1bfloat16(int iters, int dim2) {
+ ConcatManyHelper<bfloat16>(iters, 1, dim2);
+}
+
+BENCHMARK(BM_ConcatManyDim1bfloat16)->Arg(18)->Arg(34)->Arg(60);
+
+static void MemcpyAlternativeHelper(int iters, int concat_dimension, int dim2) {
+ testing::StopTiming();
+
+ const int kDim1 = 100;
+ std::vector<float> data1(kDim1 * dim2, 1.0f);
+ std::vector<float> data2(kDim1 * dim2, 2.0f);
+
+ testing::BytesProcessed(static_cast<int64>(iters) *
+ ((kDim1 * dim2) + (kDim1 * dim2)) * sizeof(float));
+ testing::StartTiming();
+ while (--iters > 0) {
+ const int n0 = data1.size();
+ const int n1 = data2.size();
+ float* result = new float[n0 + n1];
+ memcpy(&result[0], &data1[0], n0 * sizeof(float));
+ memcpy(&result[n0], &data2[0], n1 * sizeof(float));
+ delete[] result;
+ }
+}
+
+static void BM_MemcpyAlternativeDim0(int iters, int dim2) {
+ MemcpyAlternativeHelper(iters, 0, dim2);
+}
+static void BM_MemcpyAlternativeDim1(int iters, int dim2) {
+ MemcpyAlternativeHelper(iters, 1, dim2);
+}
+
+BENCHMARK(BM_MemcpyAlternativeDim0)->Arg(1000)->Arg(100000)->Arg(1000000);
+BENCHMARK(BM_MemcpyAlternativeDim1)->Arg(1000)->Arg(100000)->Arg(1000000);
+
+typedef Eigen::TensorMap<Eigen::Tensor<bfloat16, 1, Eigen::RowMajor>,
+ Eigen::Unaligned> EigenMap;
+static void MemcpyManyAlternative1(int iters, int dim2) {
+ testing::StopTiming();
+
+ const int kDim1 = 40000;
+ const int kNumCopies = 64;
+ const int size = kDim1 * dim2 * kNumCopies;
+ bfloat16* data = new bfloat16[size];
+ EigenMap map(data, size);
+ map.setRandom();
+
+ testing::BytesProcessed(static_cast<int64>(iters) * kDim1 * dim2 *
+ kNumCopies * sizeof(bfloat16));
+ testing::StartTiming();
+ while (iters-- > 0) {
+ std::vector<bfloat16*> inputs(kNumCopies);
+ for (int i = 0; i < kNumCopies; ++i) {
+ inputs[i] = &data[i * kDim1 * dim2];
+ }
+ bfloat16* result = new bfloat16[size];
+ for (int j = 0; j < kNumCopies; ++j) {
+ bfloat16* output = &result[j * dim2];
+ for (int i = 0; i < kDim1; ++i) {
+ if (i + 1 < kDim1) {
+ port::prefetch<port::PREFETCH_HINT_T0>(inputs[j] + dim2);
+ }
+ memcpy(output, inputs[j], dim2 * sizeof(bfloat16));
+ inputs[j] += dim2;
+ output += dim2 * kNumCopies;
+ }
+ }
+ delete[] result;
+ }
+ delete[] data;
+}
+
+static void MemcpyManyAlternative2(int iters, int dim2) {
+ testing::StopTiming();
+
+ const int kDim1 = 40000;
+ const int kNumCopies = 64;
+ const int size = kDim1 * dim2 * kNumCopies;
+ bfloat16* data = new bfloat16[size];
+ EigenMap map(data, size);
+ map.setRandom();
+
+ testing::BytesProcessed(static_cast<int64>(iters) * kDim1 * dim2 *
+ kNumCopies * sizeof(bfloat16));
+ testing::StartTiming();
+ std::vector<bfloat16*> inputs(kNumCopies);
+ while (--iters > 0) {
+ bfloat16* result = new bfloat16[size];
+ for (int i = 0; i < kNumCopies; ++i) {
+ inputs[i] = &data[i * kDim1 * dim2];
+ }
+ bfloat16* output = result;
+ for (int i = 0; i < kDim1; ++i) {
+ for (int j = 0; j < kNumCopies; ++j) {
+ if (j + 1 < kNumCopies) {
+ port::prefetch<port::PREFETCH_HINT_T0>(inputs[j + 1]);
+ }
+ memcpy(output, inputs[j], dim2 * sizeof(bfloat16));
+ inputs[j] += dim2;
+ output += dim2;
+ }
+ }
+ delete[] result;
+ }
+ delete[] data;
+}
+
+BENCHMARK(MemcpyManyAlternative1)
+ ->Arg(16)
+ ->Arg(17)
+ ->Arg(18)
+ ->Arg(32)
+ ->Arg(33)
+ ->Arg(34)
+ ->Arg(60)
+ ->Arg(64)
+ ->Arg(65);
+
+BENCHMARK(MemcpyManyAlternative2)
+ ->Arg(16)
+ ->Arg(17)
+ ->Arg(18)
+ ->Arg(32)
+ ->Arg(33)
+ ->Arg(34)
+ ->Arg(60)
+ ->Arg(64)
+ ->Arg(65);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
new file mode 100644
index 0000000000..281bafd3df
--- /dev/null
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -0,0 +1,249 @@
+// See docs in ../ops/array_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/constant_op.h"
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/fill_functor.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+ConstantOp::ConstantOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), tensor_(ctx->output_type(0)) {
+ const TensorProto* proto = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto));
+ OP_REQUIRES_OK(ctx, ctx->device()->MakeTensorFromProto(
+ *proto, AllocatorAttributes(), &tensor_));
+ OP_REQUIRES(
+ ctx, ctx->output_type(0) == tensor_.dtype(),
+ errors::InvalidArgument("Type mismatch between value (",
+ DataTypeString(tensor_.dtype()), ") and dtype (",
+ DataTypeString(ctx->output_type(0)), ")"));
+}
+
+void ConstantOp::Compute(OpKernelContext* ctx) { ctx->set_output(0, tensor_); }
+
+ConstantOp::~ConstantOp() {}
+
+REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_CPU), ConstantOp);
+
+#if GOOGLE_CUDA
+#define REGISTER_KERNEL(D, TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Const").Device(DEVICE_##D).TypeConstraint<TYPE>("dtype"), \
+ ConstantOp);
+REGISTER_KERNEL(GPU, float);
+REGISTER_KERNEL(GPU, double);
+REGISTER_KERNEL(GPU, uint8);
+REGISTER_KERNEL(GPU, int8);
+REGISTER_KERNEL(GPU, int16);
+REGISTER_KERNEL(GPU, int64);
+REGISTER_KERNEL(GPU, complex64);
+REGISTER_KERNEL(GPU, bool);
+// Currently we do not support string constants on GPU
+#undef REGISTER_KERNEL
+#endif
+
+// HostConstantOp differs from ConstantOp in that its output is always
+// in host memory.
+class HostConstantOp : public OpKernel {
+ public:
+ explicit HostConstantOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), tensor_(ctx->output_type(0)) {
+ const TensorProto* proto = nullptr;
+ AllocatorAttributes alloc_attr;
+ alloc_attr.set_on_host(true);
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto));
+ OP_REQUIRES_OK(
+ ctx, ctx->device()->MakeTensorFromProto(*proto, alloc_attr, &tensor_));
+ OP_REQUIRES(
+ ctx, ctx->output_type(0) == tensor_.dtype(),
+ errors::InvalidArgument(
+ "Type mismatch between value (", DataTypeString(tensor_.dtype()),
+ ") and dtype (", DataTypeString(ctx->output_type(0)), ")"));
+ }
+
+ void Compute(OpKernelContext* ctx) override { ctx->set_output(0, tensor_); }
+
+ bool IsExpensive() override { return false; }
+
+ ~HostConstantOp() override {}
+
+ private:
+ Tensor tensor_;
+ TF_DISALLOW_COPY_AND_ASSIGN(HostConstantOp);
+};
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Const")
+ .Device(DEVICE_GPU)
+ .HostMemory("output")
+ .TypeConstraint<int32>("dtype"),
+ HostConstantOp);
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+// Partial specialization of FillFunctor<Device=CPUDevice, T>.
+template <typename T>
+struct FillFunctor<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstScalar in) {
+ out.device(d) = out.constant(in());
+ }
+};
+
+// Partial specialization of SetZeroFunctor<Device=CPUDevice, T>.
+template <typename T>
+struct SetZeroFunctor<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat out) {
+ out.device(d) = out.constant(0);
+ }
+};
+
+#define DEFINE_SETZERO_CPU(T) template struct SetZeroFunctor<CPUDevice, T>
+DEFINE_SETZERO_CPU(float);
+DEFINE_SETZERO_CPU(double);
+DEFINE_SETZERO_CPU(int32);
+DEFINE_SETZERO_CPU(complex64);
+#undef DEFINE_SETZERO_CPU
+
+} // end namespace functor
+
+template <typename Device, typename T>
+class FillOp : public OpKernel {
+ public:
+ explicit FillOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& Tdims = context->input(0);
+ OP_REQUIRES(context, TensorShapeUtils::IsLegacyVector(Tdims.shape()),
+ errors::InvalidArgument("dims must be a vector of int32."));
+ const Tensor& Tvalue = context->input(1);
+ OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(Tvalue.shape()),
+ errors::InvalidArgument("value must be a scalar."));
+ auto dims = Tdims.flat<int32>();
+ for (int i = 0; i < dims.size(); i++) {
+ OP_REQUIRES(context, dims(i) >= 0,
+ errors::InvalidArgument("dims[", i, "] = ", dims(i),
+ " must be nonnegative."));
+ }
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_output(
+ 0, TensorShapeUtils::MakeShape(
+ reinterpret_cast<const int32*>(dims.data()), dims.size()),
+ &out));
+ functor::FillFunctor<Device, T> functor;
+ functor(context->eigen_device<Device>(), out->flat<T>(),
+ Tvalue.scalar<T>());
+ }
+};
+
+#define REGISTER_KERNEL(D, TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("Fill") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<TYPE>("T") \
+ .HostMemory("dims"), \
+ FillOp<D##Device, TYPE>);
+
+#define REGISTER_CPU_KERNEL(TYPE) REGISTER_KERNEL(CPU, TYPE)
+TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL);
+#undef REGISTER_CPU_KERNEL
+
+#if GOOGLE_CUDA
+REGISTER_KERNEL(GPU, float);
+REGISTER_KERNEL(GPU, double);
+REGISTER_KERNEL(GPU, uint8);
+REGISTER_KERNEL(GPU, int8);
+REGISTER_KERNEL(GPU, int16);
+REGISTER_KERNEL(GPU, int64);
+// Currently we do not support filling strings and complex64 on GPU
+
+#endif // GOOGLE_CUDA
+
+#undef REGISTER_KERNEL
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Fill")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int32>("T")
+ .HostMemory("dims")
+ .HostMemory("value")
+ .HostMemory("output"),
+ FillOp<CPUDevice, int32>);
+
+template <typename Device, typename T>
+class ZerosLikeOp : public OpKernel {
+ public:
+ explicit ZerosLikeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& input = ctx->input(0);
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &out));
+ Tensor zero(DataTypeToEnum<T>::value, {1});
+ zero.scalar<T>().setZero();
+ const Tensor& zero_cref = zero;
+ functor::FillFunctor<Device, T> functor;
+ functor(ctx->eigen_device<Device>(), out->flat<T>(), zero_cref.scalar<T>());
+ }
+};
+
+#define REGISTER_KERNEL(type, dev) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ZerosLike").Device(DEVICE_##dev).TypeConstraint<type>("T"), \
+ ZerosLikeOp<dev##Device, type>)
+
+#define REGISTER_CPU(type) REGISTER_KERNEL(type, CPU)
+TF_CALL_ALL_TYPES(REGISTER_CPU);
+#undef REGISTER_CPU
+
+#if GOOGLE_CUDA
+REGISTER_KERNEL(float, GPU);
+REGISTER_KERNEL(double, GPU);
+#endif // GOOGLE_CUDA
+
+#undef REGISTER_KERNEL
+
+class PlaceholderOp : public OpKernel {
+ public:
+ explicit PlaceholderOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &expected_shape_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ if (expected_shape_.dims() > 0) {
+ OP_REQUIRES(ctx, false,
+ errors::InvalidArgument(
+ "You must feed a value for placeholder tensor '", name(),
+ "' with dtype ", DataTypeString(output_type(0)),
+ " and shape ", expected_shape_.DebugString()));
+ } else {
+ OP_REQUIRES(ctx, false,
+ errors::InvalidArgument(
+ "You must feed a value for placeholder tensor '", name(),
+ "' with dtype ", DataTypeString(output_type(0))));
+ }
+ }
+
+ private:
+ TensorShape expected_shape_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE_CPU), PlaceholderOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/constant_op.h b/tensorflow/core/kernels/constant_op.h
new file mode 100644
index 0000000000..20a5c9c42f
--- /dev/null
+++ b/tensorflow/core/kernels/constant_op.h
@@ -0,0 +1,25 @@
+#ifndef TENSORFLOW_KERNELS_CONSTANT_OP_H_
+#define TENSORFLOW_KERNELS_CONSTANT_OP_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+// ConstantOp returns a tensor specified by ConstantOpDef.
+class ConstantOp : public OpKernel {
+ public:
+ explicit ConstantOp(OpKernelConstruction* ctx);
+ void Compute(OpKernelContext* ctx) override;
+ bool IsExpensive() override { return false; }
+ ~ConstantOp() override;
+
+ private:
+ Tensor tensor_;
+ TF_DISALLOW_COPY_AND_ASSIGN(ConstantOp);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_CONSTANT_OP_H_
diff --git a/tensorflow/core/kernels/constant_op_gpu.cu.cc b/tensorflow/core/kernels/constant_op_gpu.cu.cc
new file mode 100644
index 0000000000..64502378bd
--- /dev/null
+++ b/tensorflow/core/kernels/constant_op_gpu.cu.cc
@@ -0,0 +1,89 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/fill_functor.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace Eigen {
+namespace internal {
+
+template <typename T>
+struct scalar_const_op {
+ typedef typename packet_traits<T>::type Packet;
+
+ const T* val;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ scalar_const_op(const scalar_const_op& x)
+ : val(x.val) {}
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_const_op(const T* v) : val(v) {}
+
+ template <typename Index>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(Index,
+ Index = 0) const {
+ return *val;
+ }
+
+ template <typename Index>
+ EIGEN_STRONG_INLINE const Packet packetOp(Index, Index = 0) const {
+ return internal::pset1<Packet>(*val);
+ }
+};
+
+template <typename T>
+struct functor_traits<scalar_const_op<T> > {
+ enum {
+ Cost = 1,
+ PacketAccess = packet_traits<T>::Vectorizable,
+ IsRepeatable = true
+ };
+};
+
+} // end namespace internal
+} // end namespace Eigen
+
+namespace tensorflow {
+
+namespace functor {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// Partial specialization FillFunctor<Device=GPUDevice, T>
+template <typename T>
+struct FillFunctor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstScalar in) {
+ Eigen::internal::scalar_const_op<T> f(in.data());
+ out.device(d) = out.nullaryExpr(f);
+ }
+};
+
+#define DEFINE_FILL_GPU(T) template struct FillFunctor<GPUDevice, T>
+DEFINE_FILL_GPU(float);
+DEFINE_FILL_GPU(double);
+DEFINE_FILL_GPU(int32);
+DEFINE_FILL_GPU(uint8);
+DEFINE_FILL_GPU(int16);
+DEFINE_FILL_GPU(int8);
+DEFINE_FILL_GPU(int64);
+#undef DEFINE_FILL_GPU
+
+// Partial specialization of FillFunctor<Device=GPUDevice, T>.
+template <typename T>
+struct SetZeroFunctor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out) {
+ out.device(d) = out.constant(0);
+ }
+};
+
+#define DEFINE_SETZERO_GPU(T) template struct SetZeroFunctor<GPUDevice, T>
+DEFINE_SETZERO_GPU(float);
+#undef DEFINE_SETZERO_GPU
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/constant_op_test.cc b/tensorflow/core/kernels/constant_op_test.cc
new file mode 100644
index 0000000000..f5a464c07c
--- /dev/null
+++ b/tensorflow/core/kernels/constant_op_test.cc
@@ -0,0 +1,43 @@
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+// Returns graph containing "num" const nodes. If 'sequential' is
+// true, make sure all constants are executed sequentially in the
+// graph by adding control dependencies.
+static Graph* ManyConsts(int num, bool sequential) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Node* prev = nullptr;
+ for (int i = 0; i < num; ++i) {
+ Tensor c(DT_FLOAT, TensorShape({}));
+ c.scalar<float>()() = i;
+ Node* curr = test::graph::Constant(g, c);
+ if (sequential && prev != nullptr) {
+ g->AddControlEdge(prev, curr);
+ }
+ prev = curr;
+ }
+ return g;
+}
+
+static void BM_ManyConsts_Parallel(int iters, int num) {
+ testing::ItemsProcessed(static_cast<int64>(iters) * num);
+ test::Benchmark("cpu", ManyConsts(num, false /* !sequential */)).Run(iters);
+}
+BENCHMARK(BM_ManyConsts_Parallel)->Range(1, 1 << 10);
+
+static void BM_ManyConsts_Sequential(int iters, int num) {
+ testing::ItemsProcessed(static_cast<int64>(iters) * num);
+ test::Benchmark("cpu", ManyConsts(num, true /* sequential */)).Run(iters);
+}
+BENCHMARK(BM_ManyConsts_Sequential)->Range(1, 1 << 10);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc
new file mode 100644
index 0000000000..bc44a7f7cc
--- /dev/null
+++ b/tensorflow/core/kernels/control_flow_ops.cc
@@ -0,0 +1,359 @@
+#include "tensorflow/core/kernels/control_flow_ops.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+// A switch op has two inputs and two outputs. It forwards the value of
+// Input:0 to the output specified by input:1. Input:1 is a boolean tensor.
+// Input:0 is forwarded to output:0 if input:1 is false, otherwise to
+// output:1.
+class SwitchOp : public OpKernel {
+ public:
+ explicit SwitchOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& outputPorts = context->input(1);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(outputPorts.shape()),
+ errors::InvalidArgument("The second input must be a scalar, "
+ "but it has shape ",
+ outputPorts.shape().ShortDebugString()));
+
+ bool pred = outputPorts.scalar<bool>()();
+ int port = (pred) ? 1 : 0;
+ if (IsRefType(context->input_dtype(0))) {
+ context->forward_ref_input_to_ref_output(0, port);
+ } else {
+ context->set_output(port, context->input(0));
+ }
+ }
+
+ bool IsExpensive() override { return false; }
+
+ ~SwitchOp() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SwitchOp);
+};
+
+#define REGISTER_CPU_SWITCH(type) \
+ REGISTER_KERNEL_BUILDER(Name("Switch") \
+ .Device(DEVICE_CPU) \
+ .HostMemory("pred") \
+ .TypeConstraint<type>("T"), \
+ SwitchOp)
+
+#define REGISTER_CPU_REF_SWITCH(type) \
+ REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
+ .Device(DEVICE_CPU) \
+ .HostMemory("pred") \
+ .TypeConstraint<type>("T"), \
+ SwitchOp)
+
+#define REGISTER_GPU_SWITCH(type) \
+ REGISTER_KERNEL_BUILDER(Name("Switch") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("pred") \
+ .TypeConstraint<type>("T"), \
+ SwitchOp)
+
+#define REGISTER_GPU_REF_SWITCH(type) \
+ REGISTER_KERNEL_BUILDER(Name("RefSwitch") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("pred") \
+ .TypeConstraint<type>("T"), \
+ SwitchOp)
+
+TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH);
+TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SWITCH);
+REGISTER_GPU_SWITCH(bool);
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_REF_SWITCH);
+REGISTER_GPU_REF_SWITCH(int32);
+REGISTER_GPU_REF_SWITCH(bool);
+
+#undef REGISTER_CPU_SWITCH
+#undef REGISTER_CPU_REF_SWITCH
+#undef REGISTER_GPU_SWITCH
+#undef REGISTER_GPU_REF_SWITCH
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Switch")
+ .Device(DEVICE_GPU)
+ .HostMemory("data")
+ .HostMemory("pred")
+ .HostMemory("output_false")
+ .HostMemory("output_true")
+ .TypeConstraint<int32>("T"),
+ SwitchOp);
+
+class RefSelectOp : public OpKernel {
+ public:
+ explicit RefSelectOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("N", &num_ref_inputs_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& index_tensor = context->input(0);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(index_tensor.shape()),
+ errors::InvalidArgument("Index must be a scalar, "
+ "but it has shape ",
+ index_tensor.shape().ShortDebugString()));
+
+ int32 index = index_tensor.scalar<int32>()();
+
+ OP_REQUIRES(context, index >= 0 && index < num_ref_inputs_,
+ errors::InvalidArgument("Index must be in the range [0, ",
+ num_ref_inputs_, ") but got ", index));
+ context->forward_ref_input_to_ref_output(index + 1, 0);
+ }
+
+ bool IsExpensive() override { return false; }
+
+ ~RefSelectOp() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RefSelectOp);
+
+ private:
+ int num_ref_inputs_;
+};
+
+#define REGISTER_CPU_REF_SELECT(type) \
+ REGISTER_KERNEL_BUILDER(Name("RefSelect") \
+ .Device(DEVICE_CPU) \
+ .HostMemory("index") \
+ .TypeConstraint<type>("T"), \
+ RefSelectOp)
+TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SELECT);
+
+#undef REGISTER_CPU_REF_SWITCH
+
+// A merge op has n inputs and two outputs. It forwards the value of the
+// first input that becomes available to its first output, and the
+// index of the first input to its second output.
+class MergeOp : public OpKernel {
+ public:
+ explicit MergeOp(OpKernelConstruction* context) : OpKernel(context) {
+ const DataType dt = context->input_type(0);
+ const int num_in = context->num_inputs();
+ OP_REQUIRES_OK(context, context->MatchSignature(DataTypeVector(num_in, dt),
+ {dt, DT_INT32}));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ bool input_seen = false;
+ for (int i = 0; i < context->num_inputs(); ++i) {
+ if (context->has_input(i)) {
+ if (input_seen) {
+ context->SetStatus(errors::Internal(
+ "Merge can not have more than one valid input."));
+ return;
+ }
+ input_seen = true;
+
+ context->set_output(0, context->input(i));
+ Tensor* value_index = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}),
+ &value_index));
+ value_index->scalar<int32>()() = i;
+ }
+ }
+ }
+
+ bool IsExpensive() override { return false; }
+
+ ~MergeOp() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(MergeOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp);
+
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Merge") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("value_index"), \
+ MergeOp);
+
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+
+#undef REGISTER_GPU_KERNEL
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Merge")
+ .Device(DEVICE_GPU)
+ .HostMemory("inputs")
+ .HostMemory("output")
+ .HostMemory("value_index")
+ .TypeConstraint<int32>("T"),
+ MergeOp);
+
+// An enter op has one input and one output. It creates or finds
+// the child frame that is uniquely identified by the frame_name,
+// and makes its input available to the child frame.
+class EnterOp : public OpKernel {
+ public:
+ explicit EnterOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ if (IsRefType(context->input_dtype(0))) {
+ context->forward_ref_input_to_ref_output(0, 0);
+ } else {
+ context->set_output(0, context->input(0));
+ }
+ }
+
+ bool IsExpensive() override { return false; }
+
+ ~EnterOp() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(EnterOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_CPU), EnterOp);
+REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp);
+
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Enter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp);
+#define REGISTER_GPU_REF_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RefEnter").Device(DEVICE_GPU).TypeConstraint<type>("T"), EnterOp);
+
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+TF_CALL_NUMBER_TYPES(REGISTER_GPU_REF_KERNEL);
+
+#undef REGISTER_GPU_KERNEL
+#undef REGISTER_GPU_REF_KERNEL
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Enter")
+ .Device(DEVICE_GPU)
+ .HostMemory("data")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T"),
+ EnterOp);
+
+// An exit op has one input and one output. It exits the current
+// frame to its parent frame, and makes its input available to the
+// parent frame.
+class ExitOp : public OpKernel {
+ public:
+ explicit ExitOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ context->set_output(0, context->input(0));
+ }
+
+ bool IsExpensive() override { return false; }
+
+ ~ExitOp() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExitOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_CPU), ExitOp);
+
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Exit").Device(DEVICE_GPU).TypeConstraint<type>("T"), ExitOp);
+
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+
+#undef REGISTER_GPU_KERNEL
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Exit")
+ .Device(DEVICE_GPU)
+ .HostMemory("data")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T"),
+ ExitOp);
+
+// A next_iteration op has one input and one output. It makes its input
+// available to the next iteration.
+class NextIterationOp : public OpKernel {
+ public:
+ explicit NextIterationOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ context->set_output(0, context->input(0));
+ }
+
+ bool IsExpensive() override { return false; }
+
+ ~NextIterationOp() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(NextIterationOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_CPU),
+ NextIterationOp);
+
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("NextIteration").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ NextIterationOp);
+
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+
+#undef REGISTER_GPU_KERNEL
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("NextIteration")
+ .Device(DEVICE_GPU)
+ .HostMemory("data")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T"),
+ NextIterationOp);
+
+// A LoopCond op has one input and one output. The input is a boolean
+// scalar representing the taken branches of the "pivot" Switch that
+// determines loop termination. As a contract, any high-level front-end
+// should always use port '0' of the "pivot" switches for loop exit.
+class LoopCondOp : public OpKernel {
+ public:
+ explicit LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ context->set_output(0, context->input(0));
+ }
+
+ bool IsExpensive() override { return false; }
+
+ ~LoopCondOp() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp);
+REGISTER_KERNEL_BUILDER(Name("LoopCond")
+ .Device(DEVICE_GPU)
+ .HostMemory("input")
+ .HostMemory("output"),
+ LoopCondOp);
+
+// ControlTrigger kernels
+REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_CPU),
+ ControlTriggerOp);
+
+REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE_GPU),
+ ControlTriggerOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/control_flow_ops.h b/tensorflow/core/kernels/control_flow_ops.h
new file mode 100644
index 0000000000..184cc9fb63
--- /dev/null
+++ b/tensorflow/core/kernels/control_flow_ops.h
@@ -0,0 +1,22 @@
+#ifndef TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_
+#define TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+// A ControlTriggerOp is similar to a NoOp. However, it always treats the input
+// control edges as Live edges. Its primary use so far is in the scheduling of
+// recvs, where we add ControlTrigger nodes and use them to trigger recvs. We
+// allow ControlTrigger nodes to be enabled by dead nodes.
+class ControlTriggerOp : public OpKernel {
+ public:
+ explicit ControlTriggerOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+ void Compute(OpKernelContext* context) override {}
+ bool IsExpensive() override { return false; }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_
diff --git a/tensorflow/core/kernels/control_flow_ops_test.cc b/tensorflow/core/kernels/control_flow_ops_test.cc
new file mode 100644
index 0000000000..52bc11abf0
--- /dev/null
+++ b/tensorflow/core/kernels/control_flow_ops_test.cc
@@ -0,0 +1,71 @@
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+// Tests for the switch op
+class SwitchOpTest : public OpsTestBase {
+ protected:
+ void Initialize(DataType dt) {
+ RequireDefaultOps();
+ ASSERT_OK(NodeDefBuilder("op", "Switch")
+ .Input(FakeInput(dt))
+ .Input(FakeInput())
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(SwitchOpTest, Int32Success_6_s0) {
+ Initialize(DT_INT32);
+ AddInputFromArray<int32>(TensorShape({6}), {1, 2, 3, 4, 5, 6});
+ AddInputFromArray<bool>(TensorShape({}), {false});
+ ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_INT32, TensorShape({6}));
+ test::FillValues<int32>(&expected, {1, 2, 3, 4, 5, 6});
+ test::ExpectTensorEqual<int32>(expected, *GetOutput(0));
+ EXPECT_EQ(nullptr, GetOutput(1));
+}
+
+TEST_F(SwitchOpTest, Int32Success_6_s1) {
+ Initialize(DT_INT32);
+ AddInputFromArray<int32>(TensorShape({6}), {1, 2, 3, 4, 5, 6});
+ AddInputFromArray<bool>(TensorShape({}), {true});
+ ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_INT32, TensorShape({6}));
+ test::FillValues<int32>(&expected, {1, 2, 3, 4, 5, 6});
+ test::ExpectTensorEqual<int32>(expected, *GetOutput(1));
+ EXPECT_EQ(nullptr, GetOutput(0));
+}
+
+TEST_F(SwitchOpTest, Int32Success_2_3_s0) {
+ Initialize(DT_INT32);
+ AddInputFromArray<int32>(TensorShape({2, 3}), {1, 2, 3, 4, 5, 6});
+ AddInputFromArray<bool>(TensorShape({}), {false});
+ ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_INT32, TensorShape({2, 3}));
+ test::FillValues<int32>(&expected, {1, 2, 3, 4, 5, 6});
+ test::ExpectTensorEqual<int32>(expected, *GetOutput(0));
+ EXPECT_EQ(nullptr, GetOutput(1));
+}
+
+TEST_F(SwitchOpTest, StringSuccess_s1) {
+ Initialize(DT_STRING);
+ AddInputFromArray<string>(TensorShape({6}), {"A", "b", "C", "d", "E", "f"});
+ AddInputFromArray<bool>(TensorShape({}), {true});
+ ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({6}));
+ test::FillValues<string>(&expected, {"A", "b", "C", "d", "E", "f"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(1));
+ EXPECT_EQ(nullptr, GetOutput(0));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
new file mode 100644
index 0000000000..2fb623244c
--- /dev/null
+++ b/tensorflow/core/kernels/conv_2d.h
@@ -0,0 +1,127 @@
+#ifndef TENSORFLOW_KERNELS_CONV_2D_H_
+#define TENSORFLOW_KERNELS_CONV_2D_H_
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+// TODO(yangke): revisit these operations and in particular, see if we can
+// combine all of them into just one operation without causing nvcc to
+// timeout.
+template <typename Device, typename T, int Dims>
+struct ShuffleAndReverse {
+ void operator()(const Device& d, typename TTypes<T, Dims>::ConstTensor input,
+ const Eigen::DSizes<Eigen::DenseIndex, Dims>& order,
+ const Eigen::array<bool, Dims>& reverse_dims,
+ typename TTypes<T, Dims>::Tensor output) {
+ output.device(d) = input.shuffle(order).reverse(reverse_dims);
+ }
+};
+
+template <typename Device, typename T, int Dims>
+struct InflatePadAndShuffle {
+ void operator()(
+ const Device& d, typename TTypes<T, Dims>::ConstTensor input,
+ const Eigen::DSizes<Eigen::DenseIndex, Dims>& strides,
+ const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, Dims>& pad_dims,
+ const Eigen::DSizes<Eigen::DenseIndex, Dims>& order,
+ typename TTypes<T, Dims>::Tensor output) {
+ output.device(d) = input.inflate(strides).pad(pad_dims).shuffle(order);
+ }
+};
+
+template <typename Device, typename Input, typename Filter, typename Output>
+void SpatialConvolutionFunc(const Device& d, Output output, Input input,
+ Filter filter, int stride,
+ const Eigen::PaddingType& padding) {
+ output.device(d) = Eigen::SpatialConvolution(input, filter, stride, padding);
+}
+
+template <typename Device, typename T>
+struct SpatialConvolution {
+ void operator()(const Device& d, typename TTypes<T, 4>::Tensor output,
+ typename TTypes<T, 4>::ConstTensor input,
+ typename TTypes<T, 4>::ConstTensor filter, int stride,
+ const Eigen::PaddingType& padding) {
+ SpatialConvolutionFunc(d, output, input, filter, stride, padding);
+ }
+};
+
+template <typename Device, typename T>
+struct SpatialConvolutionBackwardInput {
+ void operator()(const Device& d, typename TTypes<T, 4>::Tensor input_backward,
+ typename TTypes<T, 4>::ConstTensor kernel,
+ typename TTypes<T, 4>::ConstTensor output_backward,
+ int input_rows, int input_cols, int stride) {
+ input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput(
+ kernel, output_backward, input_rows, input_cols, stride);
+ }
+};
+
+template <typename Device, typename T>
+struct SpatialConvolutionBackwardKernel {
+ void operator()(const Device& d,
+ typename TTypes<T, 4>::Tensor kernel_backward,
+ typename TTypes<T, 4>::ConstTensor input,
+ typename TTypes<T, 4>::ConstTensor output_backward,
+ int kernel_rows, int kernel_cols, int stride) {
+ kernel_backward.device(d) = Eigen::SpatialConvolutionBackwardKernel(
+ input, output_backward, kernel_rows, kernel_cols, stride);
+ }
+};
+
+// TODO(vrv): Figure out how to use the MatMulFunctor in matmul_op.h.
+// My initial attempt to do this compiled but failed in the pytest
+// due to a swigdeps error.
+template <typename Device, typename T>
+struct MatMulConvFunctor {
+ // Computes on device "d": out = in0 * in1, where * is matrix
+ // multiplication.
+ void operator()(
+ const Device& d, typename TTypes<T, 2>::Tensor out,
+ typename TTypes<T, 2>::ConstTensor in0,
+ typename TTypes<T, 2>::ConstTensor in1,
+ const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
+ out.device(d) = in0.contract(in1, dim_pair);
+ }
+};
+
+template <typename Device, typename T>
+struct TransformFilter {
+ void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
+ typename TTypes<T, 4>::Tensor out) {
+ out.device(d) = in.shuffle(Eigen::DSizes<Eigen::DenseIndex, 4>(3, 2, 0, 1));
+ }
+};
+
+template <typename Device, typename T>
+struct TransformDepth {
+ void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
+ const Eigen::DSizes<Eigen::DenseIndex, 4>& shuffle,
+ typename TTypes<T, 4>::Tensor out) {
+ out.device(d) = in.shuffle(shuffle);
+ }
+};
+
+template <typename Device, typename T>
+struct PadInput {
+ void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
+ int padding_rows_left, int padding_rows_right,
+ int padding_cols_left, int padding_cols_right,
+ typename TTypes<T, 4>::Tensor out) {
+ Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, 4> padding;
+ padding[0] = std::make_pair(0, 0);
+ padding[1] = std::make_pair(padding_rows_left, padding_rows_right);
+ padding[2] = std::make_pair(padding_cols_left, padding_cols_right);
+ padding[3] = std::make_pair(0, 0);
+ out.device(d) = in.pad(padding);
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_CONV_2D_H_
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
new file mode 100644
index 0000000000..bb21d7003c
--- /dev/null
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -0,0 +1,1190 @@
+// See docs in ../ops/nn_ops.cc.
+
+#define USE_EIGEN_TENSOR
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/util/use_cudnn.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/public/tensor.h"
+
+#if GOOGLE_CUDA
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
+#include "tensorflow/stream_executor/stream.h"
+#endif // GOOGLE_CUDA
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+// The operation to compute Conv2D gradients.
+//
+//
+// To compute the gradients for Conv2D, we need three input tensors:
+// input, filter, and backprop for output.
+// And we need to compute two backprops: one for input and one for filter. We
+// compute them in two different kernels.
+
+// Both backprops can be computed as straightforward conv2d.
+//
+// Consider a case where the input is 3x3 and the filter is 2x1:
+//
+// INPUT = [ A B C ]
+// [ D E F ]
+// [ G H I ]
+//
+// where each "A", "B", etc is batch x in_depth
+//
+// FILTER = [ X Y ]
+//
+// where both "X" and "Y" are in_depth x out_depth
+//
+// With VALID padding, the output is 3x2:
+//
+// OUTPUT = [ a b ]
+// [ c d ]
+// [ e f ]
+//
+// where each "a", "b", etc is batch x out_depth
+//
+// So we have:
+//
+// a = A * X + B * Y
+// b = B * X + C * Y
+// c = D * X + E * Y
+// d = E * X + F * Y
+// e = G * X + H * Y
+// f = H * X + I * Y
+//
+// So when we have backprops for the outputs (we denote them by
+// a', b', ... ):
+//
+// The backprops for the input are:
+//
+// A' = a' * X^t
+// B' = a' * Y^t + b' * X^t
+// C' = b' * Y^t
+// ...
+//
+// This is essentially computing a 2d conv of
+//
+// INPUT = [ 0 a' b' 0 ]
+// [ 0 c' d' 0 ]
+// [ 0 e' f' 0 ]
+// and
+//
+// FILTER = [ Y^t X^t ]
+//
+// The backprops for the filter are:
+//
+// X' = A^t * a' + B^t * b' + D^t * c' + E^t * d' + G^t * e' + H^t * f'
+// Y' = B^t * a' + C^t * b' + E^t + c' + F^t * d' + H^t * e' + I^t * f'
+//
+// This is essentially computing a 2d conv of
+//
+// INPUT = [ A^t B^t C^t ]
+// [ D^t E^t F^t ]
+// [ G^t H^t I^t ]
+//
+// and
+//
+// FILTER = [ a' b' ]
+// [ c' d' ]
+// [ e' f' ]
+//
+//
+//////////////////////////////////////////////////////////
+//
+// With stride more than one, it's a bit more complicated (we will need to
+// create holes to the backprop).
+//
+// Consider the case where
+//
+// INPUT = [ A B C D E ]
+// [ F G H I J ]
+// [ K L M N O ]
+// and
+//
+// FILTER = [ X Y Z ]
+//
+// with stride 2.
+//
+// The output will be
+//
+// OUTPUT = [ a b ]
+// [ c d ]
+//
+// where:
+//
+// a = A * X + B * Y + C * Z
+// b = C * X + D * Y + E * Z
+// c = K * X + L * Y + M * Z
+// d = M * X + N * Y + O * Z
+//
+//
+// To compute the backprop for INPUT, we need to convolve
+//
+// INPUT = [ 0 0 a' 0 b' 0 0 ]
+// [ 0 0 0 0 0 0 0 ]
+// [ 0 0 c' 0 d' 0 0 ]
+//
+// (notice the holes in INPUT)
+//
+// and
+//
+// FILTER = [ Z^t Y^t X^t ]
+//
+// with stride 1.
+//
+// To compute the backprop for FILTER, we need to convolve
+
+//
+// INPUT = [ A^t B^t C^t D^t E^t ]
+// [ F^t G^t H^t I^t J^t ]
+// [ K^t L^t M^t N^t O^t ]
+// and
+//
+// FILTER = [ a' 0 b' ]
+// [ 0 0 0 ]
+// [ c' 0 d' ]
+//
+// (notice the holes in FILTER)
+//
+//
+// with stride 1
+//
+//////////////////////////////////////////////////////////
+//
+//
+// The case for SAME padding is in fact very similar to VALID -- we just
+// need to pad the input tensor a bit when computing the filter_backprop.
+
+// Common code between the two kernels: verifies that the dimensions all match
+// and extract the padded rows and columns.
+#define EXTRACT_AND_VERIFY_DIMENSIONS(label) \
+ const Tensor& out_backprop = context->input(2); \
+ OP_REQUIRES( \
+ context, input_shape.dims() == 4, \
+ errors::InvalidArgument(label, ": input must be 4-dimensional")); \
+ OP_REQUIRES( \
+ context, filter_shape.dims() == 4, \
+ errors::InvalidArgument(label, ": filter must be 4-dimensional")); \
+ OP_REQUIRES( \
+ context, out_backprop.dims() == 4, \
+ errors::InvalidArgument(label, ": out_backprop must be 4-dimensional")); \
+ const int64 batch = input_shape.dim_size(0); \
+ OP_REQUIRES( \
+ context, batch == out_backprop.dim_size(0), \
+ errors::InvalidArgument( \
+ label, ": input and out_backprop must have the same batch size")); \
+ const int64 input_rows = input_shape.dim_size(1); \
+ const int64 input_cols = input_shape.dim_size(2); \
+ const int64 filter_rows = filter_shape.dim_size(0); \
+ const int64 filter_cols = filter_shape.dim_size(1); \
+ const int64 output_rows = out_backprop.dim_size(1); \
+ const int64 output_cols = out_backprop.dim_size(2); \
+ const int64 in_depth = input_shape.dim_size(3); \
+ OP_REQUIRES(context, in_depth == filter_shape.dim_size(2), \
+ errors::InvalidArgument( \
+ label, ": input and filter must have the same depth")); \
+ const int64 out_depth = filter_shape.dim_size(3); \
+ OP_REQUIRES( \
+ context, out_depth == out_backprop.dim_size(3), \
+ errors::InvalidArgument( \
+ label, ": filter and out_backprop must have the same out_depth")); \
+ const auto stride = strides_[1]; \
+ int out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; \
+ if (filter_cols == filter_rows && filter_rows == 1 && stride == 1) { \
+ out_rows = input_rows; \
+ out_cols = input_cols; \
+ } else { \
+ OP_REQUIRES_OK( \
+ context, Get2dOutputSize(input_rows, input_cols, filter_rows, \
+ filter_cols, stride, stride, padding_, \
+ &out_rows, &out_cols, &pad_rows, &pad_cols)); \
+ } \
+ OP_REQUIRES( \
+ context, output_rows == out_rows, \
+ errors::InvalidArgument( \
+ label, ": Number of rows of out_backprop doesn't match computed: ", \
+ "actual = ", output_rows, ", computed = ", out_rows)); \
+ OP_REQUIRES( \
+ context, output_cols == out_cols, \
+ errors::InvalidArgument( \
+ label, ": Number of cols of out_backprop doesn't match computed: ", \
+ "actual = ", output_cols, ", computed = ", out_cols)); \
+ const auto expanded_out_rows = (output_rows - 1) * stride + 1; \
+ const auto expanded_out_cols = (output_cols - 1) * stride + 1; \
+ const auto padded_out_rows = input_rows + filter_rows - 1; \
+ const auto padded_out_cols = input_cols + filter_cols - 1; \
+ const auto top_pad_rows = filter_rows - 1 - pad_rows; \
+ const auto left_pad_cols = filter_cols - 1 - pad_cols; \
+ const auto bottom_pad_rows = \
+ padded_out_rows - expanded_out_rows - top_pad_rows; \
+ const auto right_pad_cols = \
+ padded_out_cols - expanded_out_cols - left_pad_cols; \
+ Eigen::DSizes<Eigen::DenseIndex, 4> strides{1, stride, stride, 1}; \
+ VLOG(2) << "Conv2d: " << label \
+ << ": expanded_out_rows = " << expanded_out_rows \
+ << ", expanded_out_cols = " << expanded_out_cols \
+ << ", filter_rows = " << filter_rows \
+ << ", filter_cols = " << filter_cols \
+ << ", padded_out_rows = " << padded_out_rows \
+ << ", padded_out_cols = " << padded_out_cols \
+ << ", top_pad_rows = " << top_pad_rows \
+ << ", left_pad_cols = " << left_pad_cols \
+ << ", bottom_pad_rows = " << bottom_pad_rows \
+ << ", right_pad_cols = " << right_pad_cols \
+ << ", strides = " << strides[1]
+
+namespace {
+TensorShape VectorToShape(const TTypes<int32>::ConstVec& sizes) {
+ TensorShape shape;
+
+ using Index = TTypes<int32>::ConstVec::Index;
+ const Index dims = sizes.size();
+ for (Index i = 0; i < dims; ++i) {
+ shape.AddDim(sizes(i));
+ }
+
+ return shape;
+}
+} // namespace
+
+// The fast versions using eigen computations directly. They are only enabled
+// for CPU for now since nvcc times out when trying to compile them.
+// TODO(yangke): enable them for GPUs when we have a faster compiler.
+
+template <typename Device, class T>
+class Conv2DFastBackpropInputOp : public OpKernel {
+ public:
+ explicit Conv2DFastBackpropInputOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+ OP_REQUIRES(context, strides_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, strides_[1] == strides_[2],
+ errors::InvalidArgument(
+ "Current implementation only supports equal length "
+ "strides in the row and column dimensions."));
+ OP_REQUIRES(context, (strides_[0] == 1 && strides_[3] == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input_sizes = context->input(0);
+ const Tensor& filter = context->input(1);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsVector(input_sizes.shape()),
+ errors::InvalidArgument(
+ "Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
+ input_sizes.dims()));
+ TensorShape input_shape = VectorToShape(input_sizes.vec<int32>());
+ const TensorShape& filter_shape = filter.shape();
+
+ EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropInput");
+ Tensor* in_backprop = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input_shape, &in_backprop));
+ // Need to flip the input_rows and input_cols when passing to eigen.
+ functor::SpatialConvolutionBackwardInput<Device, T>()(
+ context->eigen_device<Device>(), in_backprop->tensor<T, 4>(),
+ filter.tensor<T, 4>(), out_backprop.tensor<T, 4>(), input_cols,
+ input_rows, stride);
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv2DFastBackpropInputOp);
+};
+
+// Based on implementation written by Yangqing Jia (jiayq).
+template <typename Device, class T>
+class Conv2DCustomBackpropInputOp : public OpKernel {
+ public:
+ explicit Conv2DCustomBackpropInputOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+ OP_REQUIRES(context, strides_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, strides_[1] == strides_[2],
+ errors::InvalidArgument(
+ "Current implementation only supports equal length "
+ "strides in the row and column dimensions."));
+ OP_REQUIRES(
+ context, (strides_[0] == 1 && strides_[3] == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input_sizes = context->input(0);
+ const Tensor& filter = context->input(1);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsVector(input_sizes.shape()),
+ errors::InvalidArgument(
+ "Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
+ input_sizes.dims()));
+ TensorShape input_shape = VectorToShape(input_sizes.vec<int32>());
+ const TensorShape& filter_shape = filter.shape();
+
+ EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropInput");
+ Tensor* in_backprop = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input_shape, &in_backprop));
+
+ // TODO(andydavis) Consider moving code shared with
+ // Conv2DCustomBackpropFilterOp into a shared helper function.
+ int pad_top;
+ int pad_bottom;
+ int pad_left;
+ int pad_right;
+ OP_REQUIRES_OK(
+ context,
+ Get2dOutputSizeVerbose(input_rows, input_cols, filter_rows, filter_cols,
+ stride, stride, padding_, &out_rows, &out_cols,
+ &pad_top, &pad_bottom, &pad_left, &pad_right));
+
+ // The total dimension size of each kernel.
+ const int filter_total_size = filter_rows * filter_cols * in_depth;
+ // The output image size is the spatial size of the output.
+ const int output_image_size = out_rows * out_cols;
+
+ Tensor col_buffer;
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({output_image_size, filter_total_size}), &col_buffer));
+
+ // The input offset corresponding to a single input image.
+ const int input_offset = input_rows * input_cols * in_depth;
+ // The output offset corresponding to a single output image.
+ const int output_offset = out_rows * out_cols * out_depth;
+
+ auto* filter_data = filter.template flat<T>().data();
+ auto* col_buffer_data = col_buffer.template flat<T>().data();
+ auto* out_backprop_data = out_backprop.template flat<T>().data();
+ auto* input_backprop_data = in_backprop->template flat<T>().data();
+
+ typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
+ Eigen::RowMajor>> MatrixMap;
+ typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
+ Eigen::RowMajor>> ConstMatrixMap;
+
+ for (int image_id = 0; image_id < batch; ++image_id) {
+ // Compute gradient into col_buffer.
+ MatrixMap C(col_buffer_data, output_image_size, filter_total_size);
+
+ ConstMatrixMap A(out_backprop_data + output_offset * image_id,
+ output_image_size, out_depth);
+ ConstMatrixMap B(filter_data, filter_total_size, out_depth);
+
+ // TODO(andydavis) Use a multi-threaded matmul implementation here.
+ C.noalias() = A * B.transpose();
+
+ Col2im<T>(col_buffer_data, in_depth, input_rows, input_cols, filter_rows,
+ filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride,
+ stride, input_backprop_data);
+
+ input_backprop_data += input_offset;
+ }
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropInputOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ Conv2DCustomBackpropInputOp<CPUDevice, float>);
+
+REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
+ .Device(DEVICE_CPU)
+ .Label("custom")
+ .TypeConstraint<float>("T"),
+ Conv2DCustomBackpropInputOp<CPUDevice, float>);
+
+REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
+ .Device(DEVICE_CPU)
+ .Label("eigen_tensor")
+ .TypeConstraint<float>("T"),
+ Conv2DFastBackpropInputOp<CPUDevice, float>);
+
+template <typename Device, class T>
+class Conv2DFastBackpropFilterOp : public OpKernel {
+ public:
+ explicit Conv2DFastBackpropFilterOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+ OP_REQUIRES(context, strides_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, strides_[1] == strides_[2],
+ errors::InvalidArgument(
+ "Current implementation only supports equal length "
+ "strides in the row and column dimensions."));
+ OP_REQUIRES(context, (strides_[0] == 1 && strides_[3] == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& filter_sizes = context->input(1);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsVector(filter_sizes.shape()),
+ errors::InvalidArgument(
+ "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
+ filter_sizes.dims()));
+ const TensorShape& input_shape = input.shape();
+ TensorShape filter_shape = VectorToShape(filter_sizes.vec<int32>());
+
+ EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropFilter");
+ Tensor* filter_backprop = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, filter_shape, &filter_backprop));
+
+ // Need to flip the filter_rows and filter_cols when passing to eigen.
+ functor::SpatialConvolutionBackwardKernel<Device, T>()(
+ context->eigen_device<Device>(), filter_backprop->tensor<T, 4>(),
+ input.tensor<T, 4>(), out_backprop.tensor<T, 4>(), filter_cols,
+ filter_rows, stride);
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv2DFastBackpropFilterOp);
+};
+
+// Based on implementation written by Yangqing Jia (jiayq).
+template <typename Device, class T>
+class Conv2DCustomBackpropFilterOp : public OpKernel {
+ public:
+ explicit Conv2DCustomBackpropFilterOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+ OP_REQUIRES(context, strides_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, strides_[1] == strides_[2],
+ errors::InvalidArgument(
+ "Current implementation only supports equal length "
+ "strides in the row and column dimensions."));
+ OP_REQUIRES(
+ context, (strides_[0] == 1 && strides_[3] == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& filter_sizes = context->input(1);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsVector(filter_sizes.shape()),
+ errors::InvalidArgument(
+ "Conv2DCustomBackpropFilter: filter_sizes input must be 1-dim, "
+ "not ",
+ filter_sizes.dims()));
+ const TensorShape& input_shape = input.shape();
+ TensorShape filter_shape = VectorToShape(filter_sizes.vec<int32>());
+
+ EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DCustomBackpropFilter");
+ Tensor* filter_backprop;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, filter_shape, &filter_backprop));
+
+ int pad_top;
+ int pad_bottom;
+ int pad_left;
+ int pad_right;
+ OP_REQUIRES_OK(
+ context,
+ Get2dOutputSizeVerbose(input_rows, input_cols, filter_rows, filter_cols,
+ stride, stride, padding_, &out_rows, &out_cols,
+ &pad_top, &pad_bottom, &pad_left, &pad_right));
+
+ // The total dimension size of each kernel.
+ const int filter_total_size = filter_rows * filter_cols * in_depth;
+ // The output image size is the spatial size of the output.
+ const int output_image_size = out_rows * out_cols;
+
+ Tensor col_buffer;
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({output_image_size, filter_total_size}), &col_buffer));
+
+ // The input offset corresponding to a single input image.
+ const int input_offset = input_rows * input_cols * in_depth;
+ // The output offset corresponding to a single output image.
+ const int output_offset = out_rows * out_cols * out_depth;
+
+ auto* input_data = input.template flat<T>().data();
+ auto* col_buffer_data = col_buffer.template flat<T>().data();
+ auto* out_backprop_data = out_backprop.template flat<T>().data();
+ auto* filter_backprop_data = filter_backprop->template flat<T>().data();
+
+ typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
+ Eigen::RowMajor>> MatrixMap;
+ typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
+ Eigen::RowMajor>> ConstMatrixMap;
+
+ MatrixMap C(filter_backprop_data, filter_total_size, out_depth);
+
+ C.setZero();
+ for (int image_id = 0; image_id < batch; ++image_id) {
+ // When we compute the gradient with respect to the filters, we need to do
+ // im2col to allow gemm-type computation.
+ Im2col<T>(input_data, in_depth, input_rows, input_cols, filter_rows,
+ filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride,
+ stride, col_buffer_data);
+
+ ConstMatrixMap A(col_buffer_data, output_image_size, filter_total_size);
+ ConstMatrixMap B(out_backprop_data + output_offset * image_id,
+ output_image_size, out_depth);
+
+ // Compute gradient with respect to filter.
+ // TODO(andydavis) Use a multi-threaded matmul implementation here.
+ C.noalias() += A.transpose() * B;
+
+ input_data += input_offset;
+ }
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropFilterOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ Conv2DCustomBackpropFilterOp<CPUDevice, float>);
+
+REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")
+ .Device(DEVICE_CPU)
+ .Label("custom")
+ .TypeConstraint<float>("T"),
+ Conv2DCustomBackpropFilterOp<CPUDevice, float>);
+
+REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")
+ .Device(DEVICE_CPU)
+ .Label("eigen_tensor")
+ .TypeConstraint<float>("T"),
+ Conv2DFastBackpropFilterOp<CPUDevice, float>);
+
+// GPU definitions of both ops.
+#if GOOGLE_CUDA
+namespace {
+template <typename T>
+perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
+ uint64 size) {
+ perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
+ size * sizeof(T));
+ perftools::gputools::DeviceMemory<T> typed(wrapped);
+ return typed;
+}
+} // namespace
+
+// The slow version (but compiles for GPU)
+
+// Backprop for input.
+template <typename Device, class T>
+class Conv2DSlowBackpropInputOp : public OpKernel {
+ public:
+ explicit Conv2DSlowBackpropInputOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+ OP_REQUIRES(context, strides_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, strides_[1] == strides_[2],
+ errors::InvalidArgument(
+ "Current implementation only supports equal length "
+ "strides in the row and column dimensions."));
+ OP_REQUIRES(context, (strides_[0] == 1 && strides_[3] == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
+ use_cudnn_ &= CanUseCudnn();
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input_sizes = context->input(0);
+ const Tensor& filter = context->input(1);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsVector(input_sizes.shape()),
+ errors::InvalidArgument(
+ "Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
+ input_sizes.dims()));
+ TensorShape input_shape = VectorToShape(input_sizes.vec<int32>());
+ const TensorShape& filter_shape = filter.shape();
+
+ EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropInput");
+ Tensor* in_backprop = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input_shape, &in_backprop));
+
+ const int padding_rows =
+ (output_rows - 1) * stride + filter_rows - input_rows;
+ const int padding_cols =
+ (output_cols - 1) * stride + filter_cols - input_cols;
+
+ // TODO(keveman): cuDNN only supports equal padding on both sides, so only
+ // calling it when that is true. Remove this check when (if?) cuDNN starts
+ // supporting different padding.
+ bool padding_compatible =
+ (padding_rows % 2 == 0) && (padding_cols % 2 == 0);
+
+ auto* stream = context->op_device_context<GPUDeviceContext>()->stream();
+ OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+
+ if (use_cudnn_ && padding_compatible) {
+ if (filter_rows == 1 && filter_cols == 1 && stride == 1) {
+ // 1x1 filter, so call cublas directly.
+ const uint64 m = batch * input_rows * input_cols;
+ const uint64 k = out_depth;
+ const uint64 n = in_depth;
+
+ auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
+ out_backprop.template flat<T>().size());
+ auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
+ filter.template flat<T>().size());
+ auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
+ in_backprop->template flat<T>().size());
+
+ auto transpose = perftools::gputools::blas::Transpose::kTranspose;
+ auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+
+ bool blas_launch_status =
+ stream->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr,
+ k, a_ptr, k, 0.0f, &c_ptr, n)
+ .ok();
+ if (!blas_launch_status) {
+ context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=",
+ m, ", n=", n, ", k=", k));
+ }
+ return;
+ }
+
+ perftools::gputools::dnn::BatchDescriptor input_desc;
+ input_desc.set_count(batch)
+ .set_height(input_rows)
+ .set_width(input_cols)
+ .set_feature_map_count(in_depth)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::BatchDescriptor output_desc;
+ output_desc.set_count(batch)
+ .set_height(output_rows)
+ .set_width(output_cols)
+ .set_feature_map_count(out_depth)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::FilterDescriptor filter_desc;
+ filter_desc.set_input_filter_height(filter_rows)
+ .set_input_filter_width(filter_cols)
+ .set_input_feature_map_count(in_depth)
+ .set_output_feature_map_count(out_depth);
+ perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
+ conv_desc.set_vertical_filter_stride(stride)
+ .set_horizontal_filter_stride(stride)
+ .set_zero_padding_height(padding_rows / 2)
+ .set_zero_padding_width(padding_cols / 2);
+
+ // NOTE(keveman):
+ // cuDNN only supports the following layouts :
+ // Input : B x D x R x C
+ // Filter : OD x ID x R x C
+ // Whereas, we have
+ // Input : B x R x C x D
+ // Filter : R x C x ID x OD
+ // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C)
+ // The first TransformDepth performs
+ // (B x R x C x D) => (B x D x R x C).
+ // Since the tensor returned from cuDNN is B x D x R x C also,
+ // the second TransformDepth performs
+ // (B x D x R x C) => (B x R x C x D).
+ Tensor transformed_filter;
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({out_depth, in_depth, filter_rows, filter_cols}),
+ &transformed_filter));
+
+ functor::TransformFilter<Device, T>()(context->eigen_device<Device>(),
+ filter.tensor<T, 4>(),
+ transformed_filter.tensor<T, 4>());
+
+ Tensor transformed_out_backprop;
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({batch, out_depth, output_rows, output_cols}),
+ &transformed_out_backprop));
+
+ functor::TransformDepth<Device, T>()(
+ context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
+ Eigen::DSizes<Eigen::DenseIndex, 4>(0, 3, 1, 2),
+ transformed_out_backprop.tensor<T, 4>());
+
+ Tensor pre_transformed_in_backprop;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({batch, in_depth, input_rows, input_cols}),
+ &pre_transformed_in_backprop));
+
+ auto out_backprop_ptr =
+ AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
+ transformed_out_backprop.template flat<T>().size());
+ auto filter_ptr =
+ AsDeviceMemory(transformed_filter.template flat<T>().data(),
+ transformed_filter.template flat<T>().size());
+ auto in_backprop_ptr =
+ AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
+ pre_transformed_in_backprop.template flat<T>().size());
+
+ bool cudnn_launch_status =
+ stream->ThenConvolveBackwardData(filter_desc, filter_ptr, output_desc,
+ out_backprop_ptr, conv_desc,
+ input_desc, &in_backprop_ptr)
+ .ok();
+
+ if (!cudnn_launch_status) {
+ context->SetStatus(errors::Internal(
+ "cuDNN Backward Data function launch failure : input shape(",
+ input_shape.DebugString(), ") filter shape(",
+ filter_shape.DebugString(), ")"));
+ }
+
+ auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
+ functor::TransformDepth<Device, T>()(
+ context->eigen_device<Device>(),
+ toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(),
+ Eigen::DSizes<Eigen::DenseIndex, 4>(0, 2, 3, 1),
+ in_backprop->tensor<T, 4>());
+ } else {
+ // We fill out a padded out_backprop
+ TensorShape padded_out_shape(
+ {batch, padded_out_rows, padded_out_cols, out_depth});
+ Tensor padded_output;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::v(),
+ padded_out_shape, &padded_output));
+
+ Eigen::DSizes<Eigen::DenseIndex, 4> trivial_order{0, 1, 2, 3};
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 4> pad_dims{
+ {{0, 0},
+ {top_pad_rows, bottom_pad_rows},
+ {left_pad_cols, right_pad_cols},
+ {0, 0}}};
+
+ functor::InflatePadAndShuffle<Device, T, 4>()(
+ context->eigen_device<Device>(), out_backprop.tensor<T, 4>(), strides,
+ pad_dims, trivial_order, padded_output.tensor<T, 4>());
+ const Tensor& padded_output_cref = padded_output;
+
+ // We then need to fill a new "reverted" filter
+ // We need to transpose the in_depth and out_depth for the filter and
+ // inverse the rows and cols.
+ TensorShape r_filter_shape(
+ {filter_rows, filter_cols, out_depth, in_depth});
+ Tensor r_filter;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::v(),
+ r_filter_shape, &r_filter));
+
+ Eigen::DSizes<Eigen::DenseIndex, 4> filter_order{0, 1, 3, 2};
+ Eigen::array<bool, 4> filter_rev_dims{true, true, false, false};
+ functor::ShuffleAndReverse<Device, T, 4>()(
+ context->eigen_device<Device>(), filter.tensor<T, 4>(), filter_order,
+ filter_rev_dims, r_filter.tensor<T, 4>());
+ const Tensor& r_filter_cref = r_filter;
+
+ // Now we can call conv_2d directly.
+ functor::SpatialConvolution<Device, T>()(
+ context->eigen_device<Device>(), in_backprop->tensor<T, 4>(),
+ padded_output_cref.tensor<T, 4>(), r_filter_cref.tensor<T, 4>(), 1,
+ BrainPadding2EigenPadding(VALID));
+ }
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+ bool use_cudnn_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropInputOp);
+};
+
+// Backprop for filter.
+template <typename Device, class T>
+class Conv2DSlowBackpropFilterOp : public OpKernel {
+ public:
+ explicit Conv2DSlowBackpropFilterOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+ OP_REQUIRES(context, strides_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, strides_[1] == strides_[2],
+ errors::InvalidArgument(
+ "Current implementation only supports equal length "
+ "strides in the row and column dimensions."));
+ OP_REQUIRES(context, (strides_[0] == 1 && strides_[3] == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
+ use_cudnn_ &= CanUseCudnn();
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& filter_sizes = context->input(1);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsVector(filter_sizes.shape()),
+ errors::InvalidArgument(
+ "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
+ filter_sizes.dims()));
+ const TensorShape& input_shape = input.shape();
+ TensorShape filter_shape = VectorToShape(filter_sizes.vec<int32>());
+
+ EXTRACT_AND_VERIFY_DIMENSIONS("Conv2DBackpropFilter");
+ Tensor* filter_backprop = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, filter_shape, &filter_backprop));
+
+ const int padding_rows =
+ (output_rows - 1) * stride + filter_rows - input_rows;
+ const int padding_cols =
+ (output_cols - 1) * stride + filter_cols - input_cols;
+
+ // TODO(zhengxq): cuDNN only supports equal padding on both sides, so only
+ // calling it when that is true. Remove this check when (if?) cuDNN starts
+ // supporting different padding.
+ bool padding_compatible =
+ (padding_rows % 2 == 0) && (padding_cols % 2 == 0);
+
+ auto* stream = context->op_device_context<GPUDeviceContext>()->stream();
+ OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+
+ if (use_cudnn_ && padding_compatible) {
+ if (filter_rows == 1 && filter_cols == 1 && stride == 1) {
+ const uint64 m = in_depth;
+ const uint64 k = batch * input_rows * input_cols;
+ const uint64 n = out_depth;
+
+ // The shape of output backprop is
+ // [batch, out_rows, out_cols, out_depth]
+ // From cublas's perspective, it is: n x k
+ auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
+ out_backprop.template flat<T>().size());
+
+ // The shape of input is
+ // [batch, in_rows, in_cols, in_depth],
+ // From cublas's perspective, it is: m x k
+ auto b_ptr = AsDeviceMemory(input.template flat<T>().data(),
+ input.template flat<T>().size());
+
+ // the shape of the filter backprop from the conv_2d should be
+ // [1, 1, in_depth, out_depth]
+ // From cublas's perspective, it is: n x m
+ auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(),
+ filter_backprop->template flat<T>().size());
+
+ bool blas_launch_status =
+ stream->ThenBlasGemm(
+ perftools::gputools::blas::Transpose::kNoTranspose,
+ perftools::gputools::blas::Transpose::kTranspose, n, m, k,
+ 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n)
+ .ok();
+ if (!blas_launch_status) {
+ context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=",
+ m, ", n=", n, ", k=", k));
+ }
+ return;
+ }
+
+ perftools::gputools::dnn::BatchDescriptor input_desc;
+ input_desc.set_count(batch)
+ .set_height(input_rows)
+ .set_width(input_cols)
+ .set_feature_map_count(in_depth)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::BatchDescriptor output_desc;
+ output_desc.set_count(batch)
+ .set_height(output_rows)
+ .set_width(output_cols)
+ .set_feature_map_count(out_depth)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::FilterDescriptor filter_desc;
+ filter_desc.set_input_filter_height(filter_rows)
+ .set_input_filter_width(filter_cols)
+ .set_input_feature_map_count(in_depth)
+ .set_output_feature_map_count(out_depth);
+ perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
+ conv_desc.set_vertical_filter_stride(stride)
+ .set_horizontal_filter_stride(stride)
+ .set_zero_padding_height(padding_rows / 2)
+ .set_zero_padding_width(padding_cols / 2);
+
+ // NOTE(zhengxq):
+ // cuDNN only supports the following layouts :
+ // Input : B x D x R x C
+ // Filter : OD x ID x R x C
+ // Whereas, we have
+ // Input : B x R x C x D
+ // Filter : R x C x ID x OD
+ // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C)
+ // The first TransformDepth performs
+ // (B x R x C x D) => (B x D x R x C).
+ // Since the tensor returned from cuDNN is B x D x R x C also,
+ // the second TransformDepth performs
+ // (B x D x R x C) => (B x R x C x D).
+
+ Tensor pre_transformed_filter_backprop;
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({out_depth, in_depth, filter_rows, filter_cols}),
+ &pre_transformed_filter_backprop));
+
+ Tensor transformed_out_backprop;
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({batch, out_depth, output_rows, output_cols}),
+ &transformed_out_backprop));
+
+ functor::TransformDepth<Device, T>()(
+ context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
+ Eigen::DSizes<Eigen::DenseIndex, 4>(0, 3, 1, 2),
+ transformed_out_backprop.tensor<T, 4>());
+
+ Tensor transformed_input;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({batch, in_depth, input_rows, input_cols}),
+ &transformed_input));
+
+ functor::TransformDepth<Device, T>()(
+ context->eigen_device<Device>(), input.tensor<T, 4>(),
+ Eigen::DSizes<Eigen::DenseIndex, 4>(0, 3, 1, 2),
+ transformed_input.tensor<T, 4>());
+
+ auto out_backprop_ptr =
+ AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
+ transformed_out_backprop.template flat<T>().size());
+ auto filter_backprop_ptr = AsDeviceMemory(
+ pre_transformed_filter_backprop.template flat<T>().data(),
+ pre_transformed_filter_backprop.template flat<T>().size());
+ auto input_ptr =
+ AsDeviceMemory(transformed_input.template flat<T>().data(),
+ transformed_input.template flat<T>().size());
+
+ bool cudnn_launch_status =
+ stream->ThenConvolveBackwardFilter(input_desc, input_ptr, output_desc,
+ out_backprop_ptr, conv_desc,
+ filter_desc, &filter_backprop_ptr)
+ .ok();
+
+ if (!cudnn_launch_status) {
+ context->SetStatus(errors::Internal(
+ "cuDNN Backward Filter function launch failure : input shape(",
+ input_shape.DebugString(), ") filter shape(",
+ filter_shape.DebugString(), ")"));
+ }
+
+ auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
+ functor::TransformDepth<Device, T>()(
+ context->eigen_device<Device>(),
+ toConstTensor(pre_transformed_filter_backprop)
+ .template tensor<T, 4>(),
+ Eigen::DSizes<Eigen::DenseIndex, 4>(2, 3, 1, 0),
+ filter_backprop->tensor<T, 4>());
+ } else {
+ // Fall back to the non-cudnn code path
+
+ // For the backprop of the filter, we need to also transpose the
+ // out_backprop.
+ // The shape of backprop is
+ // [batch, out_rows, out_cols, out_depth]
+ // And we need to change it to
+ // [out_depth, out_rows, out_cols, batch]
+ Eigen::DSizes<Eigen::DenseIndex, 4> out_order{3, 1, 2, 0};
+ TensorShape padded_out_shape(
+ {out_depth, padded_out_rows, padded_out_cols, batch});
+ Tensor padded_output;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::v(),
+ padded_out_shape, &padded_output));
+
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 4> pad_dims{
+ {{0, 0},
+ {top_pad_rows, bottom_pad_rows},
+ {left_pad_cols, right_pad_cols},
+ {0, 0}}};
+ functor::InflatePadAndShuffle<Device, T, 4>()(
+ context->eigen_device<Device>(), out_backprop.tensor<T, 4>(), strides,
+ pad_dims, out_order, padded_output.tensor<T, 4>());
+ const Tensor& padded_output_cref = padded_output;
+
+ // For the backprop of the filter, we need to transpose the input.
+ // The shape of input is
+ // [batch, in_rows, in_cols, in_depth]
+ // And we need to change it to
+ // [in_rows, in_cols, batch, in_depth]
+ Eigen::DSizes<Eigen::DenseIndex, 4> in_order{1, 2, 0, 3};
+ TensorShape in_shuffle_shape({input_rows, input_cols, batch, in_depth});
+ Tensor in_shuffle;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::v(),
+ in_shuffle_shape, &in_shuffle));
+
+ // No need for reversing this time.
+ Eigen::array<bool, 4> trivial_dims{false, false, false, false};
+ functor::ShuffleAndReverse<Device, T, 4>()(
+ context->eigen_device<Device>(), input.tensor<T, 4>(), in_order,
+ trivial_dims, in_shuffle.tensor<T, 4>());
+ const Tensor& in_shuffle_cref = in_shuffle;
+
+ // The output of the conv_2d would be
+ // [out_depth, filter_rows, filter_cols, in_depth]
+ // and we need to shuffle it back to
+ // [filter_rows, filter_cols, in_depth, out_depth];
+ // And we need to reverse the filter backprops
+ // So we need to allocated (sigh) yet another piece of memory to hold the
+ // ouptut.
+ TensorShape filter_shuffle_shape(
+ {out_depth, filter_rows, filter_cols, in_depth});
+ Tensor filter_shuffle;
+ OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
+ filter_shuffle_shape,
+ &filter_shuffle));
+
+ functor::SpatialConvolution<Device, T>()(
+ context->eigen_device<Device>(), filter_shuffle.tensor<T, 4>(),
+ padded_output_cref.tensor<T, 4>(), in_shuffle_cref.tensor<T, 4>(), 1,
+ BrainPadding2EigenPadding(VALID));
+
+ // Now copy the filter_backprop back to the destination.
+ Eigen::DSizes<Eigen::DenseIndex, 4> filter_order{1, 2, 3, 0};
+ Eigen::array<bool, 4> filter_rev_dims{true, true, false, false};
+ const Tensor& filter_shuffle_cref = filter_shuffle;
+ functor::ShuffleAndReverse<Device, T, 4>()(
+ context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 4>(),
+ filter_order, filter_rev_dims, filter_backprop->tensor<T, 4>());
+ }
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+ bool use_cudnn_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropFilterOp);
+};
+
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void ShuffleAndReverse<GPUDevice, T, 4>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \
+ const Eigen::DSizes<Eigen::DenseIndex, 4>& order, \
+ const Eigen::array<bool, 4>& reverse_dims, \
+ typename TTypes<T, 4>::Tensor output); \
+ extern template struct ShuffleAndReverse<GPUDevice, T, 4>; \
+ template <> \
+ void InflatePadAndShuffle<GPUDevice, T, 4>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input, \
+ const Eigen::DSizes<Eigen::DenseIndex, 4>& strides, \
+ const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 4>& pad_dims, \
+ const Eigen::DSizes<Eigen::DenseIndex, 4>& order, \
+ typename TTypes<T, 4>::Tensor output); \
+ extern template struct InflatePadAndShuffle<GPUDevice, T, 4>; \
+ template <> \
+ void TransformFilter<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4>::ConstTensor in, \
+ typename TTypes<T, 4>::Tensor out); \
+ extern template struct TransformFilter<GPUDevice, T>; \
+ template <> \
+ void TransformDepth<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4>::ConstTensor in, \
+ const Eigen::DSizes<Eigen::DenseIndex, 4>& shuffle, \
+ typename TTypes<T, 4>::Tensor out); \
+ extern template struct TransformDepth<GPUDevice, T>; \
+ template <> \
+ void SpatialConvolution<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4>::Tensor output, \
+ typename TTypes<T, 4>::ConstTensor input, \
+ typename TTypes<T, 4>::ConstTensor filter, int stride, \
+ const Eigen::PaddingType& padding); \
+ extern template struct SpatialConvolution<GPUDevice, T>; \
+ template <> \
+ void SpatialConvolutionBackwardInput<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4>::Tensor in_backprop, \
+ typename TTypes<T, 4>::ConstTensor filter, \
+ typename TTypes<T, 4>::ConstTensor output_backprop, int input_rows, \
+ int input_cols, int stride); \
+ extern template struct SpatialConvolutionBackwardInput<GPUDevice, T>
+
+DECLARE_GPU_SPEC(float);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .HostMemory("input_sizes"),
+ Conv2DSlowBackpropInputOp<GPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .HostMemory("filter_sizes"),
+ Conv2DSlowBackpropFilterOp<GPUDevice, float>);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
new file mode 100644
index 0000000000..aaa2951778
--- /dev/null
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -0,0 +1,373 @@
+// See docs in ../ops/nn_ops.cc.
+
+#define USE_EIGEN_TENSOR
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/util/use_cudnn.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/public/tensor.h"
+
+#if GOOGLE_CUDA
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
+#include "tensorflow/stream_executor/stream.h"
+#endif // GOOGLE_CUDA
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+struct LaunchGeneric {
+ static void launch(OpKernelContext* ctx, const Tensor& input,
+ const Tensor& filter, int stride,
+ const Eigen::PaddingType& padding, Tensor* output) {
+ if (filter.dim_size(1) == filter.dim_size(0) && filter.dim_size(0) == 1 &&
+ stride == 1) {
+ // For 1x1 kernel, the 2D convolution is reduced to matrix
+ // multiplication.
+ //
+ // TODO(vrv): We should be able to call SpatialConvolution
+ // and it will produce the same result, but doing so
+ // led to NaNs during training. Using matmul instead for now.
+ int conv_width = 1; // Width for the convolution step.
+ for (int i = 0; i < 3; ++i) {
+ conv_width *= output->dim_size(i);
+ }
+
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
+ dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
+ functor::MatMulConvFunctor<Device, T>()(
+ ctx->eigen_device<Device>(),
+ output->shaped<T, 2>({conv_width, filter.dim_size(3)}),
+ input.shaped<T, 2>({conv_width, filter.dim_size(2)}),
+ filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
+ dim_pair);
+ } else {
+ functor::SpatialConvolution<Device, T>()(
+ ctx->eigen_device<Device>(), output->tensor<T, 4>(),
+ input.tensor<T, 4>(), filter.tensor<T, 4>(), stride, padding);
+ }
+ }
+};
+
+template <typename Device, typename T>
+struct LaunchConvOp;
+
+template <typename T>
+struct LaunchConvOp<CPUDevice, T> {
+ static void launch(OpKernelContext* ctx, bool use_cudnn, const Tensor& input,
+ const Tensor& filter, int stride,
+ const Eigen::PaddingType& padding, Tensor* output) {
+ LaunchGeneric<CPUDevice, T>::launch(ctx, input, filter, stride, padding,
+ output);
+ }
+};
+
+template <typename Device, typename T>
+class Conv2DOp : public BinaryOp<T> {
+ public:
+ explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+ OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
+ use_cudnn_ &= CanUseCudnn();
+ OP_REQUIRES(context, strides_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, strides_[1] == strides_[2],
+ errors::InvalidArgument(
+ "Current implementation only supports equal length "
+ "strides in the row and column dimensions."));
+ OP_REQUIRES(context, (strides_[0] == 1 && strides_[3] == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Input tensor is of the following dimensions:
+ // [ batch, in_rows, in_cols, in_depth ]
+
+ const Tensor& input = context->input(0);
+
+ // Input filter is of the following dimensions:
+ // [ filter_rows, filter_cols, in_depth, out_depth]
+ const Tensor& filter = context->input(1);
+
+ // For 2D convolution, there should be 4 dimensions.
+ OP_REQUIRES(context, input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().ShortDebugString()));
+ OP_REQUIRES(context, filter.dims() == 4,
+ errors::InvalidArgument("filter must be 4-dimensional: ",
+ filter.shape().ShortDebugString()));
+
+ // The last dimension for input is in_depth. It must be the same as the
+ // filter's in_depth.
+ const int64 in_depth = input.dim_size(3);
+ OP_REQUIRES(
+ context, in_depth == filter.dim_size(2),
+ errors::InvalidArgument("input and filter must have the same depth: ",
+ in_depth, " vs ", filter.dim_size(2)));
+
+ // The last dimension for filter is out_depth.
+ const int64 out_depth = filter.dim_size(3);
+
+ // The second dimension for input is rows/height.
+ // The first dimension for filter is rows/height.
+ const int64 input_rows = input.dim_size(1);
+ const int64 filter_rows = filter.dim_size(0);
+
+ // The third dimension for input is columns/width.
+ // The second dimension for filter is columns/width.
+ const int64 input_cols = input.dim_size(2);
+ const int64 filter_cols = filter.dim_size(1);
+
+ // The first dimension for input is batch.
+ const int64 batch = input.dim_size(0);
+
+ // For now we take the stride from the second dimension only (we
+ // assume row = col stride, and do not support striding on the
+ // batch or depth dimension).
+ const int stride = strides_[1];
+
+ int out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
+ if (filter_cols == filter_rows && filter_rows == 1 && stride == 1) {
+ // For 1x1 kernel, the 2D convolution is reduced to matrix
+ // multiplication.
+ out_rows = input_rows;
+ out_cols = input_cols;
+ } else {
+ OP_REQUIRES_OK(
+ context, Get2dOutputSize(input_rows, input_cols, filter_rows,
+ filter_cols, stride, stride, padding_,
+ &out_rows, &out_cols, &pad_rows, &pad_cols));
+ }
+ TensorShape out_shape({batch, out_rows, out_cols, out_depth});
+
+ // Output tensor is of the following dimensions:
+ // [ in_batch, out_rows, out_cols, out_depth ]
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+
+ VLOG(2) << "Conv2D: in_depth = " << in_depth
+ << ", input_cols = " << input_cols
+ << ", filter_cols = " << filter_cols
+ << ", input_rows = " << input_rows
+ << ", filter_rows = " << filter_rows << ", stride = " << stride
+ << ", out_depth = " << out_depth;
+
+ LaunchConvOp<Device, T>::launch(context, use_cudnn_, input, filter, stride,
+ BrainPadding2EigenPadding(padding_),
+ output);
+ }
+
+ private:
+ std::vector<int32> strides_;
+ bool use_cudnn_;
+ Padding padding_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("Conv2D")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ Conv2DOp<CPUDevice, float>);
+
+#if GOOGLE_CUDA
+
+namespace {
+template <typename T>
+perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
+ uint64 size) {
+ perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
+ size * sizeof(T));
+ perftools::gputools::DeviceMemory<T> typed(wrapped);
+ return typed;
+}
+} // namespace
+
+template <typename T>
+struct LaunchConvOp<GPUDevice, T> {
+ static void launch(OpKernelContext* ctx, bool use_cudnn,
+ const Tensor& input_param, const Tensor& filter,
+ int stride, const Eigen::PaddingType& padding,
+ Tensor* output) {
+ auto* stream = ctx->op_device_context<GPUDeviceContext>()->stream();
+ OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
+
+ if (use_cudnn) {
+ Tensor input = input_param;
+ if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1) {
+ // 1x1 filter, so call cublas directly.
+ const uint64 m =
+ input.dim_size(0) * input.dim_size(1) * input.dim_size(2);
+ const uint64 k = filter.dim_size(2);
+ const uint64 n = filter.dim_size(3);
+
+ auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
+ input.template flat<T>().size());
+ auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
+ filter.template flat<T>().size());
+ auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
+ output->template flat<T>().size());
+
+ auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+ bool blas_launch_status =
+ stream->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f,
+ b_ptr, n, a_ptr, k, 0.0f, &c_ptr, n)
+ .ok();
+ if (!blas_launch_status) {
+ ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
+ ", n=", n, ", k=", k));
+ }
+ return;
+ }
+ if (padding == Eigen::PADDING_SAME) {
+ const int64 out_rows = output->dim_size(1);
+ const int64 out_cols = output->dim_size(2);
+ const int64 in_rows = input.dim_size(1);
+ const int64 in_cols = input.dim_size(2);
+ const int64 patch_rows = filter.dim_size(0);
+ const int64 patch_cols = filter.dim_size(1);
+ // Total padding on rows and cols is
+ // Pr = (R' - 1) * S + Kr - R
+ // Pc = (C' - 1) * S + Kc - C
+ // where (R', C') are output dimensions, (R, C) are input dimensions, S
+ // is stride, (Kr, Kc) are filter dimensions.
+ // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
+ // and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means
+ // we pad more on the right and bottom than on the top and left.
+ const int padding_rows = (out_rows - 1) * stride + patch_rows - in_rows;
+ const int padding_cols = (out_cols - 1) * stride + patch_cols - in_cols;
+ Tensor transformed_input;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape(
+ {input.dim_size(0), input.dim_size(1) + padding_rows,
+ input.dim_size(2) + padding_cols, input.dim_size(3)}),
+ &transformed_input));
+
+ functor::PadInput<GPUDevice, T>()(
+ ctx->eigen_device<GPUDevice>(), input_param.tensor<T, 4>(),
+ padding_rows / 2, padding_rows - padding_rows / 2, padding_cols / 2,
+ padding_cols - padding_cols / 2, transformed_input.tensor<T, 4>());
+ input = transformed_input;
+ }
+
+ perftools::gputools::dnn::BatchDescriptor input_desc;
+ input_desc.set_count(input.dim_size(0))
+ .set_height(input.dim_size(1))
+ .set_width(input.dim_size(2))
+ .set_feature_map_count(input.dim_size(3))
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth);
+ perftools::gputools::dnn::BatchDescriptor output_desc;
+ output_desc.set_count(output->dim_size(0))
+ .set_height(output->dim_size(1))
+ .set_width(output->dim_size(2))
+ .set_feature_map_count(output->dim_size(3))
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth);
+ perftools::gputools::dnn::FilterDescriptor filter_desc;
+ filter_desc.set_input_filter_height(filter.dim_size(0))
+ .set_input_filter_width(filter.dim_size(1))
+ .set_input_feature_map_count(filter.dim_size(2))
+ .set_output_feature_map_count(filter.dim_size(3));
+ perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
+ conv_desc.set_vertical_filter_stride(stride)
+ .set_horizontal_filter_stride(stride);
+
+ Tensor transformed_filter;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({filter.dim_size(3), filter.dim_size(2),
+ filter.dim_size(0), filter.dim_size(1)}),
+ &transformed_filter));
+
+ functor::TransformFilter<GPUDevice, T>()(
+ ctx->eigen_device<GPUDevice>(), filter.tensor<T, 4>(),
+ transformed_filter.tensor<T, 4>());
+
+ auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
+ input.template flat<T>().size());
+ auto filter_ptr =
+ AsDeviceMemory(transformed_filter.template flat<T>().data(),
+ transformed_filter.template flat<T>().size());
+ auto output_ptr = AsDeviceMemory(output->template flat<T>().data(),
+ output->template flat<T>().size());
+
+ bool cudnn_launch_status =
+ stream->ThenConvolve(input_desc, input_ptr, filter_desc, filter_ptr,
+ conv_desc, output_desc, &output_ptr)
+ .ok();
+
+ if (!cudnn_launch_status) {
+ ctx->SetStatus(errors::Internal(
+ "cuDNN launch failure : input shape(", input.shape().DebugString(),
+ ") filter shape(", filter.shape().DebugString(), ")"));
+ }
+ } else {
+ LaunchGeneric<GPUDevice, T>::launch(ctx, input_param, filter, stride,
+ padding, output);
+ }
+ }
+};
+
+#endif // GOOGLE_CUDA
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void SpatialConvolution<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4>::Tensor output, \
+ typename TTypes<T, 4>::ConstTensor input, \
+ typename TTypes<T, 4>::ConstTensor filter, int stride, \
+ const Eigen::PaddingType& padding); \
+ extern template struct SpatialConvolution<GPUDevice, T>; \
+ template <> \
+ void MatMulConvFunctor<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 2>::Tensor out, \
+ typename TTypes<T, 2>::ConstTensor in0, \
+ typename TTypes<T, 2>::ConstTensor in1, \
+ const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair); \
+ extern template struct MatMulConvFunctor<GPUDevice, T>; \
+ template <> \
+ void TransformFilter<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4>::ConstTensor in, \
+ typename TTypes<T, 4>::Tensor out); \
+ extern template struct TransformFilter<GPUDevice, T>; \
+ template <> \
+ void PadInput<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4>::ConstTensor in, \
+ int padding_rows_left, int padding_rows_right, int padding_cols_left, \
+ int padding_cols_right, typename TTypes<T, 4>::Tensor out); \
+ extern template struct PadInput<GPUDevice, T>
+
+DECLARE_GPU_SPEC(float);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+// Registration of the GPU implementations.
+REGISTER_KERNEL_BUILDER(Name("Conv2D")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T"),
+ Conv2DOp<GPUDevice, float>);
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_ops_gpu.cu.cc b/tensorflow/core/kernels/conv_ops_gpu.cu.cc
new file mode 100644
index 0000000000..44af814e2b
--- /dev/null
+++ b/tensorflow/core/kernels/conv_ops_gpu.cu.cc
@@ -0,0 +1,35 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/conv_2d.h"
+
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+template <typename T>
+struct SpatialConvolution<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T, 4>::Tensor output,
+ typename TTypes<T, 4>::ConstTensor input,
+ typename TTypes<T, 4>::ConstTensor filter, int stride,
+ const Eigen::PaddingType& padding) {
+ // TODO(keveman): nvcc 6.5 crashes when 32 bit indexing is turned on. Enable
+ // this when we move to cuda 7.0.
+ // SpatialConvolutionFunc(d, To32Bit(output), To32Bit(input),
+ // To32Bit(filter), stride, padding);
+
+ SpatialConvolutionFunc(d, output, input, filter, stride, padding);
+ }
+};
+
+template struct SpatialConvolution<GPUDevice, float>;
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/conv_ops_gpu_2.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_2.cu.cc
new file mode 100644
index 0000000000..e2e9d25d83
--- /dev/null
+++ b/tensorflow/core/kernels/conv_ops_gpu_2.cu.cc
@@ -0,0 +1,16 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/conv_2d.h"
+
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+template struct functor::InflatePadAndShuffle<GPUDevice, float, 4>;
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
new file mode 100644
index 0000000000..dbbe08ef9c
--- /dev/null
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -0,0 +1,22 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/conv_2d.h"
+
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+template struct functor::ShuffleAndReverse<GPUDevice, float, 4>;
+
+template struct functor::TransformFilter<GPUDevice, float>;
+
+template struct functor::PadInput<GPUDevice, float>;
+
+template struct functor::TransformDepth<GPUDevice, float>;
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/conv_ops_gpu_matmul.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_matmul.cu.cc
new file mode 100644
index 0000000000..87d79ecb4d
--- /dev/null
+++ b/tensorflow/core/kernels/conv_ops_gpu_matmul.cu.cc
@@ -0,0 +1,16 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/conv_2d.h"
+
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+template struct functor::MatMulConvFunctor<GPUDevice, float>;
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/core_ops_test.cc b/tensorflow/core/kernels/core_ops_test.cc
new file mode 100644
index 0000000000..a42a5999da
--- /dev/null
+++ b/tensorflow/core/kernels/core_ops_test.cc
@@ -0,0 +1,990 @@
+#define EIGEN_USE_THREADS
+
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif // GOOGLE_CUDA
+
+#include <functional>
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/port.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+static void SetConstOp(const string& name, std::initializer_list<int64> dims,
+ NodeDef* node) {
+ Tensor tensor(DT_FLOAT, TensorShape(dims));
+ for (int64 i = 0; i < tensor.NumElements(); ++i) {
+ tensor.flat<float>()(i) = i / 10.0f;
+ }
+ TF_CHECK_OK(NodeDefBuilder(name, "Const")
+ .Attr("dtype", DT_FLOAT)
+ .Attr("value", tensor)
+ .Finalize(node));
+}
+
+static void SetConstSizesOp(const string& name, const std::vector<int32>& sizes,
+ NodeDef* node) {
+ TensorShape shape;
+ shape.AddDim(sizes.size());
+ Tensor tensor(DT_INT32, shape);
+ for (int64 i = 0; i < tensor.NumElements(); ++i) {
+ tensor.flat<int32>()(i) = sizes[i];
+ }
+ TF_CHECK_OK(NodeDefBuilder(name, "Const")
+ .Attr("dtype", DT_INT32)
+ .Attr("value", tensor)
+ .Finalize(node));
+}
+
+namespace {
+
+enum CONV_OP {
+ CONV_OP_FORWARD = 0,
+ CONV_OP_BACKPROP_INPUT = 1,
+ CONV_OP_BACKPROP_FILTER = 2
+};
+
+} // namespace
+
+static void BM_ConvFloat(int iters, int batch, int rows, int cols, int in_depth,
+ int out_depth, int filter_rows, int filter_cols,
+ CONV_OP op, int num_threads, int stride,
+ Padding padding, bool use_gpu, const string& label) {
+ if (!IsGoogleCudaEnabled() && use_gpu) {
+ testing::SetLabel(
+ strings::StrCat("Skipping GPU test (no --config=cuda): ", label));
+ return;
+ }
+ testing::SetLabel(label);
+
+ // Set the number of threads
+ SessionOptions options;
+ options.config.set_intra_op_parallelism_threads(num_threads);
+
+ // We set up a graph for computing convolution.
+ GraphDef graph;
+
+ // For this, we need an input tensor and a filter tensor.
+ // Compute the output size.
+ int out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
+ TF_CHECK_OK(Get2dOutputSize(rows, cols, filter_rows, filter_cols, stride,
+ stride, padding, &out_rows, &out_cols, &pad_rows,
+ &pad_cols));
+ // Counting the number of floating point operations (both MUL and ADD)
+ int64 num_ops = 0;
+ if (op == CONV_OP_FORWARD) {
+ // Forward computation:
+ // BATCH x OUT_ROW X OUT_COL X IN_DEPTH X PATCH_ROW X PATH_COL X OUT_DEPTH
+ // We multiply by two since there are mutliplications and additions.
+ num_ops = static_cast<int64>(batch * in_depth * out_depth) *
+ static_cast<int64>(filter_rows * filter_cols) *
+ static_cast<int64>(out_rows * out_cols) * 2;
+ } else {
+ // Backward computation: both input and filter backprop take the same
+ // amount of computation:
+ // BATCH x IN_ROW X IN_COL X IN_DEPTH X PATCH_ROW X PATCH_COL X OUT_DEPTH
+ // We multiply by two since there are mutliplications and additions.
+ num_ops = static_cast<int64>(batch * in_depth * out_depth) *
+ static_cast<int64>(filter_rows * filter_cols) *
+ static_cast<int64>(rows * cols) * 2;
+ }
+
+ SetConstOp("input", {batch, rows, cols, in_depth}, graph.add_node());
+ SetConstOp("filter", {filter_rows, filter_cols, in_depth, out_depth},
+ graph.add_node());
+ SetConstOp("output_backprop", {batch, out_rows, out_cols, out_depth},
+ graph.add_node());
+ SetConstSizesOp("input_sizes",
+ std::vector<int32>({batch, rows, cols, in_depth}),
+ graph.add_node());
+ SetConstSizesOp("filter_sizes", std::vector<int32>({filter_rows, filter_cols,
+ in_depth, out_depth}),
+ graph.add_node());
+
+ // Now add the convolution op
+ NodeDef* conv = graph.add_node();
+ switch (op) {
+ case CONV_OP_FORWARD:
+ TF_CHECK_OK(NodeDefBuilder("conv2d", "Conv2D")
+ .Input("input", 0, DT_FLOAT)
+ .Input("filter", 0, DT_FLOAT)
+ .Attr("strides", {1, stride, stride, 1})
+ .Attr("padding", padding == VALID ? "VALID" : "SAME")
+ .Finalize(conv));
+ break;
+ case CONV_OP_BACKPROP_INPUT:
+ TF_CHECK_OK(NodeDefBuilder("conv2d", "Conv2DBackpropInput")
+ .Input("input_sizes", 0, DT_INT32)
+ .Input("filter", 0, DT_FLOAT)
+ .Input("output_backprop", 0, DT_FLOAT)
+ .Attr("strides", {1, stride, stride, 1})
+ .Attr("padding", padding == VALID ? "VALID" : "SAME")
+ .Finalize(conv));
+ break;
+ case CONV_OP_BACKPROP_FILTER:
+ TF_CHECK_OK(NodeDefBuilder("conv2d", "Conv2DBackpropFilter")
+ .Input("input", 0, DT_FLOAT)
+ .Input("filter_sizes", 0, DT_INT32)
+ .Input("output_backprop", 0, DT_FLOAT)
+ .Attr("strides", {1, stride, stride, 1})
+ .Attr("padding", padding == VALID ? "VALID" : "SAME")
+ .Finalize(conv));
+ break;
+ }
+ Graph* g = new Graph(OpRegistry::Global());
+ GraphConstructorOptions opts;
+ TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph, g));
+
+ string device = use_gpu ? "gpu" : "cpu";
+ test::Benchmark(device, g, &options).Run(iters);
+ testing::ItemsProcessed(num_ops * iters);
+}
+
+// BS: batch_size
+// R: tensor_in_rows
+// C: tensor_in_cols
+// ID: input_depth
+// OD: output_depth
+// KR: kernel_rows
+// KC: kernel_cols
+#define BM_ConvFloatFwd(BS, R, C, ID, OD, KR, KC, STR, PAD, LABEL) \
+ static void BM_ConvFloatFwdCPU1_##LABEL(int iters) { \
+ BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \
+ PAD, false, \
+ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
+ KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \
+ } \
+ static void BM_ConvFloatFwdCPU4_##LABEL(int iters) { \
+ BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 4, STR, \
+ PAD, false, \
+ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
+ KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \
+ } \
+ static void BM_ConvFloatFwdGPU_##LABEL(int iters) { \
+ BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_FORWARD, 1, STR, \
+ PAD, true, \
+ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
+ KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
+ } \
+ BENCHMARK(BM_ConvFloatFwdCPU1_##LABEL); \
+ BENCHMARK(BM_ConvFloatFwdCPU4_##LABEL); \
+ BENCHMARK(BM_ConvFloatFwdGPU_##LABEL)
+
+BM_ConvFloatFwd(32, 5, 5, 1248, 128, 1, 1, 1, SAME, conv0);
+BM_ConvFloatFwd(32, 8, 8, 384, 384, 1, 3, 1, SAME, conv1);
+BM_ConvFloatFwd(32, 8, 8, 384, 384, 3, 1, 1, SAME, conv2);
+BM_ConvFloatFwd(32, 8, 8, 2048, 192, 1, 1, 1, SAME, conv3);
+BM_ConvFloatFwd(32, 8, 8, 448, 384, 3, 3, 1, SAME, conv4);
+BM_ConvFloatFwd(32, 8, 8, 2048, 320, 1, 1, 1, SAME, conv5);
+BM_ConvFloatFwd(32, 8, 8, 2048, 448, 1, 1, 1, SAME, conv6);
+BM_ConvFloatFwd(32, 8, 8, 2048, 384, 1, 1, 1, SAME, conv7);
+BM_ConvFloatFwd(32, 8, 8, 1760, 384, 1, 1, 1, SAME, conv8);
+BM_ConvFloatFwd(32, 8, 8, 1760, 192, 1, 1, 1, SAME, conv9);
+BM_ConvFloatFwd(32, 8, 8, 1760, 448, 1, 1, 1, SAME, conv10);
+BM_ConvFloatFwd(32, 8, 8, 1760, 320, 1, 1, 1, SAME, conv11);
+BM_ConvFloatFwd(32, 17, 17, 192, 192, 3, 3, 2, VALID, conv12);
+BM_ConvFloatFwd(32, 17, 17, 192, 192, 3, 3, 1, SAME, conv13);
+BM_ConvFloatFwd(32, 17, 17, 1248, 192, 1, 1, 1, SAME, conv14);
+BM_ConvFloatFwd(32, 17, 17, 128, 320, 3, 3, 2, VALID, conv15);
+BM_ConvFloatFwd(32, 17, 17, 1248, 128, 1, 1, 1, SAME, conv16);
+BM_ConvFloatFwd(32, 17, 17, 224, 224, 1, 3, 1, SAME, conv17);
+BM_ConvFloatFwd(32, 17, 17, 192, 256, 3, 1, 1, SAME, conv18);
+BM_ConvFloatFwd(32, 17, 17, 192, 256, 1, 3, 1, SAME, conv19);
+BM_ConvFloatFwd(32, 17, 17, 1216, 192, 1, 1, 1, SAME, conv20);
+BM_ConvFloatFwd(32, 17, 17, 1216, 96, 1, 1, 1, SAME, conv21);
+BM_ConvFloatFwd(32, 17, 17, 224, 224, 3, 1, 1, SAME, conv22);
+BM_ConvFloatFwd(32, 17, 17, 192, 224, 3, 3, 1, SAME, conv23);
+BM_ConvFloatFwd(32, 17, 17, 192, 192, 1, 3, 1, SAME, conv24);
+BM_ConvFloatFwd(32, 17, 17, 1152, 192, 1, 1, 1, SAME, conv25);
+BM_ConvFloatFwd(32, 17, 17, 1152, 128, 1, 1, 1, SAME, conv26);
+BM_ConvFloatFwd(32, 17, 17, 192, 192, 3, 1, 1, SAME, conv27);
+BM_ConvFloatFwd(32, 17, 17, 160, 192, 3, 3, 1, SAME, conv28);
+BM_ConvFloatFwd(32, 17, 17, 1152, 160, 1, 1, 1, SAME, conv29);
+BM_ConvFloatFwd(32, 17, 17, 1024, 128, 1, 1, 1, SAME, conv30);
+BM_ConvFloatFwd(32, 17, 17, 128, 192, 1, 3, 1, SAME, conv31);
+BM_ConvFloatFwd(32, 17, 17, 1024, 160, 1, 1, 1, SAME, conv32);
+BM_ConvFloatFwd(32, 17, 17, 128, 192, 3, 1, 1, SAME, conv33);
+BM_ConvFloatFwd(32, 17, 17, 1024, 256, 1, 1, 1, SAME, conv34);
+BM_ConvFloatFwd(32, 17, 17, 128, 128, 3, 1, 1, SAME, conv35);
+BM_ConvFloatFwd(32, 17, 17, 768, 192, 1, 1, 1, SAME, conv36);
+BM_ConvFloatFwd(32, 17, 17, 128, 128, 1, 3, 1, SAME, conv37);
+BM_ConvFloatFwd(32, 17, 17, 128, 128, 3, 3, 1, SAME, conv38);
+BM_ConvFloatFwd(32, 17, 17, 768, 128, 1, 1, 1, SAME, conv39);
+BM_ConvFloatFwd(32, 17, 17, 768, 320, 1, 1, 1, SAME, conv40);
+BM_ConvFloatFwd(32, 35, 35, 96, 96, 3, 3, 2, VALID, conv41);
+BM_ConvFloatFwd(32, 35, 35, 288, 384, 3, 3, 2, VALID, conv42);
+BM_ConvFloatFwd(32, 35, 35, 64, 96, 3, 3, 1, SAME, conv43);
+BM_ConvFloatFwd(32, 35, 35, 288, 64, 1, 1, 1, SAME, conv44);
+BM_ConvFloatFwd(32, 35, 35, 256, 64, 1, 1, 1, SAME, conv45);
+BM_ConvFloatFwd(32, 35, 35, 48, 64, 5, 5, 1, SAME, conv46);
+BM_ConvFloatFwd(32, 35, 35, 256, 48, 1, 1, 1, SAME, conv47);
+BM_ConvFloatFwd(32, 35, 35, 96, 96, 3, 3, 1, SAME, conv48);
+BM_ConvFloatFwd(32, 35, 35, 192, 32, 1, 1, 1, SAME, conv49);
+BM_ConvFloatFwd(32, 35, 35, 192, 64, 1, 1, 1, SAME, conv50);
+BM_ConvFloatFwd(32, 35, 35, 192, 48, 1, 1, 1, SAME, conv51);
+BM_ConvFloatFwd(32, 73, 73, 64, 192, 3, 3, 1, VALID, conv52);
+BM_ConvFloatFwd(32, 73, 73, 64, 64, 1, 1, 1, VALID, conv53);
+BM_ConvFloatFwd(32, 147, 147, 24, 64, 1, 1, 1, VALID, conv54);
+
+#define BM_ConvFloatBkInAndFilter(BS, R, C, ID, OD, KR, KC, STR, PAD, LABEL) \
+ static void BM_ConvFloatBkInCPU1_##LABEL(int iters) { \
+ BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \
+ STR, PAD, false, \
+ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
+ KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \
+ } \
+ static void BM_ConvFloatBkInCPU4_##LABEL(int iters) { \
+ BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 4, \
+ STR, PAD, false, \
+ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
+ KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \
+ } \
+ static void BM_ConvFloatBkInGPU_##LABEL(int iters) { \
+ BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_INPUT, 1, \
+ STR, PAD, true, \
+ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
+ KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
+ } \
+ static void BM_ConvFloatBkFilterCPU1_##LABEL(int iters) { \
+ BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
+ STR, PAD, false, \
+ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
+ KR, "_", KC, "_", STR, "_", PAD, "_cpu1")); \
+ } \
+ static void BM_ConvFloatBkFilterCPU4_##LABEL(int iters) { \
+ BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 4, \
+ STR, PAD, false, \
+ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
+ KR, "_", KC, "_", STR, "_", PAD, "_cpu4")); \
+ } \
+ static void BM_ConvFloatBkFilterGPU_##LABEL(int iters) { \
+ BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
+ STR, PAD, true, \
+ strings::StrCat(BS, "_", R, "_", C, "_", ID, "_", OD, "_", \
+ KR, "_", KC, "_", STR, "_", PAD, "_gpu")); \
+ } \
+ BENCHMARK(BM_ConvFloatBkInCPU1_##LABEL); \
+ BENCHMARK(BM_ConvFloatBkInCPU4_##LABEL); \
+ BENCHMARK(BM_ConvFloatBkInGPU_##LABEL); \
+ BENCHMARK(BM_ConvFloatBkFilterCPU1_##LABEL); \
+ BENCHMARK(BM_ConvFloatBkFilterCPU4_##LABEL); \
+ BENCHMARK(BM_ConvFloatBkFilterGPU_##LABEL)
+
+// Benchmarks from the inception model
+
+BM_ConvFloatBkInAndFilter(32, 5, 5, 1248, 128, 1, 1, 1, SAME, conv0);
+BM_ConvFloatBkInAndFilter(32, 8, 8, 384, 384, 1, 3, 1, SAME, conv1);
+BM_ConvFloatBkInAndFilter(32, 8, 8, 384, 384, 3, 1, 1, SAME, conv2);
+BM_ConvFloatBkInAndFilter(32, 8, 8, 2048, 192, 1, 1, 1, SAME, conv3);
+BM_ConvFloatBkInAndFilter(32, 8, 8, 448, 384, 3, 3, 1, SAME, conv4);
+BM_ConvFloatBkInAndFilter(32, 8, 8, 2048, 320, 1, 1, 1, SAME, conv5);
+BM_ConvFloatBkInAndFilter(32, 8, 8, 2048, 448, 1, 1, 1, SAME, conv6);
+BM_ConvFloatBkInAndFilter(32, 8, 8, 2048, 384, 1, 1, 1, SAME, conv7);
+BM_ConvFloatBkInAndFilter(32, 8, 8, 1760, 384, 1, 1, 1, SAME, conv8);
+BM_ConvFloatBkInAndFilter(32, 8, 8, 1760, 192, 1, 1, 1, SAME, conv9);
+BM_ConvFloatBkInAndFilter(32, 8, 8, 1760, 448, 1, 1, 1, SAME, conv10);
+BM_ConvFloatBkInAndFilter(32, 8, 8, 1760, 320, 1, 1, 1, SAME, conv11);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 192, 192, 3, 3, 2, VALID, conv12);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 192, 192, 3, 3, 1, SAME, conv13);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 1248, 192, 1, 1, 1, SAME, conv14);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 128, 320, 3, 3, 2, VALID, conv15);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 1248, 128, 1, 1, 1, SAME, conv16);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 224, 224, 1, 3, 1, SAME, conv17);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 192, 256, 3, 1, 1, SAME, conv18);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 192, 256, 1, 3, 1, SAME, conv19);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 1216, 192, 1, 1, 1, SAME, conv20);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 1216, 96, 1, 1, 1, SAME, conv21);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 224, 224, 3, 1, 1, SAME, conv22);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 192, 224, 3, 3, 1, SAME, conv23);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 192, 192, 1, 3, 1, SAME, conv24);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 1152, 192, 1, 1, 1, SAME, conv25);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 1152, 128, 1, 1, 1, SAME, conv26);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 192, 192, 3, 1, 1, SAME, conv27);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 160, 192, 3, 3, 1, SAME, conv28);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 1152, 160, 1, 1, 1, SAME, conv29);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 1024, 128, 1, 1, 1, SAME, conv30);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 128, 192, 1, 3, 1, SAME, conv31);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 1024, 160, 1, 1, 1, SAME, conv32);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 128, 192, 3, 1, 1, SAME, conv33);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 1024, 256, 1, 1, 1, SAME, conv34);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 128, 128, 3, 1, 1, SAME, conv35);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 768, 192, 1, 1, 1, SAME, conv36);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 128, 128, 1, 3, 1, SAME, conv37);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 128, 128, 3, 3, 1, SAME, conv38);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 768, 128, 1, 1, 1, SAME, conv39);
+BM_ConvFloatBkInAndFilter(32, 17, 17, 768, 320, 1, 1, 1, SAME, conv40);
+BM_ConvFloatBkInAndFilter(32, 35, 35, 96, 96, 3, 3, 2, VALID, conv41);
+BM_ConvFloatBkInAndFilter(32, 35, 35, 288, 384, 3, 3, 2, VALID, conv42);
+BM_ConvFloatBkInAndFilter(32, 35, 35, 64, 96, 3, 3, 1, SAME, conv43);
+BM_ConvFloatBkInAndFilter(32, 35, 35, 288, 64, 1, 1, 1, SAME, conv44);
+BM_ConvFloatBkInAndFilter(32, 35, 35, 256, 64, 1, 1, 1, SAME, conv45);
+BM_ConvFloatBkInAndFilter(32, 35, 35, 48, 64, 5, 5, 1, SAME, conv46);
+BM_ConvFloatBkInAndFilter(32, 35, 35, 256, 48, 1, 1, 1, SAME, conv47);
+BM_ConvFloatBkInAndFilter(32, 35, 35, 96, 96, 3, 3, 1, SAME, conv48);
+BM_ConvFloatBkInAndFilter(32, 35, 35, 192, 32, 1, 1, 1, SAME, conv49);
+BM_ConvFloatBkInAndFilter(32, 35, 35, 192, 64, 1, 1, 1, SAME, conv50);
+BM_ConvFloatBkInAndFilter(32, 35, 35, 192, 48, 1, 1, 1, SAME, conv51);
+BM_ConvFloatBkInAndFilter(32, 73, 73, 64, 192, 3, 3, 1, VALID, conv52);
+BM_ConvFloatBkInAndFilter(32, 73, 73, 64, 64, 1, 1, 1, VALID, conv53);
+BM_ConvFloatBkInAndFilter(32, 147, 147, 24, 64, 1, 1, 1, VALID, conv54);
+
+#define BM_ConvFloatBkFCPU(BS, R, C, ID, OD, KR, KC, TH, LABEL) \
+ static void \
+ BM_ConvFloatBkFCPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC##_##TH( \
+ int iters) { \
+ BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, TH, \
+ 1, VALID, false, LABEL); \
+ } \
+ BENCHMARK( \
+ BM_ConvFloatBkFCPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC##_##TH)
+
+// Benchmarks from https://github.com/soumith/convnet-benchmarks
+BM_ConvFloatBkFCPU(128, 128, 128, 3, 96, 11, 11, 4, "convnet-layer1");
+BM_ConvFloatBkFCPU(128, 64, 64, 64, 128, 9, 9, 4, "convnet-layer2");
+BM_ConvFloatBkFCPU(128, 32, 32, 128, 128, 9, 9, 4, "convnet-layer3");
+BM_ConvFloatBkFCPU(128, 16, 16, 128, 128, 7, 7, 4, "convnet-layer4");
+BM_ConvFloatBkFCPU(128, 13, 13, 384, 384, 3, 3, 4, "convnet-layer5");
+
+#define BM_ConvFloatBkFGPU(BS, R, C, ID, OD, KR, KC, LABEL) \
+ static void BM_ConvFloatBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC( \
+ int iters) { \
+ BM_ConvFloat(iters, BS, R, C, ID, OD, KR, KC, CONV_OP_BACKPROP_FILTER, 1, \
+ 1, VALID, true, LABEL); \
+ } \
+ BENCHMARK(BM_ConvFloatBkFGPU_##BS##_##R##_##C##_##ID##_##OD##_##KR##_##KC)
+
+// Benchmarks from https://github.com/soumith/convnet-benchmarks
+BM_ConvFloatBkFGPU(128, 128, 128, 3, 96, 11, 11, "convnet-layer1");
+BM_ConvFloatBkFGPU(128, 64, 64, 64, 128, 9, 9, "convnet-layer2");
+BM_ConvFloatBkFGPU(128, 32, 32, 128, 128, 9, 9, "convnet-layer3");
+BM_ConvFloatBkFGPU(128, 16, 16, 128, 128, 7, 7, "convnet-layer4");
+BM_ConvFloatBkFGPU(128, 13, 13, 384, 384, 3, 3, "convnet-layer5");
+
+static void BM_LRNFloat(int iters, int depth, int cols, int rows,
+ int batch_size, int range, int num_threads,
+ const string& label) {
+ tensorflow::testing::StopTiming();
+ std::unique_ptr<Device> device(
+ DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
+
+ thread::ThreadPool threadpool(Env::Default(), "test", num_threads);
+ EigenThreadPoolWrapper wrapper(&threadpool);
+ Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads);
+ device->set_eigen_cpu_device(&eigen_cpu_device);
+
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TensorShape shape({batch_size, rows, cols, depth});
+
+ Tensor input(DT_FLOAT, shape);
+ test::FillIota<float>(&input, 1.0);
+ inputs.push_back({nullptr, &input});
+
+ // Convolution op.
+ NodeDef lrn_node_def;
+ TF_CHECK_OK(NodeDefBuilder("lrn_op", "LRN")
+ .Input("input", 0, DT_FLOAT)
+ .Attr("depth_radius", range)
+ .Attr("bias", 1.0)
+ .Attr("alpha", 0.1)
+ .Attr("beta", 0.5)
+ .Finalize(&lrn_node_def));
+
+ Status status;
+ std::unique_ptr<OpKernel> op(CreateOpKernel(
+ DEVICE_CPU, device.get(), cpu_allocator(), lrn_node_def, &status));
+ TF_CHECK_OK(status);
+
+ OpKernelContext::Params params;
+ params.device = device.get();
+ params.frame_iter = FrameAndIter(0, 0);
+ params.inputs = &inputs;
+ params.op_kernel = op.get();
+ params.output_alloc_attr = [&device, &op, &params](int index) {
+ AllocatorAttributes attr;
+ const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY);
+ attr.set_on_host(on_host);
+ return attr;
+ };
+
+ std::unique_ptr<OpKernelContext> context(new OpKernelContext(params));
+
+ op->Compute(context.get());
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters; ++i) {
+ delete context->release_output(0).tensor;
+ op->Compute(context.get());
+ }
+ tensorflow::testing::StopTiming();
+ testing::ItemsProcessed(context->mutable_output(0)->NumElements() * iters *
+ (2 * range + 1) * 2);
+ testing::SetLabel(label);
+}
+
+#define BM_LRNFloatFwdCPU(DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL) \
+ static void \
+ BM_LRNFloat_##DEPTH##_##COLS##_##ROWS##_##BATCH##_##RANGE##_##THREADS( \
+ int iters) { \
+ BM_LRNFloat(iters, DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL); \
+ } \
+ BENCHMARK( \
+ BM_LRNFloat_##DEPTH##_##COLS##_##ROWS##_##BATCH##_##RANGE##_##THREADS)
+
+// clang-format off
+// DEPTH, COLS, ROWS, BATCH, RANGE, THREADS, LABEL
+BM_LRNFloatFwdCPU(64, 56, 56, 32, 5, 1, "lrn 1 thread");
+BM_LRNFloatFwdCPU(192, 28, 28, 64, 2, 1, "lrn 1 thread");
+BM_LRNFloatFwdCPU(192, 56, 56, 32, 5, 1, "lrn 1 thread");
+BM_LRNFloatFwdCPU(64, 56, 56, 32, 5, 4, "lrn 4 threads");
+BM_LRNFloatFwdCPU(192, 28, 28, 64, 2, 4, "lrn 4 threads");
+BM_LRNFloatFwdCPU(192, 56, 56, 32, 5, 4, "lrn 4 threads");
+BM_LRNFloatFwdCPU(64, 56, 56, 32, 5, 8, "lrn 8 threads");
+BM_LRNFloatFwdCPU(192, 28, 28, 64, 2, 8, "lrn 8 threads");
+BM_LRNFloatFwdCPU(192, 56, 56, 32, 5, 8, "lrn 8 threads");
+// clang-format on
+
+/*
+AvgPooling Op
+*/
+static void BM_AvgPool(int iters, int batch_size, int rows, int cols, int depth,
+ int kernel_rows, int kernel_cols, int stride,
+ Padding padding, int num_threads, const string& label) {
+ tensorflow::testing::StopTiming();
+ std::unique_ptr<Device> device(
+ DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
+
+ thread::ThreadPool threadpool(Env::Default(), "test", num_threads);
+ EigenThreadPoolWrapper wrapper(&threadpool);
+ Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads);
+ device->set_eigen_cpu_device(&eigen_cpu_device);
+
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TensorShape shape1({batch_size, rows, cols, depth});
+ Tensor input1(DT_FLOAT, shape1);
+ test::FillIota<float>(&input1, 1.0);
+ inputs.push_back({nullptr, &input1});
+
+ // AvgPooling op.
+ NodeDef avgpool_node_def;
+ CHECK_EQ(kernel_rows, kernel_cols);
+ Status status = NodeDefBuilder("avgpool_op", "AvgPool")
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("ksize", {1, kernel_rows, kernel_cols, 1})
+ .Attr("strides", {1, stride, stride, 1})
+ .Attr("padding", padding == VALID ? "VALID" : "SAME")
+ .Finalize(&avgpool_node_def);
+ TF_CHECK_OK(status);
+
+ std::unique_ptr<OpKernel> op(CreateOpKernel(
+ DEVICE_CPU, device.get(), cpu_allocator(), avgpool_node_def, &status));
+ TF_CHECK_OK(status);
+ OpKernelContext::Params params;
+ params.device = device.get();
+ params.frame_iter = FrameAndIter(0, 0);
+ params.inputs = &inputs;
+ params.op_kernel = op.get();
+ params.output_alloc_attr = [&device, &op, &params](int index) {
+ AllocatorAttributes attr;
+ const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY);
+ attr.set_on_host(on_host);
+ return attr;
+ };
+
+ std::unique_ptr<OpKernelContext> avgpool_context(new OpKernelContext(params));
+
+ op->Compute(avgpool_context.get());
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters; ++i) {
+ delete avgpool_context->release_output(0).tensor;
+ op->Compute(avgpool_context.get());
+ }
+ tensorflow::testing::StopTiming();
+ testing::ItemsProcessed(avgpool_context->mutable_output(0)->NumElements() *
+ iters);
+ testing::SetLabel(label);
+}
+
+// BS: batch_size
+// IR: input_rows
+// IC: input_cols
+// ND: node_depth
+// KR: kernel_rows
+// KC: kernel_cols
+// ST: stride. We use the same stride for both directions.
+// PT: padding
+#define BM_AvgPoolFwdCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \
+ static void \
+ BM_AvgPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \
+ int iters) { \
+ BM_AvgPool(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \
+ } \
+ BENCHMARK( \
+ BM_AvgPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH)
+
+// Labels are taken from the 2014-July-24 version of imagenet
+BM_AvgPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, VALID, 1, "avgpool0_VALID");
+BM_AvgPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, VALID, 1, "avgpool1_VALID");
+BM_AvgPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, VALID, 1, "avgpool4_VALID");
+BM_AvgPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, VALID, 1, "avgpool10_VALID");
+BM_AvgPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, SAME, 1, "avgpool0_SAME");
+BM_AvgPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, SAME, 1, "avgpool1_SAME");
+BM_AvgPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, SAME, 1, "avgpool4_SAME");
+BM_AvgPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, SAME, 1, "avgpool10_SAME");
+BM_AvgPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, VALID, 4, "avgpool0_VALID");
+BM_AvgPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, VALID, 4, "avgpool1_VALID");
+BM_AvgPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, VALID, 4, "avgpool4_VALID");
+BM_AvgPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, VALID, 4, "avgpool10_VALID");
+BM_AvgPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, SAME, 4, "avgpool0_SAME");
+BM_AvgPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, SAME, 4, "avgpool1_SAME");
+BM_AvgPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, SAME, 4, "avgpool4_SAME");
+BM_AvgPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, SAME, 4, "avgpool10_SAME");
+
+static void BM_AvgPoolBk(int iters, int batch_size, int rows, int cols,
+ int depth, int kernel_rows, int kernel_cols,
+ int stride, Padding padding, int num_threads,
+ const string& label) {
+ tensorflow::testing::StopTiming();
+ std::unique_ptr<Device> device(
+ DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
+
+ thread::ThreadPool threadpool(Env::Default(), "test", num_threads);
+ EigenThreadPoolWrapper wrapper(&threadpool);
+ Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads);
+ device->set_eigen_cpu_device(&eigen_cpu_device);
+
+ gtl::InlinedVector<TensorValue, 4> inputs;
+
+ int out_height, out_width, pad_rows, pad_cols;
+ Status status =
+ Get2dOutputSize(rows, cols, kernel_rows, kernel_cols, stride, stride,
+ padding, &out_height, &out_width, &pad_rows, &pad_cols);
+ TF_CHECK_OK(status);
+ TensorShape output_shape({batch_size, out_height, out_width, depth});
+ TensorShape shape2({4});
+ Tensor input_shape_tensor(DT_INT32, shape2);
+ int32 input_dims[] = {batch_size, rows, cols, depth};
+ for (int i = 0; i < 4; i++) {
+ input_shape_tensor.flat<int32>()(i) = input_dims[i];
+ }
+ inputs.push_back({nullptr, &input_shape_tensor});
+
+ Tensor output_backprop(DT_FLOAT, output_shape);
+ test::FillIota<float>(&output_backprop, 11.0);
+ inputs.push_back({nullptr, &output_backprop});
+
+ // AvgPoolGrad op.
+ NodeDef avgpool_grad_node_def;
+ status = NodeDefBuilder("avgpool_grad_op", "AvgPoolGrad")
+ .Input(FakeInput())
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("ksize", {1, kernel_rows, kernel_cols, 1})
+ .Attr("strides", {1, stride, stride, 1})
+ .Attr("padding", padding == VALID ? "VALID" : "SAME")
+ .Finalize(&avgpool_grad_node_def);
+ TF_CHECK_OK(status);
+ std::unique_ptr<OpKernel> op(CreateOpKernel(
+ DEVICE_CPU, nullptr, cpu_allocator(), avgpool_grad_node_def, &status));
+ TF_CHECK_OK(status);
+ OpKernelContext::Params params;
+ params.device = device.get();
+ params.frame_iter = FrameAndIter(0, 0);
+ params.inputs = &inputs;
+ params.op_kernel = op.get();
+ params.output_alloc_attr = [&device, &op, &params](int index) {
+ AllocatorAttributes attr;
+ const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY);
+ attr.set_on_host(on_host);
+ return attr;
+ };
+
+ std::unique_ptr<OpKernelContext> avgpool_context(new OpKernelContext(params));
+
+ op->Compute(avgpool_context.get());
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters; ++i) {
+ delete avgpool_context->release_output(0).tensor;
+ op->Compute(avgpool_context.get());
+ }
+ tensorflow::testing::StopTiming();
+ testing::ItemsProcessed(avgpool_context->mutable_output(0)->NumElements() *
+ iters);
+ testing::SetLabel(label);
+}
+
+// BS: batch_size
+// IR: input_rows
+// IC: input_cols
+// ND: node_depth
+// KR: kernel_rows
+// KC: kernel_cols
+// ST: stride. We use the same stride for both directions.
+// PT: padding
+// The resulted symbol is too long. Need to use two macros to fit in 80-chars
+#define BM_AvgPoolBkCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \
+ static void \
+ BM_AvgPoolBk_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \
+ int iters) { \
+ BM_AvgPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \
+ } \
+ BENCHMARK( \
+ BM_AvgPoolBk_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH)
+
+// Shapes taken from the 2015/05/16 inception model
+BM_AvgPoolBkCPU(32, 35, 35, 192, 3, 3, 1, SAME, 1, "avgpool_grad0_SAME");
+BM_AvgPoolBkCPU(32, 35, 35, 256, 3, 3, 1, SAME, 1, "avgpool_grad1_SAME");
+BM_AvgPoolBkCPU(32, 17, 17, 768, 3, 3, 1, SAME, 1, "avgpool_grad2_SAME");
+BM_AvgPoolBkCPU(32, 17, 17, 1024, 3, 3, 1, SAME, 1, "avgpool_grad3_SAME");
+BM_AvgPoolBkCPU(32, 17, 17, 1152, 3, 3, 1, SAME, 1, "avgpool_grad4_SAME");
+BM_AvgPoolBkCPU(32, 17, 17, 1216, 3, 3, 1, SAME, 1, "avgpool_grad5_SAME");
+BM_AvgPoolBkCPU(32, 17, 17, 1248, 5, 5, 3, VALID, 1, "avgpool_grad6_VALID");
+BM_AvgPoolBkCPU(32, 8, 8, 1760, 3, 3, 1, SAME, 1, "avgpool_grad7_SAME");
+BM_AvgPoolBkCPU(32, 8, 8, 2048, 8, 8, 1, VALID, 1, "avgpool_grad8_VALID");
+
+/*
+MaxPooling Op
+*/
+static void BM_MaxPool(int iters, int batch_size, int rows, int cols, int depth,
+ int kernel_rows, int kernel_cols, int stride,
+ Padding padding, int num_threads, const string& label) {
+ tensorflow::testing::StopTiming();
+ std::unique_ptr<Device> device(
+ DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
+
+ thread::ThreadPool threadpool(Env::Default(), "test", num_threads);
+ EigenThreadPoolWrapper wrapper(&threadpool);
+ Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads);
+ device->set_eigen_cpu_device(&eigen_cpu_device);
+
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TensorShape shape1({batch_size, rows, cols, depth});
+ Tensor input1(DT_FLOAT, shape1);
+ test::FillIota<float>(&input1, 1.0);
+ inputs.push_back({nullptr, &input1});
+
+ // MaxPooling op.
+ NodeDef maxpool_node_def;
+ CHECK_EQ(kernel_rows, kernel_cols);
+ Status status = NodeDefBuilder("maxpool_op", "MaxPool")
+ .Input(FakeInput())
+ .Attr("ksize", {1, kernel_rows, kernel_cols, 1})
+ .Attr("strides", {1, stride, stride, 1})
+ .Attr("padding", padding == VALID ? "VALID" : "SAME")
+ .Finalize(&maxpool_node_def);
+ TF_CHECK_OK(status);
+ std::unique_ptr<OpKernel> op(CreateOpKernel(
+ DEVICE_CPU, device.get(), cpu_allocator(), maxpool_node_def, &status));
+ TF_CHECK_OK(status);
+ OpKernelContext::Params params;
+ params.device = device.get();
+ params.frame_iter = FrameAndIter(0, 0);
+ params.inputs = &inputs;
+ params.op_kernel = op.get();
+ params.output_alloc_attr = [&device, &op, &params](int index) {
+ AllocatorAttributes attr;
+ const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY);
+ attr.set_on_host(on_host);
+ return attr;
+ };
+
+ std::unique_ptr<OpKernelContext> maxpool_context(new OpKernelContext(params));
+
+ op->Compute(maxpool_context.get());
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters; ++i) {
+ delete maxpool_context->release_output(0).tensor;
+ op->Compute(maxpool_context.get());
+ }
+ tensorflow::testing::StopTiming();
+ testing::ItemsProcessed(maxpool_context->mutable_output(0)->NumElements() *
+ iters);
+ testing::SetLabel(label);
+}
+
+// BS: batch_size
+// IR: input_rows
+// IC: input_cols
+// ND: node_depth
+// KR: kernel_rows
+// KC: kernel_cols
+// ST: stride. We use the same stride for both directions.
+// PT: padding
+#define BM_MaxPoolFwdCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \
+ static void \
+ BM_MaxPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH( \
+ int iters) { \
+ BM_MaxPool(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL); \
+ } \
+ BENCHMARK( \
+ BM_MaxPool_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_##PT##_##TH)
+
+// Labels are taken from the 2014-July-24 version of imagenet
+BM_MaxPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, VALID, 1, "maxpool0_VALID");
+BM_MaxPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, VALID, 1, "maxpool1_VALID");
+BM_MaxPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, VALID, 1, "maxpool4_VALID");
+BM_MaxPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, VALID, 1, "maxpool10_VALID");
+BM_MaxPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, SAME, 1, "maxpool0_SAME");
+BM_MaxPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, SAME, 1, "maxpool1_SAME");
+BM_MaxPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, SAME, 1, "maxpool4_SAME");
+BM_MaxPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, SAME, 1, "maxpool10_SAME");
+BM_MaxPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, VALID, 4, "maxpool0_VALID");
+BM_MaxPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, VALID, 4, "maxpool1_VALID");
+BM_MaxPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, VALID, 4, "maxpool4_VALID");
+BM_MaxPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, VALID, 4, "maxpool10_VALID");
+BM_MaxPoolFwdCPU(32, 112, 112, 64, 3, 3, 2, SAME, 4, "maxpool0_SAME");
+BM_MaxPoolFwdCPU(32, 56, 56, 192, 3, 3, 2, SAME, 4, "maxpool1_SAME");
+BM_MaxPoolFwdCPU(32, 28, 28, 352, 3, 3, 2, SAME, 4, "maxpool4_SAME");
+BM_MaxPoolFwdCPU(32, 14, 14, 576, 3, 3, 2, SAME, 4, "maxpool10_SAME");
+
+static void BM_MaxPoolBk(int iters, int batch_size, int rows, int cols,
+ int depth, int kernel_rows, int kernel_cols,
+ int stride, Padding padding, int num_threads,
+ bool use_gpu, const string& label) {
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+
+ int out_height, out_width, pad_rows, pad_cols;
+ Status status =
+ Get2dOutputSize(rows, cols, kernel_rows, kernel_cols, stride, stride,
+ padding, &out_height, &out_width, &pad_rows, &pad_cols);
+ TF_CHECK_OK(status);
+
+ Tensor input_data(DT_FLOAT, TensorShape({batch_size, rows, cols, depth}));
+ input_data.flat<float>().setRandom();
+ Node* input_data_node = ops::Const(input_data, b.opts());
+
+ Tensor output_data(DT_FLOAT,
+ TensorShape({batch_size, out_height, out_width, depth}));
+ output_data.flat<float>().setRandom();
+ Node* output_data_node = ops::Const(output_data, b.opts());
+
+ Tensor output_diff(DT_FLOAT,
+ TensorShape({batch_size, out_height, out_width, depth}));
+ output_diff.flat<float>().setRandom();
+ Node* output_diff_node = ops::Const(output_diff, b.opts());
+
+ CHECK_EQ(kernel_rows, kernel_cols);
+ ops::MaxPoolGrad(input_data_node, output_data_node, output_diff_node,
+ {1, kernel_rows, kernel_cols, 1} /* ksize */,
+ {1, stride, stride, 1} /* stride */,
+ padding == VALID ? "VALID" : "SAME", b.opts());
+ Graph* g = new Graph(OpRegistry::Global());
+ TF_CHECK_OK(b.ToGraph(g));
+ string device = use_gpu ? "gpu" : "cpu";
+ test::Benchmark(device, g).Run(iters);
+
+ testing::ItemsProcessed(batch_size * rows * cols * depth * iters);
+ testing::SetLabel(label);
+}
+
+// BS: batch_size
+// IR: input_rows
+// IC: input_cols
+// ND: node_depth
+// KR: kernel_rows
+// KC: kernel_cols
+// ST: stride. We use the same stride for both directions.
+// PT: padding
+// The resulted symbol is too long. Need to use two macros to fit in 80-chars
+// clang-format off
+#define BM_MaxPoolBkGPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \
+ static void \
+ BM_MaxPoolBk_GPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \
+ ##PT##_##TH( \
+ int iters) { \
+ BM_MaxPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, true, LABEL); \
+ } \
+ BENCHMARK( \
+ BM_MaxPoolBk_GPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \
+ ##PT##_##TH) \
+
+#define BM_MaxPoolBkCPU(BS, IR, IC, ND, KR, KC, ST, PT, TH, LABEL) \
+ static void \
+ BM_MaxPoolBk_CPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \
+ ##PT##_##TH( \
+ int iters) { \
+ BM_MaxPoolBk(iters, BS, IR, IC, ND, KR, KC, ST, PT, TH, false, LABEL); \
+ } \
+ BENCHMARK( \
+ BM_MaxPoolBk_CPU_##BS##_##IR##_##IC##_##ND##_##KR##_##KC##_##ST##_ \
+ ##PT##_##TH)
+// clang-format on
+
+// Shapes taken from the 2015/05/16 inception model
+BM_MaxPoolBkGPU(32, 147, 147, 64, 3, 3, 2, VALID, 1, "maxpool_grad0_VALID");
+BM_MaxPoolBkGPU(32, 71, 71, 192, 3, 3, 2, VALID, 1, "maxpool_grad1_VALID");
+BM_MaxPoolBkGPU(32, 35, 35, 288, 3, 3, 2, VALID, 1, "maxpool_grad2_VALID");
+BM_MaxPoolBkGPU(32, 17, 17, 1248, 3, 3, 2, VALID, 1, "maxpool_grad3_VALID");
+BM_MaxPoolBkGPU(32, 8, 8, 2048, 3, 3, 2, VALID, 1, "maxpool_grad4_VALID");
+
+BM_MaxPoolBkCPU(32, 147, 147, 64, 3, 3, 2, VALID, 1, "maxpool_grad0_VALID");
+BM_MaxPoolBkCPU(32, 71, 71, 192, 3, 3, 2, VALID, 1, "maxpool_grad1_VALID");
+BM_MaxPoolBkCPU(32, 35, 35, 288, 3, 3, 2, VALID, 1, "maxpool_grad2_VALID");
+BM_MaxPoolBkCPU(32, 17, 17, 1248, 3, 3, 2, VALID, 1, "maxpool_grad3_VALID");
+BM_MaxPoolBkCPU(32, 8, 8, 2048, 3, 3, 2, VALID, 1, "maxpool_grad4_VALID");
+
+/*
+Relu Op
+Run benchmark with:
+*/
+static void BM_ReluFloat(int iters, int batch_size, int rows, int cols,
+ int depth, int num_threads, const string& label) {
+ tensorflow::testing::StopTiming();
+ std::unique_ptr<Device> device(
+ DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
+
+ thread::ThreadPool threadpool(Env::Default(), "test", num_threads);
+ EigenThreadPoolWrapper wrapper(&threadpool);
+ Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads);
+ device->set_eigen_cpu_device(&eigen_cpu_device);
+
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TensorShape shape1({batch_size, rows, cols, depth});
+ Tensor input1(DT_FLOAT, shape1);
+ test::FillIota<float>(&input1, 1.0);
+ inputs.push_back({nullptr, &input1});
+
+ // Reluing op.
+ NodeDef relu_node_def;
+ Status status = NodeDefBuilder("relu_op", "Relu")
+ .Input(FakeInput(DT_FLOAT))
+ .Finalize(&relu_node_def);
+ TF_CHECK_OK(status);
+ std::unique_ptr<OpKernel> op(CreateOpKernel(
+ DEVICE_CPU, device.get(), cpu_allocator(), relu_node_def, &status));
+ TF_CHECK_OK(status);
+ OpKernelContext::Params params;
+ params.device = device.get();
+ params.frame_iter = FrameAndIter(0, 0);
+ params.inputs = &inputs;
+ params.op_kernel = op.get();
+ params.output_alloc_attr = [&device, &op, &params](int index) {
+ AllocatorAttributes attr;
+ const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY);
+ attr.set_on_host(on_host);
+ return attr;
+ };
+
+ std::unique_ptr<OpKernelContext> relu_context(new OpKernelContext(params));
+
+ op->Compute(relu_context.get());
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters; ++i) {
+ delete relu_context->release_output(0).tensor;
+ op->Compute(relu_context.get());
+ }
+ tensorflow::testing::StopTiming();
+ testing::ItemsProcessed(relu_context->mutable_output(0)->NumElements() *
+ iters);
+ testing::SetLabel(label);
+}
+
+// BS: batch_size
+// IR: input_rows
+// IC: input_cols
+// ND: node_depth
+#define BM_Relu(BS, IR, IC, ND, TH, LABEL) \
+ static void BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH(int iters) { \
+ BM_ReluFloat(iters, BS, IR, IC, ND, TH, LABEL); \
+ } \
+ BENCHMARK(BM_ReluFloat_##BS##_##IR##_##IC##_##ND##_##TH)
+
+BM_Relu(32, 112, 112, 64, 1, "relu0");
+BM_Relu(32, 56, 56, 192, 1, "relu1");
+BM_Relu(32, 28, 28, 352, 1, "relu4");
+BM_Relu(32, 14, 14, 576, 1, "relu10");
+BM_Relu(32, 112, 112, 64, 4, "relu0");
+BM_Relu(32, 56, 56, 192, 4, "relu1");
+BM_Relu(32, 28, 28, 352, 4, "relu4");
+BM_Relu(32, 14, 14, 576, 4, "relu10");
+
+static void BM_ImageNetSoftmaxFwd(int iters, int batch_size, int node_depth,
+ int num_threads, const string& label) {
+ tensorflow::testing::StopTiming();
+ std::unique_ptr<Device> device(
+ DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
+
+ thread::ThreadPool threadpool(Env::Default(), "test", num_threads);
+ EigenThreadPoolWrapper wrapper(&threadpool);
+ Eigen::ThreadPoolDevice eigen_cpu_device(&wrapper, num_threads);
+ device->set_eigen_cpu_device(&eigen_cpu_device);
+
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TensorShape shape1({node_depth, batch_size});
+ Tensor* input1 = new Tensor(DT_FLOAT, shape1);
+ test::FillIota<float>(input1, 1.0);
+ inputs.push_back({nullptr, input1});
+
+ // Softmax op.
+ NodeDef softmax_node_def;
+ TF_CHECK_OK(NodeDefBuilder("softmax_op", "Softmax")
+ .Input("input", 0, DT_FLOAT)
+ .Finalize(&softmax_node_def));
+ Status status;
+ std::unique_ptr<OpKernel> op(CreateOpKernel(
+ DEVICE_CPU, device.get(), cpu_allocator(), softmax_node_def, &status));
+ TF_CHECK_OK(status);
+ OpKernelContext::Params params;
+ params.device = device.get();
+ params.frame_iter = FrameAndIter(0, 0);
+ params.inputs = &inputs;
+ params.op_kernel = op.get();
+ params.output_alloc_attr = [&device, &op, &params](int index) {
+ AllocatorAttributes attr;
+ const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY);
+ attr.set_on_host(on_host);
+ return attr;
+ };
+
+ std::unique_ptr<OpKernelContext> softmax_context(new OpKernelContext(params));
+
+ op->Compute(softmax_context.get());
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters; ++i) {
+ delete softmax_context->release_output(0).tensor;
+ op->Compute(softmax_context.get());
+ }
+ tensorflow::testing::StopTiming();
+ testing::ItemsProcessed(softmax_context->mutable_output(0)->NumElements() *
+ iters);
+ testing::SetLabel(label);
+}
+
+#define BM_ImageNetSoftmaxFwdCPU(BATCH_SIZE, NODE_DEPTH, TH, LABEL) \
+ static void BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH( \
+ int iters) { \
+ BM_ImageNetSoftmaxFwd(iters, BATCH_SIZE, NODE_DEPTH, TH, LABEL); \
+ } \
+ BENCHMARK(BM_ImageNetSoftmaxFwd_##BATCH_SIZE##_##NODE_DEPTH##_##TH)
+
+// Labels are taken from the 2014-July-24 version of imagenet
+BM_ImageNetSoftmaxFwdCPU(32, 1008, 1, "softmax32");
+BM_ImageNetSoftmaxFwdCPU(128, 1008, 1, "softmax128");
+BM_ImageNetSoftmaxFwdCPU(32, 1008, 4, "softmax32");
+BM_ImageNetSoftmaxFwdCPU(128, 1008, 4, "softmax128");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/count_up_to_op.cc b/tensorflow/core/kernels/count_up_to_op.cc
new file mode 100644
index 0000000000..7cf4bdb6d0
--- /dev/null
+++ b/tensorflow/core/kernels/count_up_to_op.cc
@@ -0,0 +1,51 @@
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+template <class T>
+class CountUpToOp : public OpKernel {
+ public:
+ explicit CountUpToOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("limit", &limit_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ T before_increment;
+ {
+ mutex_lock l(*context->input_ref_mutex(0));
+ Tensor tensor = context->mutable_input(0, true);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(tensor.shape()),
+ errors::InvalidArgument("input is not a scalar: ",
+ tensor.shape().DebugString()));
+ T* ptr = &tensor.scalar<T>()();
+ before_increment = *ptr;
+ if (*ptr >= limit_) {
+ context->SetStatus(errors::OutOfRange("Reached limit of ", limit_));
+ return;
+ }
+ ++*ptr;
+ }
+ // Output if no error.
+ Tensor* out_tensor;
+ OP_REQUIRES_OK(context, context->allocate_output("output", TensorShape({}),
+ &out_tensor));
+ out_tensor->scalar<T>()() = before_increment;
+ }
+
+ private:
+ T limit_;
+};
+
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("CountUpTo").TypeConstraint<TYPE>("T").Device(DEVICE_CPU), \
+ CountUpToOp<TYPE>)
+
+REGISTER(int32);
+REGISTER(int64);
+
+#undef REGISTER
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_abs.cc b/tensorflow/core/kernels/cwise_op_abs.cc
new file mode 100644
index 0000000000..5d39b88166
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_abs.cc
@@ -0,0 +1,23 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER4(UnaryOp, CPU, "Abs", functor::abs, float, double, int32, int64);
+#ifndef __ANDROID__
+REGISTER_KERNEL_BUILDER(Name("ComplexAbs").Device(DEVICE_CPU),
+ UnaryOp<CPUDevice, functor::abs<complex64>>);
+#endif
+#if GOOGLE_CUDA
+REGISTER3(UnaryOp, GPU, "Abs", functor::abs, float, double, int64);
+#endif
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Abs")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .TypeConstraint<int32>("T"),
+ UnaryOp<CPUDevice, functor::abs<int32>>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_add.cc b/tensorflow/core/kernels/cwise_op_add.cc
new file mode 100644
index 0000000000..a6cd4bddbe
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_add.cc
@@ -0,0 +1,21 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER7(BinaryOp, CPU, "Add", functor::add, float, double, int32, int64, int8,
+ int16, complex64);
+#if GOOGLE_CUDA
+REGISTER3(BinaryOp, GPU, "Add", functor::add, float, double, int64);
+#endif
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Add")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::add<int32>>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_ceil.cc b/tensorflow/core/kernels/cwise_op_ceil.cc
new file mode 100644
index 0000000000..0a8f1313f8
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_ceil.cc
@@ -0,0 +1,8 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER2(UnaryOp, CPU, "Ceil", functor::ceil, float, double);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Ceil", functor::ceil, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_complex.cc b/tensorflow/core/kernels/cwise_op_complex.cc
new file mode 100644
index 0000000000..825181bc35
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_complex.cc
@@ -0,0 +1,10 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER_KERNEL_BUILDER(Name("Complex").Device(DEVICE_CPU),
+ BinaryOp<CPUDevice, functor::make_complex<float>>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("Complex").Device(DEVICE_GPU),
+ BinaryOp<GPUDevice, functor::make_complex<float>>);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_conj.cc b/tensorflow/core/kernels/cwise_op_conj.cc
new file mode 100644
index 0000000000..ba445d1c3d
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_conj.cc
@@ -0,0 +1,10 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER_KERNEL_BUILDER(Name("Conj").Device(DEVICE_CPU),
+ UnaryOp<CPUDevice, functor::conj<complex64>>);
+#if GOOGLE_CUDA
+// REGISTER_KERNEL_BUILDER(Name("Conj").Device(DEVICE_GPU),
+// UnaryOp<GPUDevice, functor::conj<complex64>>);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_cos.cc b/tensorflow/core/kernels/cwise_op_cos.cc
new file mode 100644
index 0000000000..45e24fc2ec
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_cos.cc
@@ -0,0 +1,8 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER3(UnaryOp, CPU, "Cos", functor::cos, float, double, complex64);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Cos", functor::cos, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc
new file mode 100644
index 0000000000..76d606ed03
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_div.cc
@@ -0,0 +1,21 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(BinaryOp, CPU, "Div", functor::div, float, double, int32, int64,
+ complex64);
+#if GOOGLE_CUDA
+REGISTER3(BinaryOp, GPU, "Div", functor::div, float, double, int64);
+#endif
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Div")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::div<int32>>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_equal_to.cc b/tensorflow/core/kernels/cwise_op_equal_to.cc
new file mode 100644
index 0000000000..8369299332
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_equal_to.cc
@@ -0,0 +1,21 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(BinaryOp, CPU, "Equal", functor::equal_to, float, double, int32,
+ int64, complex64);
+#if GOOGLE_CUDA
+REGISTER3(BinaryOp, GPU, "Equal", functor::equal_to, float, double, int64);
+#endif
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Equal")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::equal_to<int32>>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_exp.cc b/tensorflow/core/kernels/cwise_op_exp.cc
new file mode 100644
index 0000000000..b2603a1b4c
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_exp.cc
@@ -0,0 +1,8 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER3(UnaryOp, CPU, "Exp", functor::exp, float, double, complex64);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Exp", functor::exp, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_floor.cc b/tensorflow/core/kernels/cwise_op_floor.cc
new file mode 100644
index 0000000000..83c8203953
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_floor.cc
@@ -0,0 +1,8 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER2(UnaryOp, CPU, "Floor", functor::floor, float, double);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Floor", functor::floor, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_abs.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_abs.cu.cc
new file mode 100644
index 0000000000..59436afbc0
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_abs.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY3(abs, float, double, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_add.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_add.cu.cc
new file mode 100644
index 0000000000..edf8e0d1a5
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_add.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY3(add, float, double, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_ceil.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_ceil.cu.cc
new file mode 100644
index 0000000000..f24c4b8b73
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_ceil.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(ceil, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc
new file mode 100644
index 0000000000..29086b5c71
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_complex.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY1(make_complex, float);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_conj.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_conj.cu.cc
new file mode 100644
index 0000000000..cae22cea8e
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_conj.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+// DEFINE_UNARY1(conj, complex64); // not working
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_cos.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_cos.cu.cc
new file mode 100644
index 0000000000..c8412496a8
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_cos.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(cos, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
new file mode 100644
index 0000000000..c581c0487e
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY3(div, float, double, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc
new file mode 100644
index 0000000000..f994822a74
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_equal_to.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY4(equal_to, float, double, int64, complex64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_exp.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_exp.cu.cc
new file mode 100644
index 0000000000..caeaa19cef
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_exp.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(exp, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_floor.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_floor.cu.cc
new file mode 100644
index 0000000000..0a06ff2978
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_floor.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(floor, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_greater.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_greater.cu.cc
new file mode 100644
index 0000000000..e1278e077b
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_greater.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY3(greater, float, double, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_greater_equal.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_greater_equal.cu.cc
new file mode 100644
index 0000000000..fafcf9b28a
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_greater_equal.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY3(greater_equal, float, double, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc
new file mode 100644
index 0000000000..0370782c96
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_imag.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY1(get_imag, complex64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc
new file mode 100644
index 0000000000..020abef210
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_inverse.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY3(inverse, float, double, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_isfinite.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_isfinite.cu.cc
new file mode 100644
index 0000000000..7a3a273af7
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_isfinite.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(isfinite, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_isinf.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_isinf.cu.cc
new file mode 100644
index 0000000000..cfc4be3d25
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_isinf.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(isinf, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_isnan.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_isnan.cu.cc
new file mode 100644
index 0000000000..c93b74387e
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_isnan.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(isnan, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_less.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_less.cu.cc
new file mode 100644
index 0000000000..8e2b28ac60
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_less.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY3(less, float, double, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_less_equal.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_less_equal.cu.cc
new file mode 100644
index 0000000000..be8e34a58b
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_less_equal.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY3(less_equal, float, double, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_log.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_log.cu.cc
new file mode 100644
index 0000000000..7d183cce50
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_log.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(log, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_logical_and.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_logical_and.cu.cc
new file mode 100644
index 0000000000..ba7046f9f0
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_logical_and.cu.cc
@@ -0,0 +1,13 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+template struct BinaryFunctor<GPUDevice, logical_and, 1>;
+template struct BinaryFunctor<GPUDevice, logical_and, 2>;
+template struct BinaryFunctor<GPUDevice, logical_and, 3>;
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_logical_not.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_logical_not.cu.cc
new file mode 100644
index 0000000000..34a43a76ef
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_logical_not.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+template struct UnaryFunctor<GPUDevice, logical_not>;
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_logical_or.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_logical_or.cu.cc
new file mode 100644
index 0000000000..47a7bd68dc
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_logical_or.cu.cc
@@ -0,0 +1,13 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+template struct BinaryFunctor<GPUDevice, logical_or, 1>;
+template struct BinaryFunctor<GPUDevice, logical_or, 2>;
+template struct BinaryFunctor<GPUDevice, logical_or, 3>;
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_maximum.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_maximum.cu.cc
new file mode 100644
index 0000000000..8f7ab90e9a
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_maximum.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY3(maximum, float, double, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_minimum.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_minimum.cu.cc
new file mode 100644
index 0000000000..75fd7f89b4
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_minimum.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY3(minimum, float, double, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_mod.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_mod.cu.cc
new file mode 100644
index 0000000000..d08a17a94d
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_mod.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+// No GPU ops for mod yet.
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc
new file mode 100644
index 0000000000..e0a6738bef
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_mul.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY3(mul, float, double, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_neg.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_neg.cu.cc
new file mode 100644
index 0000000000..3031afbb75
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_neg.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY4(neg, float, double, int32, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc
new file mode 100644
index 0000000000..59c76ee88b
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_not_equal_to.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY4(not_equal_to, float, double, int64, complex64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_pow.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_pow.cu.cc
new file mode 100644
index 0000000000..50177495bc
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_pow.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY3(pow, float, double, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_real.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_real.cu.cc
new file mode 100644
index 0000000000..3b1d465914
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_real.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY1(get_real, complex64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc
new file mode 100644
index 0000000000..682e2d2d4b
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_rsqrt.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(rsqrt, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc
new file mode 100644
index 0000000000..b5125648e3
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc
@@ -0,0 +1,15 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+template struct SelectFunctor<GPUDevice, float>;
+template struct SelectFunctor<GPUDevice, double>;
+template struct SelectFunctor<GPUDevice, int32>;
+template struct SelectFunctor<GPUDevice, int64>;
+template struct SelectFunctor<GPUDevice, complex64>;
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc
new file mode 100644
index 0000000000..9c250f3071
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(sigmoid, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_sign.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sign.cu.cc
new file mode 100644
index 0000000000..f413480ecc
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_sign.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY3(sign, float, double, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_sin.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sin.cu.cc
new file mode 100644
index 0000000000..6135f3b780
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_sin.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(sin, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc
new file mode 100644
index 0000000000..9bdf3b9e30
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_sqrt.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(sqrt, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_square.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_square.cu.cc
new file mode 100644
index 0000000000..6b900e994d
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_square.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY3(square, float, double, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_sub.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sub.cu.cc
new file mode 100644
index 0000000000..6fd5ea0d38
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_sub.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY3(sub, float, double, int64);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc
new file mode 100644
index 0000000000..e0393f6c2a
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc
@@ -0,0 +1,11 @@
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(tanh, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_greater.cc b/tensorflow/core/kernels/cwise_op_greater.cc
new file mode 100644
index 0000000000..9ae31dcdfe
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_greater.cc
@@ -0,0 +1,21 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER4(BinaryOp, CPU, "Greater", functor::greater, float, double, int32,
+ int64);
+#if GOOGLE_CUDA
+REGISTER3(BinaryOp, GPU, "Greater", functor::greater, float, double, int64);
+#endif
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Greater")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::greater<int32>>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_greater_equal.cc b/tensorflow/core/kernels/cwise_op_greater_equal.cc
new file mode 100644
index 0000000000..be4cc5dc79
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_greater_equal.cc
@@ -0,0 +1,22 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER4(BinaryOp, CPU, "GreaterEqual", functor::greater_equal, float, double,
+ int32, int64);
+#if GOOGLE_CUDA
+REGISTER3(BinaryOp, GPU, "GreaterEqual", functor::greater_equal, float, double,
+ int64);
+#endif
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("GreaterEqual")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::greater_equal<int32>>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_imag.cc b/tensorflow/core/kernels/cwise_op_imag.cc
new file mode 100644
index 0000000000..c2432326fc
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_imag.cc
@@ -0,0 +1,10 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER_KERNEL_BUILDER(Name("Imag").Device(DEVICE_CPU),
+ UnaryOp<CPUDevice, functor::get_imag<complex64>>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("Imag").Device(DEVICE_GPU),
+ UnaryOp<GPUDevice, functor::get_imag<complex64>>);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_inverse.cc b/tensorflow/core/kernels/cwise_op_inverse.cc
new file mode 100644
index 0000000000..6af883e755
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_inverse.cc
@@ -0,0 +1,8 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER3(UnaryOp, CPU, "Inv", functor::inverse, float, double, complex64);
+#if GOOGLE_CUDA
+REGISTER3(UnaryOp, GPU, "Inv", functor::inverse, float, double, int64);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_isfinite.cc b/tensorflow/core/kernels/cwise_op_isfinite.cc
new file mode 100644
index 0000000000..e52d199a8f
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_isfinite.cc
@@ -0,0 +1,8 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER2(UnaryOp, CPU, "IsFinite", functor::isfinite, float, double);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "IsFinite", functor::isfinite, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_isinf.cc b/tensorflow/core/kernels/cwise_op_isinf.cc
new file mode 100644
index 0000000000..868204f86e
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_isinf.cc
@@ -0,0 +1,8 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER2(UnaryOp, CPU, "IsInf", functor::isinf, float, double);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "IsInf", functor::isinf, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_isnan.cc b/tensorflow/core/kernels/cwise_op_isnan.cc
new file mode 100644
index 0000000000..a8f4d60d0f
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_isnan.cc
@@ -0,0 +1,8 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER2(UnaryOp, CPU, "IsNan", functor::isnan, float, double);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "IsNan", functor::isnan, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_less.cc b/tensorflow/core/kernels/cwise_op_less.cc
new file mode 100644
index 0000000000..3b5f75445c
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_less.cc
@@ -0,0 +1,20 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER4(BinaryOp, CPU, "Less", functor::less, float, double, int32, int64);
+#if GOOGLE_CUDA
+REGISTER3(BinaryOp, GPU, "Less", functor::less, float, double, int64);
+#endif
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Less")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::less<int32>>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_less_equal.cc b/tensorflow/core/kernels/cwise_op_less_equal.cc
new file mode 100644
index 0000000000..507c7c2908
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_less_equal.cc
@@ -0,0 +1,22 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER4(BinaryOp, CPU, "LessEqual", functor::less_equal, float, double, int32,
+ int64);
+#if GOOGLE_CUDA
+REGISTER3(BinaryOp, GPU, "LessEqual", functor::less_equal, float, double,
+ int64);
+#endif
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("LessEqual")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::less_equal<int32>>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_log.cc b/tensorflow/core/kernels/cwise_op_log.cc
new file mode 100644
index 0000000000..ebc7cbcc4e
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_log.cc
@@ -0,0 +1,8 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER3(UnaryOp, CPU, "Log", functor::log, float, double, complex64);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Log", functor::log, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_logical_and.cc b/tensorflow/core/kernels/cwise_op_logical_and.cc
new file mode 100644
index 0000000000..a4075088f4
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_logical_and.cc
@@ -0,0 +1,10 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER_KERNEL_BUILDER(Name("LogicalAnd").Device(DEVICE_CPU),
+ BinaryOp<CPUDevice, functor::logical_and>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("LogicalAnd").Device(DEVICE_GPU),
+ BinaryOp<GPUDevice, functor::logical_and>);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_logical_not.cc b/tensorflow/core/kernels/cwise_op_logical_not.cc
new file mode 100644
index 0000000000..b2e97bf70c
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_logical_not.cc
@@ -0,0 +1,10 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER_KERNEL_BUILDER(Name("LogicalNot").Device(DEVICE_CPU),
+ UnaryOp<CPUDevice, functor::logical_not>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("LogicalNot").Device(DEVICE_GPU),
+ UnaryOp<GPUDevice, functor::logical_not>);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_logical_or.cc b/tensorflow/core/kernels/cwise_op_logical_or.cc
new file mode 100644
index 0000000000..0d1df082f7
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_logical_or.cc
@@ -0,0 +1,10 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER_KERNEL_BUILDER(Name("LogicalOr").Device(DEVICE_CPU),
+ BinaryOp<CPUDevice, functor::logical_or>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("LogicalOr").Device(DEVICE_GPU),
+ BinaryOp<GPUDevice, functor::logical_or>);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_maximum.cc b/tensorflow/core/kernels/cwise_op_maximum.cc
new file mode 100644
index 0000000000..c0c9e3f6f5
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_maximum.cc
@@ -0,0 +1,21 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER4(BinaryOp, CPU, "Maximum", functor::maximum, float, double, int32,
+ int64);
+#if GOOGLE_CUDA
+REGISTER3(BinaryOp, GPU, "Maximum", functor::maximum, float, double, int64);
+#endif
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Maximum")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::maximum<int32>>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_minimum.cc b/tensorflow/core/kernels/cwise_op_minimum.cc
new file mode 100644
index 0000000000..4c6bf7df05
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_minimum.cc
@@ -0,0 +1,21 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER4(BinaryOp, CPU, "Minimum", functor::minimum, float, double, int32,
+ int64);
+#if GOOGLE_CUDA
+REGISTER3(BinaryOp, GPU, "Minimum", functor::minimum, float, double, int64);
+#endif
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Minimum")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::minimum<int32>>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_mod.cc b/tensorflow/core/kernels/cwise_op_mod.cc
new file mode 100644
index 0000000000..17f2834030
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_mod.cc
@@ -0,0 +1,6 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER2(BinaryOp, CPU, "Mod", functor::mod, int32, int64);
+REGISTER2(BinaryOp, CPU, "Mod", functor::fmod, float, double);
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_mul.cc b/tensorflow/core/kernels/cwise_op_mul.cc
new file mode 100644
index 0000000000..15f65012cd
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_mul.cc
@@ -0,0 +1,21 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER7(BinaryOp, CPU, "Mul", functor::mul, float, double, int32, int64, int8,
+ int16, complex64);
+#if GOOGLE_CUDA
+REGISTER3(BinaryOp, GPU, "Mul", functor::mul, float, double, int64);
+#endif
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Mul")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::mul<int32>>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_neg.cc b/tensorflow/core/kernels/cwise_op_neg.cc
new file mode 100644
index 0000000000..3a19b2e94f
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_neg.cc
@@ -0,0 +1,9 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(UnaryOp, CPU, "Neg", functor::neg, float, double, int32, complex64,
+ int64);
+#if GOOGLE_CUDA
+REGISTER4(UnaryOp, GPU, "Neg", functor::neg, float, double, int32, int64);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_not_equal_to.cc b/tensorflow/core/kernels/cwise_op_not_equal_to.cc
new file mode 100644
index 0000000000..02d434a1c2
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_not_equal_to.cc
@@ -0,0 +1,10 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, double,
+ int32, int64, complex64);
+#if GOOGLE_CUDA
+REGISTER3(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, double,
+ int64);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_pow.cc b/tensorflow/core/kernels/cwise_op_pow.cc
new file mode 100644
index 0000000000..d10dced85f
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_pow.cc
@@ -0,0 +1,9 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(BinaryOp, CPU, "Pow", functor::pow, float, double, int32, int64,
+ complex64);
+#if GOOGLE_CUDA
+REGISTER3(BinaryOp, GPU, "Pow", functor::pow, float, double, int64);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_real.cc b/tensorflow/core/kernels/cwise_op_real.cc
new file mode 100644
index 0000000000..84295a5a16
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_real.cc
@@ -0,0 +1,10 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER_KERNEL_BUILDER(Name("Real").Device(DEVICE_CPU),
+ UnaryOp<CPUDevice, functor::get_real<complex64>>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("Real").Device(DEVICE_GPU),
+ UnaryOp<GPUDevice, functor::get_real<complex64>>);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_rsqrt.cc b/tensorflow/core/kernels/cwise_op_rsqrt.cc
new file mode 100644
index 0000000000..a22b1209de
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_rsqrt.cc
@@ -0,0 +1,8 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER3(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, double, complex64);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Rsqrt", functor::rsqrt, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
new file mode 100644
index 0000000000..baa821690a
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -0,0 +1,17 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER_SELECT(CPU, "Select", "", float);
+REGISTER_SELECT(CPU, "Select", "", double);
+REGISTER_SELECT(CPU, "Select", "", int32);
+REGISTER_SELECT(CPU, "Select", "", int64);
+REGISTER_SELECT(CPU, "Select", "", complex64);
+REGISTER_SELECT(CPU, "Select", "", string);
+#if GOOGLE_CUDA
+REGISTER_SELECT(GPU, "Select", "", float);
+REGISTER_SELECT(GPU, "Select", "", double);
+REGISTER_SELECT(GPU, "Select", "", int32);
+REGISTER_SELECT(GPU, "Select", "", int64);
+REGISTER_SELECT(GPU, "Select", "", complex64);
+#endif // GOOGLE_CUDA
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_sigmoid.cc b/tensorflow/core/kernels/cwise_op_sigmoid.cc
new file mode 100644
index 0000000000..e03b5d54dd
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_sigmoid.cc
@@ -0,0 +1,8 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER3(UnaryOp, CPU, "Sigmoid", functor::sigmoid, float, double, complex64);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Sigmoid", functor::sigmoid, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_sign.cc b/tensorflow/core/kernels/cwise_op_sign.cc
new file mode 100644
index 0000000000..59a0bfa1ed
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_sign.cc
@@ -0,0 +1,19 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER4(UnaryOp, CPU, "Sign", functor::sign, float, double, int32, int64);
+#if GOOGLE_CUDA
+REGISTER3(UnaryOp, GPU, "Sign", functor::sign, float, double, int64);
+#endif
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Sign")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .TypeConstraint<int32>("T"),
+ UnaryOp<CPUDevice, functor::sign<int32>>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_sin.cc b/tensorflow/core/kernels/cwise_op_sin.cc
new file mode 100644
index 0000000000..e7c87374d7
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_sin.cc
@@ -0,0 +1,8 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER3(UnaryOp, CPU, "Sin", functor::sin, float, double, complex64);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Sin", functor::sin, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_sqrt.cc b/tensorflow/core/kernels/cwise_op_sqrt.cc
new file mode 100644
index 0000000000..f43241264a
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_sqrt.cc
@@ -0,0 +1,8 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER3(UnaryOp, CPU, "Sqrt", functor::sqrt, float, double, complex64);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Sqrt", functor::sqrt, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_square.cc b/tensorflow/core/kernels/cwise_op_square.cc
new file mode 100644
index 0000000000..510fda49aa
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_square.cc
@@ -0,0 +1,9 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(UnaryOp, CPU, "Square", functor::square, float, double, int32,
+ complex64, int64);
+#if GOOGLE_CUDA
+REGISTER3(UnaryOp, GPU, "Square", functor::square, float, double, int64);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_sub.cc b/tensorflow/core/kernels/cwise_op_sub.cc
new file mode 100644
index 0000000000..c3c5952f8d
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_sub.cc
@@ -0,0 +1,21 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(BinaryOp, CPU, "Sub", functor::sub, float, double, int32, int64,
+ complex64);
+#if GOOGLE_CUDA
+REGISTER3(BinaryOp, GPU, "Sub", functor::sub, float, double, int64);
+#endif
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Sub")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .HostMemory("z")
+ .TypeConstraint<int32>("T"),
+ BinaryOp<CPUDevice, functor::sub<int32>>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_tanh.cc b/tensorflow/core/kernels/cwise_op_tanh.cc
new file mode 100644
index 0000000000..31f4743449
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_tanh.cc
@@ -0,0 +1,8 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER3(UnaryOp, CPU, "Tanh", functor::tanh, float, double, complex64);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Tanh", functor::tanh, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
new file mode 100644
index 0000000000..7d818cfbbf
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -0,0 +1,607 @@
+#ifndef TENSORFLOW_KERNELS_CWISE_OPS_H_
+#define TENSORFLOW_KERNELS_CWISE_OPS_H_
+
+#include <cmath>
+#include <functional>
+#include "tensorflow/core/framework/numeric_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+// The following functors (sign, tanh, sigmoid, etc.) are not defined
+// by Eigen. When their equivalent are added into the Eigen, we can
+// replace them using type aliases.
+
+namespace Eigen {
+namespace internal {
+
+template <typename T>
+struct scalar_sign_op {
+ // TODO(zhifengc): this only works for real types. In theory,
+ // sign(x) = x / |x| works for both real and complex values.
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op);
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const {
+ return T(x > T(0)) - T(x < T(0));
+ }
+};
+
+// TODO(zhifengc): Eigen::internal::pow_impl does not have proper
+// EIGEN host/device decoration. We duplicate code here for now.
+template <typename T, bool IsInteger>
+struct pow {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T
+ operator()(const T& x, const T& y) const {
+ return std::pow(x, y);
+ }
+};
+
+template <typename T>
+struct pow<T, true> {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(T x, T y) const {
+ T res(1);
+ if (y & 1) res *= x;
+ y >>= 1;
+ while (y) {
+ x *= x;
+ if (y & 1) res *= x;
+ y >>= 1;
+ }
+ return res;
+ }
+};
+
+template <typename T>
+struct scalar_pow2_op : pow<T, NumTraits<T>::IsInteger> {};
+
+template <typename T>
+struct functor_traits<scalar_pow2_op<T> > {
+ enum {
+ Cost = 5 * NumTraits<T>::MulCost,
+ PacketAccess = false,
+ };
+};
+
+template <typename T>
+struct scalar_fmod2_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_fmod2_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a,
+ const T& b) const {
+ return fmod(a, b);
+ }
+};
+
+template <typename T>
+struct scalar_mod2_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_mod2_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
+ operator()(const T& a, const T& b) const {
+ return a % b;
+ }
+};
+
+template <typename T>
+struct functor_traits<scalar_mod2_op<T> > {
+ enum {
+ Cost = 5, // Roughly the cost of a div
+ PacketAccess = false,
+ };
+};
+
+// scalar_left and scalar_right are template helpers to partially
+// apply a binary function.
+//
+// Suppose Binary is a binary functor f(x, y), scalar_left<> is a
+// unary functor g_x(y) = f(x, y), where x is provided via the
+// constructor. Similarly, scalar_right<> is a unary functor g_y(x) =
+// f(x, y).
+
+template <typename Tout, typename Tin, typename Binary,
+ bool PacketAccess = functor_traits<Binary>::PacketAccess>
+struct scalar_left {
+ typedef Tout result_type;
+ const Tin* left;
+ EIGEN_DEVICE_FUNC inline scalar_left(
+ const scalar_left& other) // NOLINT(runtime/explicit)
+ : left(other.left) {}
+ EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c) : left(c) {}
+ EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const {
+ return Binary()(*left, right);
+ }
+};
+
+template <typename Tout, typename Tin, typename Binary>
+struct scalar_left<Tout, Tin, Binary, true> {
+ typedef Tout result_type;
+ const Tin* left;
+ EIGEN_DEVICE_FUNC inline scalar_left(
+ const scalar_left& other) // NOLINT(runtime/explicit)
+ : left(other.left) {}
+ EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c) : left(c) {}
+ EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const {
+ return Binary()(*left, right);
+ }
+
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const {
+ const Packet left_packet = Eigen::internal::pset1<Packet>(*left);
+ return Binary().packetOp(left_packet, right_packet);
+ }
+};
+
+template <typename Tout, typename Tin, typename Binary>
+struct functor_traits<scalar_left<Tout, Tin, Binary> > {
+ enum {
+ Cost = functor_traits<Binary>::Cost,
+ PacketAccess = functor_traits<Binary>::PacketAccess,
+ };
+};
+
+template <typename Tout, typename Tin, typename Binary,
+ bool PacketAccess = functor_traits<Binary>::PacketAccess>
+struct scalar_right {
+ typedef Tout result_type;
+ const Tin* right;
+ EIGEN_DEVICE_FUNC inline scalar_right(
+ const scalar_right& other) // NOLINT(runtime/explicit)
+ : right(other.right) {}
+ EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c) : right(c) {}
+ EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const {
+ return Binary()(left, *right);
+ }
+};
+
+template <typename Tout, typename Tin, typename Binary>
+struct scalar_right<Tout, Tin, Binary, true> {
+ typedef Tout result_type;
+ const Tin* right;
+ EIGEN_DEVICE_FUNC inline scalar_right(
+ const scalar_right& other) // NOLINT(runtime/explicit)
+ : right(other.right) {}
+ EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c) : right(c) {}
+ EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const {
+ return Binary()(left, *right);
+ }
+
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const {
+ const Packet right_packet = Eigen::internal::pset1<Packet>(*right);
+ return Binary().packetOp(left_packet, right_packet);
+ }
+};
+
+template <typename Tout, typename Tin, typename Binary>
+struct functor_traits<scalar_right<Tout, Tin, Binary> > {
+ enum {
+ Cost = functor_traits<Binary>::Cost,
+ PacketAccess = functor_traits<Binary>::PacketAccess,
+ };
+};
+
+// similar to std::equal_to, but with the DEVICE_FUNC qualifier
+template <class T>
+struct equal_to : std::binary_function<T, T, bool> {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ bool operator()(const T& x, const T& y) const { return x == y; }
+};
+
+// similar to std::not_equal_to, but with the DEVICE_FUNC qualifier
+template <class T>
+struct not_equal_to : std::binary_function<T, T, bool> {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ bool operator()(const T& x, const T& y) const { return x != y; }
+};
+
+// similar to std::greater, but with the DEVICE_FUNC qualifier
+template <class T>
+struct greater : std::binary_function<T, T, bool> {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ bool operator()(const T& x, const T& y) const { return x > y; }
+};
+
+// similar to std::less, but with the DEVICE_FUNC qualifier
+template <class T>
+struct less : std::binary_function<T, T, bool> {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ bool operator()(const T& x, const T& y) const { return x < y; }
+};
+
+// similar to std::greater_equal, but with the DEVICE_FUNC qualifier
+template <class T>
+struct greater_equal : std::binary_function<T, T, bool> {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ bool operator()(const T& x, const T& y) const { return x >= y; }
+};
+
+// similar to std::less_equal, but with the DEVICE_FUNC qualifier
+template <class T>
+struct less_equal : std::binary_function<T, T, bool> {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ bool operator()(const T& x, const T& y) const { return x <= y; }
+};
+
+} // end namespace internal
+} // end namespace Eigen
+
+namespace tensorflow {
+namespace functor {
+
+////////////////////////////////////////////////////////////////////////////////
+// Helpers
+////////////////////////////////////////////////////////////////////////////////
+
+// Base template for functors whose input scalar type is T and
+// output scalar type is R.
+template <typename T, typename F, typename R = T>
+struct base {
+ // func defines operator() and its vectorized version packetOp().
+ typedef F func;
+
+ // If true, the functor's corresponding binary op will instantiate
+ // specialized kernels to perform an optimized broadcast
+ // operation. Each functor for which this is enabled increases the
+ // code size, so by default this is disabled for binary functors and
+ // is enabled on a per-op basis as needed.
+ static const bool use_bcast_optimization = false;
+
+ // operator() has the signature:
+ // out_type operator()(in_type in0, in_type in1 ...)
+ typedef R out_type;
+ typedef T in_type;
+
+ // TensorFlow provides tensor-ized version of "func". Roughly
+ // speaking, the tensorflow operation has the signature:
+ // tout_type op(tin_type in0)
+ // tout_type op(tin_type in0, tin_type in1)
+ // tout_type op(tin_type in0, in_type scalar)
+ typedef typename TTypes<out_type>::Flat tout_type;
+ typedef typename TTypes<in_type>::ConstFlat tin_type;
+ typedef typename TTypes<in_type>::ConstScalar tscalar_type;
+};
+
+// For now, we only apply certain speed optimization for
+// float/double's broadcast binary op.
+template <typename T>
+struct use_bcast_optimization {
+ static const bool value = false;
+};
+
+template <>
+struct use_bcast_optimization<float> {
+ static const bool value = true;
+};
+
+template <>
+struct use_bcast_optimization<double> {
+ static const bool value = true;
+};
+
+////////////////////////////////////////////////////////////////////////////////
+// Unary functors
+////////////////////////////////////////////////////////////////////////////////
+
+// abs(x) = |x|
+// neg(x) = - x
+// inverse(x) = 1 / x
+// square(x) = x^2
+// sqrt(x) = x^(1/2)
+// rsqrt(x) = x^(-1/2)
+// exp(x) = e^x
+// log(x) = natural logrithm of x
+// tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
+// sigmoid = 1 / (1 + exp(-x)) // a.k.a, logistic
+//
+// NOTE: We may eventually implement common functions used in NN
+// here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc.
+// For reference, see speech/lstm/eigen_functors.h.
+
+template <typename T>
+struct abs : base<T, Eigen::internal::scalar_abs_op<T>,
+ typename Eigen::internal::scalar_abs_op<T>::result_type> {};
+
+template <typename T>
+struct neg : base<T, Eigen::internal::scalar_opposite_op<T> > {};
+
+template <typename T>
+struct inverse : base<T, Eigen::internal::scalar_inverse_op<T> > {};
+
+template <typename T>
+struct square : base<T, Eigen::internal::scalar_square_op<T> > {};
+
+template <typename T>
+struct sqrt : base<T, Eigen::internal::scalar_sqrt_op<T> > {};
+
+template <typename T>
+struct rsqrt : base<T, Eigen::internal::scalar_rsqrt_op<T> > {};
+
+template <typename T>
+struct exp : base<T, Eigen::internal::scalar_exp_op<T> > {};
+
+template <typename T>
+struct log : base<T, Eigen::internal::scalar_log_op<T> > {};
+
+template <typename T>
+struct sign : base<T, Eigen::internal::scalar_sign_op<T> > {};
+
+template <typename T>
+struct tanh : base<T, Eigen::internal::scalar_tanh_op<T> > {};
+
+template <typename T>
+struct sigmoid : base<T, Eigen::internal::scalar_sigmoid_op<T> > {};
+
+template <typename T>
+struct sin : base<T, Eigen::internal::scalar_sin_op<T> > {};
+
+template <typename T>
+struct cos : base<T, Eigen::internal::scalar_cos_op<T> > {};
+
+struct logical_not : base<bool, std::logical_not<bool> > {};
+
+namespace impl {
+
+#ifndef __CUDACC__
+// Uses STL std cmath functions.
+template <typename T>
+bool isinf(T v) {
+ return std::isinf(v);
+}
+
+template <typename T>
+bool isnan(T v) {
+ return std::isnan(v);
+}
+
+template <typename T>
+bool isfinite(T v) {
+ return std::isfinite(v);
+}
+
+template <typename T>
+T floor(T v) {
+ return std::floor(v);
+}
+
+template <typename T>
+T ceil(T v) {
+ return std::ceil(v);
+}
+#else
+// Uses CUDA's functions for float and double.
+template <typename T>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isinf(T v) {
+ return ::isinf(v);
+}
+
+template <typename T>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isnan(T v) {
+ return ::isnan(v);
+}
+
+template <typename T>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isfinite(T v) {
+ return ::isfinite(v);
+}
+
+template <typename T>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T floor(T v) {
+ return ::floor(v);
+}
+
+template <typename T>
+EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T ceil(T v) {
+ return ::ceil(v);
+}
+#endif
+} // end namespace impl
+
+// NOTE: std::isinf, std::isnan, std::isfinite are plain function.
+// Therefore we need to wrap them in functors to be used with Eigen's
+// type system.
+
+template <typename T>
+struct isinf_func {
+ typedef bool result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(T x) const {
+ return impl::isinf(x);
+ }
+};
+
+template <typename T>
+struct isinf : base<T, isinf_func<T>, bool> {};
+
+template <typename T>
+struct isnan_func {
+ typedef bool result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(T x) const {
+ return impl::isnan(x);
+ }
+};
+
+template <typename T>
+struct isnan : base<T, isnan_func<T>, bool> {};
+
+template <typename T>
+struct isfinite_func {
+ typedef bool result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(T x) const {
+ return impl::isfinite(x);
+ }
+};
+
+template <typename T>
+struct isfinite : base<T, isfinite_func<T>, bool> {};
+
+template <typename T>
+struct floor_func {
+ typedef T result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(T x) const {
+ return impl::floor(x);
+ }
+};
+
+template <typename T>
+struct floor : base<T, floor_func<T> > {};
+
+template <typename T>
+struct ceil_func {
+ typedef T result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(T x) const {
+ return impl::ceil(x);
+ }
+};
+
+template <typename T>
+struct ceil : base<T, ceil_func<T> > {};
+
+////////////////////////////////////////////////////////////////////////////////
+// Binary functors
+////////////////////////////////////////////////////////////////////////////////
+
+// Binary functors:
+//
+// add(x, y) = x + y
+// sub(x, y) = x - y
+// mul(x, y) = x * y
+// div(x, y) = x / y
+// mod(x, y) = x % y (int32 and int64 only)
+// fmod(x, y) = fmod(x, y) (float and double only)
+// pow(x, y) = x ^ y
+// maximum(x, y) = x > y ? x : y
+// minimum(x, y) = x < y ? x : y
+
+template <typename T>
+struct add : base<T, Eigen::internal::scalar_sum_op<T> > {
+ static const bool use_bcast_optimization = true;
+};
+
+template <typename T>
+struct sub : base<T, Eigen::internal::scalar_difference_op<T> > {
+ static const bool use_bcast_optimization = true;
+};
+
+template <typename T>
+struct mul : base<T, Eigen::internal::scalar_product_op<T> > {};
+
+template <typename T>
+struct div : base<T, Eigen::internal::scalar_quotient_op<T> > {};
+
+template <typename T>
+struct fmod : base<T, Eigen::internal::scalar_fmod2_op<T> > {};
+
+template <typename T>
+struct mod : base<T, Eigen::internal::scalar_mod2_op<T> > {};
+
+template <typename T>
+struct pow : base<T, Eigen::internal::scalar_pow2_op<T> > {};
+
+template <typename T>
+struct maximum : base<T, Eigen::internal::scalar_max_op<T> > {};
+
+template <typename T>
+struct minimum : base<T, Eigen::internal::scalar_min_op<T> > {};
+
+template <typename T>
+struct less : base<T, Eigen::internal::less<T>, bool> {};
+
+template <typename T>
+struct less_equal : base<T, Eigen::internal::less_equal<T>, bool> {};
+
+template <typename T>
+struct greater : base<T, Eigen::internal::greater<T>, bool> {};
+
+template <typename T>
+struct greater_equal : base<T, Eigen::internal::greater_equal<T>, bool> {};
+
+template <typename T>
+struct equal_to : base<T, Eigen::internal::equal_to<T>, bool> {};
+
+template <typename T>
+struct not_equal_to : base<T, Eigen::internal::not_equal_to<T>, bool> {};
+
+struct logical_and : base<bool, Eigen::internal::scalar_boolean_and_op> {};
+
+struct logical_or : base<bool, Eigen::internal::scalar_boolean_or_op> {};
+
+template <typename T>
+struct make_complex_func {
+ typedef std::complex<T> result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ result_type operator()(T real, T imag) const {
+ return std::complex<T>(real, imag);
+ }
+};
+
+template <typename T>
+struct make_complex : base<T, make_complex_func<T>, std::complex<T> > {};
+
+template <typename T>
+struct get_real
+ : base<T, Eigen::internal::scalar_real_op<T>, typename T::value_type> {};
+
+template <typename T>
+struct get_imag
+ : base<T, Eigen::internal::scalar_imag_op<T>, typename T::value_type> {};
+
+template <typename T>
+struct conj : base<T, Eigen::internal::scalar_conjugate_op<T> > {};
+
+////////////////////////////////////////////////////////////////////////////////
+// Functors takes 1 or 2 tensors, computes the base functor on
+// coefficient of the input tensors and puts the results in the output
+// tensor.
+////////////////////////////////////////////////////////////////////////////////
+template <typename Device, typename Functor>
+struct UnaryFunctor {
+ // Computes on device "d": out[i] = Functor(in[i])
+ void operator()(const Device& d, typename Functor::tout_type out,
+ typename Functor::tin_type in);
+};
+
+template <typename Device, typename Functor, int NDIMS>
+struct BinaryFunctor {
+ // Computes on device "d": out[i] = Functor(in0[i], in1[i])
+ void operator()(const Device& d, typename Functor::tout_type out,
+ typename Functor::tin_type in0,
+ typename Functor::tin_type in1);
+
+ // Computes on device "d": out[i] = Functor(scalar[0], in[i])
+ void Left(const Device& d, typename Functor::tout_type out,
+ typename Functor::tscalar_type scalar,
+ typename Functor::tin_type in);
+
+ // Computes on device "d": out[i] = Functor(in[i], scalar[0])
+ void Right(const Device& d, typename Functor::tout_type out,
+ typename Functor::tin_type in,
+ typename Functor::tscalar_type scalar);
+
+ // Computes on device "d":
+ // out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast01))
+ //
+ // TODO(zhifengc): makes BCast a template member function on NDIMS
+ // instead making BinaryFunctor templates on NDIMS.
+ void BCast(const Device& d,
+ typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
+ typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
+ typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
+ typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
+ typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1);
+};
+
+template <int NDIMS>
+bool AllOne(const typename Eigen::array<Eigen::DenseIndex, NDIMS>& a) {
+ for (int i = 0; i < a.size(); ++i) {
+ if (a[i] != 1) return false;
+ }
+ return true;
+}
+
+template <typename Device, typename T>
+struct SelectFunctor {
+ void operator()(const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<bool>::ConstFlat cond_flat,
+ typename TTypes<T>::ConstFlat then_flat,
+ typename TTypes<T>::ConstFlat else_flat);
+};
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_CWISE_OPS_H_
diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc
new file mode 100644
index 0000000000..f86d2ddd9a
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_ops_common.cc
@@ -0,0 +1,42 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+
+BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out,
+ DataType in)
+ : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out}));
+}
+
+void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) {
+ ctx->SetStatus(errors::Unimplemented(
+ "Broadcast between ", ctx->input(0).shape().ShortDebugString(), " and ",
+ ctx->input(1).shape().ShortDebugString(), " is not supported yet."));
+}
+
+static BCast::Vec FromShape(const TensorShape& shape) {
+ BCast::Vec ret;
+ for (int i = 0; i < shape.dims(); ++i) ret.push_back(shape.dim_size(i));
+ return ret;
+}
+
+static TensorShape ToShape(const BCast::Vec& vec) {
+ TensorShape shape;
+ for (auto elem : vec) shape.AddDim(elem);
+ return shape;
+}
+
+BinaryOpShared::BinaryOpState::BinaryOpState(OpKernelContext* ctx)
+ : bcast(FromShape(ctx->input(0).shape()),
+ FromShape(ctx->input(1).shape())) {
+ if (!bcast.IsValid()) {
+ ctx->SetStatus(errors::InvalidArgument(
+ "Incompatible shapes: ", ctx->input(0).shape().ShortDebugString(),
+ " vs. ", ctx->input(1).shape().ShortDebugString()));
+ return;
+ }
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, ToShape(bcast.output_shape()), &out));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
new file mode 100644
index 0000000000..cf848b86d1
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -0,0 +1,390 @@
+#ifndef TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_
+#define TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_
+
+// See docs in ../ops/math_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/cwise_ops.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/bcast.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+class BinaryOpShared : public OpKernel {
+ public:
+ explicit BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in);
+
+ protected:
+ struct BinaryOpState {
+ // Sets up bcast with the shape of in0 and in1, ensures that the bcast
+ // is valid, and if so, allocates out using ctx->output(...).
+ // Caller must check ctx->status() upon return for non-ok status.
+ // If ctx->status().ok() is true, then out is guaranteed to be allocated.
+ BinaryOpState(OpKernelContext* ctx);
+
+ BCast bcast;
+ Tensor* out = nullptr;
+ };
+
+ template <int NDIMS>
+ static Eigen::array<Eigen::DenseIndex, NDIMS> ToIndexArray(
+ const BCast::Vec& vec) {
+ CHECK_EQ(vec.size(), NDIMS);
+ Eigen::array<Eigen::DenseIndex, NDIMS> ret;
+ for (int i = 0; i < NDIMS; ++i) ret[i] = vec[i];
+ return ret;
+ }
+ void SetUnimplementedError(OpKernelContext* ctx);
+};
+
+// Coefficient-wise binary operations:
+// Device: E.g., CPUDevice, GPUDevice.
+// Functor: defined in cwise_functors.h. E.g., functor::add2.
+template <typename Device, typename Functor>
+class BinaryOp : public BinaryOpShared {
+ public:
+ typedef typename Functor::in_type Tin; // Input scalar data type.
+ typedef typename Functor::out_type Tout; // Output scalar data type.
+
+ explicit BinaryOp(OpKernelConstruction* ctx)
+ : BinaryOpShared(ctx, DataTypeToEnum<Tout>::v(),
+ DataTypeToEnum<Tin>::v()) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& in0 = ctx->input(0);
+ const Tensor& in1 = ctx->input(1);
+ // 'state': Shared helper not dependent on T to reduce code size
+ BinaryOpState state(ctx);
+ if (!ctx->status().ok()) return;
+ Tensor* out = state.out;
+ BCast* bcast = &state.bcast;
+ if (out->NumElements() == 0) {
+ return;
+ }
+ const int ndims = bcast->x_reshape().size();
+ if (ndims <= 1) {
+ if (in1.NumElements() == 1) {
+ // tensor op scalar
+ functor::BinaryFunctor<Device, Functor, 1>().Right(
+ ctx->eigen_device<Device>(), out->flat<Tout>(), in0.flat<Tin>(),
+ in1.scalar<Tin>());
+ return;
+ }
+ if (in0.NumElements() == 1) {
+ // scalar op tensor
+ functor::BinaryFunctor<Device, Functor, 1>().Left(
+ ctx->eigen_device<Device>(), out->flat<Tout>(), in0.scalar<Tin>(),
+ in1.flat<Tin>());
+ return;
+ }
+ functor::BinaryFunctor<Device, Functor, 1>()(
+ ctx->eigen_device<Device>(), out->flat<Tout>(), in0.flat<Tin>(),
+ in1.flat<Tin>());
+ return;
+ }
+
+ if (ndims == 2) {
+ functor::BinaryFunctor<Device, Functor, 2>().BCast(
+ ctx->eigen_device<Device>(),
+ out->shaped<Tout, 2>(bcast->result_shape()),
+ in0.shaped<Tin, 2>(bcast->x_reshape()),
+ ToIndexArray<2>(bcast->x_bcast()),
+ in1.shaped<Tin, 2>(bcast->y_reshape()),
+ ToIndexArray<2>(bcast->y_bcast()));
+ return;
+ }
+
+ if (ndims == 3) {
+ functor::BinaryFunctor<Device, Functor, 3>().BCast(
+ ctx->eigen_device<Device>(),
+ out->shaped<Tout, 3>(bcast->result_shape()),
+ in0.shaped<Tin, 3>(bcast->x_reshape()),
+ ToIndexArray<3>(bcast->x_bcast()),
+ in1.shaped<Tin, 3>(bcast->y_reshape()),
+ ToIndexArray<3>(bcast->y_bcast()));
+ return;
+ }
+
+ SetUnimplementedError(ctx);
+ }
+
+ private:
+};
+
+// Coefficient-wise unary operations:
+// Device: E.g., CPUDevice, GPUDevice.
+// Functor: defined in cwise_functors.h. E.g., functor::sqrt.
+template <typename Device, typename Functor>
+class UnaryOp : public OpKernel {
+ public:
+ typedef typename Functor::in_type Tin; // Input scalar data type.
+ typedef typename Functor::out_type Tout; // Output scalar data type.
+ // Tin may be different from Tout. E.g., abs: complex64 -> float
+
+ explicit UnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ auto in = DataTypeToEnum<Tin>::v();
+ auto out = DataTypeToEnum<Tout>::v();
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out}));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& inp = ctx->input(0);
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out));
+ functor::UnaryFunctor<Device, Functor>()(
+ ctx->eigen_device<Device>(), out->flat<Tout>(), inp.flat<Tin>());
+ }
+};
+
+// Coefficient-wise select operation.
+// Device: E.g., CPUDevice, GPUDevice.
+template <typename Device, typename T>
+class SelectOp : public OpKernel {
+ public:
+ explicit SelectOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ auto dt = DataTypeToEnum<T>::v();
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature({DT_BOOL, dt, dt}, {dt}));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& in0 = ctx->input(0);
+ const Tensor& in1 = ctx->input(1);
+ const Tensor& in2 = ctx->input(2);
+ if (!ctx->ValidateInputsAreSameShape(this)) return;
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
+ functor::SelectFunctor<Device, T> func;
+ func(ctx->eigen_device<Device>(), out->flat<T>(), in0.flat<bool>(),
+ in1.flat<T>(), in2.flat<T>());
+ }
+};
+
+namespace functor {
+
+// For CPUDevice, we do operations inline if the resulting tensor is
+// modestly sized.
+static bool DoInline(size_t size) { return size <= 32768; }
+
+template <typename D, typename OUT, typename RHS>
+void Assign(const D& d, OUT out, RHS rhs) {
+ if (DoInline(out.size())) {
+ out = rhs;
+ } else {
+ out.device(d) = rhs;
+ }
+}
+
+// Partial specialization of BinaryFunctor<Device=CPUDevice, Functor>.
+template <typename Functor, int NDIMS>
+struct BinaryFunctor<CPUDevice, Functor, NDIMS> {
+ void operator()(const CPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tin_type in0,
+ typename Functor::tin_type in1) {
+ Assign(d, out, in0.binaryExpr(in1, typename Functor::func()));
+ }
+
+ void Left(const CPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tscalar_type scalar,
+ typename Functor::tin_type in) {
+ typedef typename Functor::out_type Tout;
+ typedef typename Functor::in_type Tin;
+ typedef typename Functor::func Binary;
+ typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary;
+ Assign(d, out, in.unaryExpr(Unary(scalar.data())));
+ }
+
+ void Right(const CPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tin_type in,
+ typename Functor::tscalar_type scalar) {
+ typedef typename Functor::out_type Tout;
+ typedef typename Functor::in_type Tin;
+ typedef typename Functor::func Binary;
+ typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
+ Assign(d, out, in.unaryExpr(Unary(scalar.data())));
+ }
+
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ inline Eigen::DSizes<int, 2> NByOne(int n) {
+ return Eigen::DSizes<int, 2>(n, 1);
+ }
+ inline Eigen::DSizes<int, 2> OneByM(int m) {
+ return Eigen::DSizes<int, 2>(1, m);
+ }
+#else
+ inline Eigen::IndexList<int, Eigen::type2index<1>> NByOne(int n) {
+ Eigen::IndexList<int, Eigen::type2index<1>> ret;
+ ret.set(0, n);
+ return ret;
+ }
+ inline Eigen::IndexList<Eigen::type2index<1>, int> OneByM(int m) {
+ Eigen::IndexList<Eigen::type2index<1>, int> ret;
+ ret.set(1, m);
+ return ret;
+ }
+#endif
+
+ void BCast(const CPUDevice& dev,
+ typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
+ typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
+ typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
+ typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
+ typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1) {
+ typedef typename Functor::in_type T;
+ typename Functor::func func;
+ if ((NDIMS == 2) && Functor::use_bcast_optimization &&
+ use_bcast_optimization<T>::value) {
+ // Optimize for speed by using Eigen::type2index and avoid
+ // .broadcast() when we know its a no-op.
+ //
+ // Here, we need to handle 6 cases depending on how many "1"
+ // exist in in0 and in1's shapes (4 numbers in total). It's not
+ // possible that two shapes have more than 2 1s because those
+ // are simplified to NDIMS==1 case.
+ //
+ // Because this optimization increases the binary size for each
+ // Functor (+, -, *, /, <, <=, etc.), type and ndim combination.
+ // we only apply such optimization for selected ops/types/ndims.
+ //
+ // Because NDIMS, Functor::use_broadcast_optimization and
+ // use_broadcast_optimization<T> are compile-time constant, gcc
+ // does a decent job avoiding generating code when conditions
+ // are not met.
+ const int a = in0.dimension(0); // in0 is shape [a, b]
+ const int b = in0.dimension(1);
+ const int c = in1.dimension(0); // in1 is shape [c, d]
+ const int d = in1.dimension(1);
+ if ((a == 1) && (d == 1)) {
+ auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
+ auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
+ Assign(dev, out, lhs.binaryExpr(rhs, func));
+ return;
+ }
+ if ((b == 1) && (c == 1)) {
+ auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
+ auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
+ Assign(dev, out, lhs.binaryExpr(rhs, func));
+ return;
+ }
+ if (a == 1) {
+ auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c));
+ auto rhs = in1;
+ Assign(dev, out, lhs.binaryExpr(rhs, func));
+ return;
+ }
+ if (b == 1) {
+ auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d));
+ auto rhs = in1;
+ Assign(dev, out, lhs.binaryExpr(rhs, func));
+ return;
+ }
+ if (c == 1) {
+ auto lhs = in0;
+ auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a));
+ Assign(dev, out, lhs.binaryExpr(rhs, func));
+ return;
+ }
+ if (d == 1) {
+ auto lhs = in0;
+ auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b));
+ Assign(dev, out, lhs.binaryExpr(rhs, func));
+ return;
+ }
+
+ const bool bcast0_all_one = AllOne<NDIMS>(bcast0);
+ const bool bcast1_all_one = AllOne<NDIMS>(bcast1);
+ if (bcast0_all_one && !bcast1_all_one) {
+ auto lhs = in0; // No need to do broadcast for in0
+ auto rhs = in1.broadcast(bcast1);
+ Assign(dev, out, lhs.binaryExpr(rhs, func));
+ return;
+ }
+
+ if (!bcast0_all_one && bcast1_all_one) {
+ auto lhs = in0.broadcast(bcast0);
+ auto rhs = in1; // No need to do broadcast for in1
+ Assign(dev, out, lhs.binaryExpr(rhs, func));
+ return;
+ }
+ }
+
+ // Fallback path. Always work and probably slower.
+ auto lhs = in0.broadcast(bcast0);
+ auto rhs = in1.broadcast(bcast1);
+ Assign(dev, out, lhs.binaryExpr(rhs, func));
+ }
+};
+
+// Partial specialization of UnaryFunctor<Device=CPUDevice, Functor>.
+template <typename Functor>
+struct UnaryFunctor<CPUDevice, Functor> {
+ void operator()(const CPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tin_type in) {
+ Assign(d, out, in.unaryExpr(typename Functor::func()));
+ }
+};
+
+template <typename T>
+struct SelectFunctor<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<bool>::ConstFlat cond_flat,
+ typename TTypes<T>::ConstFlat then_flat,
+ typename TTypes<T>::ConstFlat else_flat) {
+ Assign(d, out, cond_flat.select(then_flat, else_flat));
+ }
+};
+
+} // end namespace functor
+
+#define REGISTER_SELECT(D, N, F, T) \
+ REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ SelectOp<D##Device, T>)
+
+#define REGISTER(OP, D, N, F, T) \
+ REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ OP<D##Device, F<T>>);
+
+// Macros to register kernels for multiple types (T0, T1, etc.) on
+// device type "D" (CPU or GPU) for operatin "N" (e.g., sqrt) using
+// the functor "F" (e.g., functor:sqrt).
+
+#ifdef __ANDROID__
+// On Android, only register the first type (float)
+#define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0)
+#define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0)
+#define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0)
+#define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) REGISTER(OP, D, N, F, T0)
+#define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0)
+#define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
+ REGISTER(OP, D, N, F, T0)
+#else // !__ANDROID__
+#define REGISTER2(OP, D, N, F, T0, T1) \
+ REGISTER(OP, D, N, F, T0) \
+ REGISTER(OP, D, N, F, T1)
+#define REGISTER3(OP, D, N, F, T0, T1, T2) \
+ REGISTER2(OP, D, N, F, T0, T1) \
+ REGISTER(OP, D, N, F, T2)
+#define REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
+ REGISTER2(OP, D, N, F, T0, T1) \
+ REGISTER2(OP, D, N, F, T2, T3)
+#define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \
+ REGISTER3(OP, D, N, F, T0, T1, T2) \
+ REGISTER2(OP, D, N, F, T3, T4)
+#define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \
+ REGISTER3(OP, D, N, F, T0, T1, T2) \
+ REGISTER3(OP, D, N, F, T3, T4, T5)
+#define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \
+ REGISTER4(OP, D, N, F, T0, T1, T2, T3) \
+ REGISTER3(OP, D, N, F, T4, T5, T6)
+#endif // __ANDROID__
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_
diff --git a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
new file mode 100644
index 0000000000..b0dc027144
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
@@ -0,0 +1,135 @@
+#if !GOOGLE_CUDA
+#error This file must only be included when building with Cuda support
+#endif
+
+#ifndef TENSORFLOW_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
+#define TENSORFLOW_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
+
+#define EIGEN_USE_GPU
+
+#include <complex>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/kernels/cwise_ops.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+#include "tensorflow/core/platform/logging.h"
+namespace tensorflow {
+namespace functor {
+
+typedef Eigen::GpuDevice GPUDevice;
+typedef std::complex<float> complex64;
+
+// Partial specialization of UnaryFunctor<Device=GPUDevice, Functor>.
+template <typename Functor>
+struct UnaryFunctor<GPUDevice, Functor> {
+ void operator()(const GPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tin_type in) {
+ out.device(d) = in.unaryExpr(typename Functor::func());
+ }
+};
+
+// Partial specialization of BinaryFunctor<Device=GPUDevice, Functor>.
+template <typename Functor, int NDIMS>
+struct BinaryFunctor<GPUDevice, Functor, NDIMS> {
+ void operator()(const GPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tin_type in0,
+ typename Functor::tin_type in1) {
+ out.device(d) = in0.binaryExpr(in1, typename Functor::func());
+ }
+
+ void Left(const GPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tscalar_type scalar,
+ typename Functor::tin_type in) {
+ typedef typename Functor::out_type Tout;
+ typedef typename Functor::in_type Tin;
+ typedef typename Functor::func Binary;
+ typedef typename Eigen::internal::scalar_left<Tout, Tin, Binary> Unary;
+ out.device(d) = in.unaryExpr(Unary(scalar.data()));
+ }
+
+ void Right(const GPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tin_type in,
+ typename Functor::tscalar_type scalar) {
+ typedef typename Functor::out_type Tout;
+ typedef typename Functor::in_type Tin;
+ typedef typename Functor::func Binary;
+ typedef typename Eigen::internal::scalar_right<Tout, Tin, Binary> Unary;
+ out.device(d) = in.unaryExpr(Unary(scalar.data()));
+ }
+
+ void BCast(const GPUDevice& d,
+ typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
+ typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
+ typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
+ typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
+ typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1) {
+ typedef typename Functor::in_type T;
+ typename Functor::func func;
+ if ((NDIMS == 2) && Functor::use_bcast_optimization &&
+ use_bcast_optimization<T>::value) {
+ const bool bcast0_all_one = AllOne<NDIMS>(bcast0);
+ const bool bcast1_all_one = AllOne<NDIMS>(bcast1);
+ if (bcast0_all_one && !bcast1_all_one) {
+ out.device(d) = in0.binaryExpr(in1.broadcast(bcast1), func);
+ return;
+ }
+ if (!bcast0_all_one && bcast1_all_one) {
+ out.device(d) = in0.broadcast(bcast0).binaryExpr(in1, func);
+ return;
+ }
+ }
+ out.device(d) =
+ in0.broadcast(bcast0).binaryExpr(in1.broadcast(bcast1), func);
+ }
+};
+
+template <typename T>
+struct SelectFunctor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<bool>::ConstFlat cond_flat,
+ typename TTypes<T>::ConstFlat then_flat,
+ typename TTypes<T>::ConstFlat else_flat) {
+ out.device(d) = cond_flat.select(then_flat, else_flat);
+ }
+};
+
+// Macros to explicitly instantiate kernels on GPU for multiple types
+// (T0, T1, etc.) for UnaryFunctor (e.g., functor:sqrt).
+#define DEFINE_UNARY1(F, T) template struct UnaryFunctor<GPUDevice, F<T> >
+#define DEFINE_UNARY2(F, T0, T1) \
+ DEFINE_UNARY1(F, T0); \
+ DEFINE_UNARY1(F, T1)
+#define DEFINE_UNARY3(F, T0, T1, T2) \
+ DEFINE_UNARY2(F, T0, T1); \
+ DEFINE_UNARY1(F, T2)
+#define DEFINE_UNARY4(F, T0, T1, T2, T3) \
+ DEFINE_UNARY2(F, T0, T1); \
+ DEFINE_UNARY2(F, T2, T3)
+#define DEFINE_UNARY5(F, T0, T1, T2, T3, T4) \
+ DEFINE_UNARY2(F, T0, T1); \
+ DEFINE_UNARY3(F, T2, T3, T4)
+
+// Macros to explicitly instantiate kernels on GPU for multiple types
+// (T0, T1, etc.) for BinaryFunctor.
+#define DEFINE_BINARY1(F, T) \
+ template struct BinaryFunctor<GPUDevice, F<T>, 1>; \
+ template struct BinaryFunctor<GPUDevice, F<T>, 2>; \
+ template struct BinaryFunctor<GPUDevice, F<T>, 3>
+#define DEFINE_BINARY2(F, T0, T1) \
+ DEFINE_BINARY1(F, T0); \
+ DEFINE_BINARY1(F, T1)
+#define DEFINE_BINARY3(F, T0, T1, T2) \
+ DEFINE_BINARY2(F, T0, T1); \
+ DEFINE_BINARY1(F, T2)
+#define DEFINE_BINARY4(F, T0, T1, T2, T3) \
+ DEFINE_BINARY2(F, T0, T1); \
+ DEFINE_BINARY2(F, T2, T3)
+#define DEFINE_BINARY5(F, T0, T1, T2, T3, T4) \
+ DEFINE_BINARY2(F, T0, T1); \
+ DEFINE_BINARY3(F, T2, T3, T4)
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
diff --git a/tensorflow/core/kernels/cwise_ops_test.cc b/tensorflow/core/kernels/cwise_ops_test.cc
new file mode 100644
index 0000000000..56af248117
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_ops_test.cc
@@ -0,0 +1,167 @@
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+// Creates a Graph which applies a unary "func" on a 3D float tensor
+// of "num" elements.
+static Graph* Unary(const string& func, int num) {
+ RequireDefaultOps();
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor data(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)}));
+ CHECK_GT(data.NumElements(), 0);
+ data.flat<float>().setRandom();
+ test::graph::Unary(g, func, test::graph::Constant(g, data), 0);
+ return g;
+}
+
+static int kRows = 100000;
+
+static int RowsAndColsArg(int r, int c) { return r * kRows + c; }
+static int RowsFromArg(int arg) { return (arg / kRows); }
+static int ColsFromArg(int arg) { return (arg % kRows); }
+
+#define BM_UNARY(DEVICE, FUNC) \
+ static void BM_##DEVICE##_##FUNC(int iters, int num) { \
+ const int64 tot = static_cast<int64>(iters) * num; \
+ testing::ItemsProcessed(tot); \
+ testing::BytesProcessed(tot * sizeof(float)); \
+ test::Benchmark(#DEVICE, Unary(#FUNC, num)).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_##FUNC)->Range(4 << 10, 1 << 20);
+
+BM_UNARY(cpu, Floor);
+BM_UNARY(gpu, Floor);
+
+// data func scalar.
+static Graph* BinaryScalar(int num, const string& func) {
+ RequireDefaultOps();
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor lhs(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)}));
+ lhs.flat<float>().setRandom();
+ Tensor rhs(DT_FLOAT, TensorShape({}));
+ rhs.flat<float>().setRandom();
+ test::graph::Binary(g, func, test::graph::Constant(g, lhs),
+ test::graph::Constant(g, rhs));
+ return g;
+}
+
+#define BM_BINARY_SCALAR(DEVICE, FUNC) \
+ static void BM_##DEVICE##_##FUNC##_scalar(int iters, int num) { \
+ const int64 tot = static_cast<int64>(iters) * num; \
+ testing::ItemsProcessed(tot); \
+ testing::BytesProcessed(tot * sizeof(float)); \
+ test::Benchmark(#DEVICE, BinaryScalar(num, #FUNC)).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_##FUNC##_scalar) \
+ ->Arg(4096) /* must >= 4096 */ \
+ ->Arg(32768) \
+ ->Arg(131072) \
+ ->Arg(1048576);
+
+BM_BINARY_SCALAR(cpu, Less);
+BM_BINARY_SCALAR(gpu, Less);
+BM_BINARY_SCALAR(cpu, Add);
+BM_BINARY_SCALAR(gpu, Add);
+#undef BM_BINARY_SCALAR
+
+static Graph* BiasAdd(int rows, int cols) {
+ RequireDefaultOps();
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor lhs(DT_FLOAT, TensorShape({rows, cols}));
+ lhs.flat<float>().setRandom();
+ TensorShape rhs_shape;
+ rhs_shape = TensorShape({cols});
+ Tensor rhs(DT_FLOAT, rhs_shape);
+ rhs.flat<float>().setRandom();
+ test::graph::Binary(g, "BiasAdd", test::graph::Constant(g, lhs),
+ test::graph::Constant(g, rhs));
+ return g;
+}
+
+#define BM_BIAS_ADD(DEVICE, R, C) \
+ static void BM_##DEVICE##_BiasAdd_R##R##_C##C(int iters, int arg) { \
+ const int rows = RowsFromArg(arg); \
+ const int cols = ColsFromArg(arg); \
+ const int64 tot = static_cast<int64>(iters) * rows * cols; \
+ testing::ItemsProcessed(tot); \
+ testing::BytesProcessed(tot * sizeof(float)); \
+ test::Benchmark(#DEVICE, BiasAdd(rows, cols)).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_BiasAdd_R##R##_C##C)->Arg(RowsAndColsArg(R, C));
+
+#define BM_BIAS_ADD_ALL(DEVICE) \
+ BM_BIAS_ADD(DEVICE, 512, 2048); \
+ BM_BIAS_ADD(DEVICE, 512, 4096); \
+ BM_BIAS_ADD(DEVICE, 2048, 512); \
+ BM_BIAS_ADD(DEVICE, 4096, 512);
+
+BM_BIAS_ADD_ALL(cpu);
+BM_BIAS_ADD_ALL(gpu);
+#undef BM_BIAS_ADD_ALL
+#undef BM_BIAS_ADD
+
+static Graph* BcastAdd(int rows, int cols, int dim) {
+ RequireDefaultOps();
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor lhs(DT_FLOAT, TensorShape({rows, cols}));
+ lhs.flat<float>().setRandom();
+ TensorShape rhs_shape;
+ if (dim == 0) {
+ rhs_shape = TensorShape({rows, 1});
+ } else {
+ rhs_shape = TensorShape({cols});
+ }
+ Tensor rhs(DT_FLOAT, rhs_shape);
+ rhs.flat<float>().setRandom();
+ test::graph::Binary(g, "Add", test::graph::Constant(g, lhs),
+ test::graph::Constant(g, rhs));
+ return g;
+}
+
+#define BM_BCAST_ADD_ROW(DEVICE, R, C) \
+ static void BM_##DEVICE##_BcastAddRow_R##R##_C##C(int iters, int arg) { \
+ const int rows = RowsFromArg(arg); \
+ const int cols = ColsFromArg(arg); \
+ const int64 tot = static_cast<int64>(iters) * rows * cols; \
+ testing::ItemsProcessed(tot); \
+ testing::BytesProcessed(tot * sizeof(float)); \
+ test::Benchmark(#DEVICE, BcastAdd(rows, cols, 0)).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_BcastAddRow_R##R##_C##C)->Arg(RowsAndColsArg(R, C));
+
+#define BM_BCAST_ADD_ROW_ALL(DEVICE) \
+ BM_BCAST_ADD_ROW(DEVICE, 512, 2048); \
+ BM_BCAST_ADD_ROW(DEVICE, 512, 4096); \
+ BM_BCAST_ADD_ROW(DEVICE, 2048, 512); \
+ BM_BCAST_ADD_ROW(DEVICE, 4096, 512);
+BM_BCAST_ADD_ROW_ALL(cpu);
+BM_BCAST_ADD_ROW_ALL(gpu);
+#undef BM_BCAST_ADD_ROW_ALL
+#undef BM_BCAST_ADD_ROW
+
+#define BM_BCAST_ADD_COL(DEVICE, R, C) \
+ static void BM_##DEVICE##_BcastAddCol_R##R##_C##C(int iters, int arg) { \
+ const int rows = RowsFromArg(arg); \
+ const int cols = ColsFromArg(arg); \
+ const int64 tot = static_cast<int64>(iters) * rows * cols; \
+ testing::ItemsProcessed(tot); \
+ testing::BytesProcessed(tot * sizeof(float)); \
+ test::Benchmark(#DEVICE, BcastAdd(rows, cols, 1)).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_BcastAddCol_R##R##_C##C)->Arg(RowsAndColsArg(R, C));
+
+#define BM_BCAST_ADD_COL_ALL(DEVICE) \
+ BM_BCAST_ADD_COL(DEVICE, 512, 2048); \
+ BM_BCAST_ADD_COL(DEVICE, 512, 4096); \
+ BM_BCAST_ADD_COL(DEVICE, 2048, 512); \
+ BM_BCAST_ADD_COL(DEVICE, 4096, 512);
+BM_BCAST_ADD_COL_ALL(cpu);
+BM_BCAST_ADD_COL_ALL(gpu);
+#undef BM_BCAST_ADD_COL_ALL
+#undef BM_BCAST_ADD_COL
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc
new file mode 100644
index 0000000000..0919bab96f
--- /dev/null
+++ b/tensorflow/core/kernels/decode_csv_op.cc
@@ -0,0 +1,222 @@
+// See docs in ../ops/parsing_ops.cc.
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+class DecodeCSVOp : public OpKernel {
+ public:
+ explicit DecodeCSVOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string delim;
+
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("OUT_TYPE", &out_type_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("field_delim", &delim));
+
+ OP_REQUIRES(ctx, delim.size() == 1,
+ errors::InvalidArgument("field_delim should be only 1 char"));
+
+ delim_ = delim[0];
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* records;
+ OpInputList record_defaults;
+
+ OP_REQUIRES_OK(ctx, ctx->input("records", &records));
+ OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults));
+
+ for (int i = 0; i < record_defaults.size(); ++i) {
+ OP_REQUIRES(ctx, record_defaults[i].NumElements() < 2,
+ errors::InvalidArgument(
+ "There should only be 1 default per field but field ", i,
+ " has ", record_defaults[i].NumElements()));
+ }
+
+ auto records_t = records->flat<string>();
+ int records_size = records_t.size();
+
+ OpOutputList output;
+ OP_REQUIRES_OK(ctx, ctx->output_list("output", &output));
+
+ for (size_t i = 0; i < out_type_.size(); ++i) {
+ Tensor* out = nullptr;
+ output.allocate(i, records->shape(), &out);
+ }
+
+ for (int i = 0; i < records_size; ++i) {
+ const StringPiece record(records_t(i));
+ std::vector<string> fields;
+ ExtractFields(ctx, record, &fields);
+ OP_REQUIRES(ctx, fields.size() == out_type_.size(),
+ errors::InvalidArgument("Expect ", out_type_.size(),
+ " fields but have ", fields.size(),
+ " in record ", i));
+
+ // Check each field in the record
+ for (size_t f = 0; f < out_type_.size(); ++f) {
+ const DataType& dtype = out_type_[f];
+ switch (dtype) {
+ case DT_INT32: {
+ // If this field is empty, check if default is given:
+ // If yes, use default value; Otherwise report error.
+ if (fields[f].empty()) {
+ OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
+ errors::InvalidArgument(
+ "Field ", f,
+ " is required but missing in record ", i, "!"));
+
+ output[f]->flat<int32>()(i) = record_defaults[f].flat<int32>()(0);
+ } else {
+ int32 value;
+ OP_REQUIRES(ctx, strings::safe_strto32(fields[f].c_str(), &value),
+ errors::InvalidArgument("Field ", f, " in record ", i,
+ " is not a valid int32: ",
+ fields[f]));
+ output[f]->flat<int32>()(i) = value;
+ }
+ break;
+ }
+ case DT_INT64: {
+ // If this field is empty, check if default is given:
+ // If yes, use default value; Otherwise report error.
+ if (fields[f].empty()) {
+ OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
+ errors::InvalidArgument(
+ "Field ", f,
+ " is required but missing in record ", i, "!"));
+
+ output[f]->flat<int64>()(i) = record_defaults[f].flat<int64>()(0);
+ } else {
+ int64 value;
+ OP_REQUIRES(ctx, strings::safe_strto64(fields[f].c_str(), &value),
+ errors::InvalidArgument("Field ", f, " in record ", i,
+ " is not a valid int64: ",
+ fields[f]));
+ output[f]->flat<int64>()(i) = value;
+ }
+ break;
+ }
+ case DT_FLOAT: {
+ // If this field is empty, check if default is given:
+ // If yes, use default value; Otherwise report error.
+ if (fields[f].empty()) {
+ OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
+ errors::InvalidArgument(
+ "Field ", f,
+ " is required but missing in record ", i, "!"));
+ output[f]->flat<float>()(i) = record_defaults[f].flat<float>()(0);
+ } else {
+ float value;
+ OP_REQUIRES(ctx, strings::safe_strtof(fields[f].c_str(), &value),
+ errors::InvalidArgument("Field ", f, " in record ", i,
+ " is not a valid float: ",
+ fields[f]));
+ output[f]->flat<float>()(i) = value;
+ }
+ break;
+ }
+ case DT_STRING: {
+ // If this field is empty, check if default is given:
+ // If yes, use default value; Otherwise report error.
+ if (fields[f].empty()) {
+ OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
+ errors::InvalidArgument(
+ "Field ", f,
+ " is required but missing in record ", i, "!"));
+ output[f]->flat<string>()(i) =
+ record_defaults[f].flat<string>()(0);
+ } else {
+ output[f]->flat<string>()(i) = fields[f];
+ }
+ break;
+ }
+ default:
+ OP_REQUIRES(ctx, false,
+ errors::InvalidArgument("csv: data type ", dtype,
+ " not supported in field ", f));
+ }
+ }
+ }
+ }
+
+ private:
+ std::vector<DataType> out_type_;
+ char delim_;
+
+ void ExtractFields(OpKernelContext* ctx, StringPiece input,
+ std::vector<string>* result) {
+ int current_idx = 0;
+ if (!input.empty()) {
+ while (static_cast<size_t>(current_idx) < input.size()) {
+ if (input[current_idx] == '\n' || input[current_idx] == '\r') {
+ current_idx++;
+ continue;
+ }
+
+ bool quoted = false;
+ if (input[current_idx] == '"') {
+ quoted = true;
+ current_idx++;
+ }
+
+ // This is the body of the field;
+ string field;
+ if (!quoted) {
+ while (static_cast<size_t>(current_idx) < input.size() &&
+ input[current_idx] != delim_) {
+ OP_REQUIRES(ctx, input[current_idx] != '"' &&
+ input[current_idx] != '\n' &&
+ input[current_idx] != '\r',
+ errors::InvalidArgument(
+ "Unquoted fields cannot have quotes/CRLFs inside"));
+ field += input[current_idx];
+ current_idx++;
+ }
+
+ // Go to next field or the end
+ current_idx++;
+ } else {
+ // Quoted field needs to be ended with '"' and delim or end
+ while (
+ (static_cast<size_t>(current_idx) < input.size() - 1) &&
+ (input[current_idx] != '"' || input[current_idx + 1] != delim_)) {
+ if (input[current_idx] != '"') {
+ field += input[current_idx];
+ current_idx++;
+ } else {
+ OP_REQUIRES(
+ ctx, input[current_idx + 1] == '"',
+ errors::InvalidArgument("Quote inside a string has to be "
+ "escaped by another quote"));
+ field += '"';
+ current_idx += 2;
+ }
+ }
+
+ OP_REQUIRES(
+ ctx,
+ input[current_idx] == '"' &&
+ (static_cast<size_t>(current_idx) == input.size() - 1 ||
+ input[current_idx + 1] == delim_),
+ errors::InvalidArgument("Quoted field has to end with quote "
+ "followed by delim or end"));
+
+ current_idx += 2;
+ }
+
+ result->push_back(field);
+ }
+
+ // Check if the last field is missing
+ if (input[input.size() - 1] == delim_) result->push_back(string());
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("DecodeCSV").Device(DEVICE_CPU), DecodeCSVOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/decode_jpeg_op.cc b/tensorflow/core/kernels/decode_jpeg_op.cc
new file mode 100644
index 0000000000..e41d3f3e11
--- /dev/null
+++ b/tensorflow/core/kernels/decode_jpeg_op.cc
@@ -0,0 +1,72 @@
+// See docs in ../ops/image_ops.cc
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/lib/jpeg/jpeg_mem.h"
+
+namespace tensorflow {
+
+// Decode the contents of a JPEG file
+class DecodeJpegOp : public OpKernel {
+ public:
+ explicit DecodeJpegOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("channels", &flags_.components));
+ OP_REQUIRES(context, flags_.components == 0 || flags_.components == 1 ||
+ flags_.components == 3,
+ errors::InvalidArgument("channels must be 0, 1, or 3, got ",
+ flags_.components));
+ OP_REQUIRES_OK(context, context->GetAttr("ratio", &flags_.ratio));
+ OP_REQUIRES(context, flags_.ratio == 1 || flags_.ratio == 2 ||
+ flags_.ratio == 4 || flags_.ratio == 8,
+ errors::InvalidArgument("ratio must be 1, 2, 4, or 8, got ",
+ flags_.ratio));
+ OP_REQUIRES_OK(
+ context, context->GetAttr("fancy_upscaling", &flags_.fancy_upscaling));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("try_recover_truncated",
+ &flags_.try_recover_truncated_jpeg));
+ OP_REQUIRES_OK(context, context->GetAttr("acceptable_fraction",
+ &flags_.min_acceptable_fraction));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& contents = context->input(0);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()),
+ errors::InvalidArgument("contents must be scalar, got shape ",
+ contents.shape().ShortDebugString()));
+ const StringPiece input = contents.scalar<string>()();
+ OP_REQUIRES(context, input.size() <= std::numeric_limits<int>::max(),
+ errors::InvalidArgument("JPEG contents are too large for int: ",
+ input.size()));
+
+ // Decode image, allocating tensor once the image size is known
+ Tensor* output = NULL;
+ OP_REQUIRES(
+ context,
+ jpeg::Uncompress(
+ input.data(), input.size(), flags_, NULL,
+ [=, &output](int width, int height, int channels) -> uint8* {
+ Status status(context->allocate_output(
+ 0, TensorShape({height, width, channels}), &output));
+ if (!status.ok()) {
+ VLOG(1) << status;
+ context->SetStatus(status);
+ return nullptr;
+ }
+ return output->flat<uint8>().data();
+ }),
+ errors::InvalidArgument("Invalid JPEG data, size ", input.size()));
+ }
+
+ private:
+ jpeg::UncompressFlags flags_;
+};
+REGISTER_KERNEL_BUILDER(Name("DecodeJpeg").Device(DEVICE_CPU), DecodeJpegOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/decode_png_op.cc b/tensorflow/core/kernels/decode_png_op.cc
new file mode 100644
index 0000000000..e8071526f9
--- /dev/null
+++ b/tensorflow/core/kernels/decode_png_op.cc
@@ -0,0 +1,69 @@
+// See docs in ../ops/image_ops.cc
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/lib/png/png_io.h"
+
+namespace tensorflow {
+
+// Decode the contents of a PNG file
+class DecodePngOp : public OpKernel {
+ public:
+ explicit DecodePngOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_));
+ OP_REQUIRES(context, channels_ == 0 || channels_ == 1 || channels_ == 3 ||
+ channels_ == 4,
+ errors::InvalidArgument("channels must be 0, 1, 3, or 4, got ",
+ channels_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& contents = context->input(0);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()),
+ errors::InvalidArgument("contents must be scalar, got shape ",
+ contents.shape().ShortDebugString()));
+
+ // Start decoding image to get shape details
+ const StringPiece data = contents.scalar<string>()();
+ png::DecodeContext decode;
+ OP_REQUIRES(
+ context, png::CommonInitDecode(data, channels_, 8, &decode),
+ errors::InvalidArgument("Invalid PNG header, data size ", data.size()));
+
+ // Verify that width and height don't overflow int
+ const int width = decode.width;
+ const int height = decode.height;
+ if (width != static_cast<int64>(decode.width) ||
+ height != static_cast<int64>(decode.height)) {
+ png::CommonFreeDecode(&decode);
+ OP_REQUIRES(context, false,
+ errors::InvalidArgument("PNG size too large for int: ",
+ decode.width, " by ", decode.height));
+ }
+
+ // Allocate tensor
+ Tensor* output = nullptr;
+ const auto status = context->allocate_output(
+ 0, TensorShape({height, width, decode.channels}), &output);
+ if (!status.ok()) png::CommonFreeDecode(&decode);
+ OP_REQUIRES_OK(context, status);
+
+ // Finish decoding image
+ OP_REQUIRES(
+ context, png::CommonFinishDecode(output->flat<uint8>().data(),
+ decode.channels * width, &decode),
+ errors::InvalidArgument("Invalid PNG data, size ", data.size()));
+ }
+
+ private:
+ int channels_;
+};
+REGISTER_KERNEL_BUILDER(Name("DecodePng").Device(DEVICE_CPU), DecodePngOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/decode_raw_op.cc b/tensorflow/core/kernels/decode_raw_op.cc
new file mode 100644
index 0000000000..ef24c333a4
--- /dev/null
+++ b/tensorflow/core/kernels/decode_raw_op.cc
@@ -0,0 +1,90 @@
+// See docs in ../ops/parse_ops.cc.
+
+#include <algorithm>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+template <typename T>
+class DecodeRawOp : public OpKernel {
+ public:
+ explicit DecodeRawOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("little_endian", &little_endian_));
+ OP_REQUIRES_OK(context, context->GetAttr("out_type", &out_type_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const auto& input = context->input(0);
+ int str_size = -1;
+ auto flat_in = input.flat<string>();
+ for (int i = 0; i < flat_in.size(); ++i) {
+ const string& in_str = flat_in(i);
+ if (str_size == -1) {
+ str_size = in_str.size();
+ } else {
+ OP_REQUIRES(context, str_size == in_str.size(),
+ errors::InvalidArgument(
+ "DecodeRaw requires input strings to all be the same "
+ "size, but element ",
+ i, " has size ", str_size, " != ", in_str.size()));
+ }
+ }
+ TensorShape out_shape = input.shape();
+ if (str_size == -1) { // Empty input
+ out_shape.AddDim(1);
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output("output", out_shape,
+ &output_tensor));
+ return;
+ }
+ OP_REQUIRES(
+ context, str_size % sizeof(T) == 0,
+ errors::InvalidArgument("Input to DecodeRaw has length ", str_size,
+ " that is not a multiple of ", sizeof(T),
+ ", the size of ", DataTypeString(out_type_)));
+ const int added_dim = str_size / sizeof(T);
+ out_shape.AddDim(added_dim);
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(
+ context, context->allocate_output("output", out_shape, &output_tensor));
+ auto out = output_tensor->flat_inner_dims<T>();
+ DCHECK_EQ(flat_in.size(), out.dimensions()[0]);
+ OP_REQUIRES(
+ context,
+ little_endian_ == ::tensorflow::port::kLittleEndian || sizeof(T) == 1,
+ errors::Unimplemented("Unimplemented support for little_endian=",
+ little_endian_ ? "true" : "false"));
+ // Endianness matches, so just copy each string byte-for-byte.
+ T* out_data = out.data();
+ for (int i = 0; i < flat_in.size(); ++i) {
+ const T* in_data = reinterpret_cast<const T*>(flat_in(i).data());
+ memcpy(out_data, in_data, str_size);
+ out_data += added_dim;
+ }
+ }
+
+ private:
+ bool little_endian_;
+ DataType out_type_;
+};
+
+#define REGISTER(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("DecodeRaw").Device(DEVICE_CPU).TypeConstraint<type>("out_type"), \
+ DecodeRawOp<type>)
+
+REGISTER(float);
+REGISTER(double);
+REGISTER(int32);
+REGISTER(uint8);
+REGISTER(int16);
+REGISTER(int8);
+REGISTER(int64);
+
+#undef REGISTER
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc
new file mode 100644
index 0000000000..f56c37b4ef
--- /dev/null
+++ b/tensorflow/core/kernels/dense_update_ops.cc
@@ -0,0 +1,136 @@
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/assign_op.h"
+#include "tensorflow/core/kernels/dense_update_ops.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+template <typename Device, typename T>
+class AssignOpT : public AssignOp {
+ public:
+ using AssignOp::AssignOp;
+
+ void Copy(OpKernelContext* context, Tensor* lhs, const Tensor& rhs) override {
+ functor::DenseUpdate<Device, T, ASSIGN> copy;
+ copy(context->eigen_device<Device>(), lhs->flat<T>(), rhs.flat<T>());
+ }
+};
+
+// TODO(jeff): Get rid of use_exclusive_lock_ option
+template <typename Device, typename T, DenseUpdateType OP>
+class DenseUpdateOp : public OpKernel {
+ public:
+ explicit DenseUpdateOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr("use_locking", &use_exclusive_lock_));
+ const DataType dt = DataTypeToEnum<T>::v();
+ OP_REQUIRES_OK(context, context->MatchSignature({MakeRefType(dt), dt},
+ {MakeRefType(dt)}));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // We always return the input ref.
+ context->forward_ref_input_to_ref_output(0, 0);
+
+ if (use_exclusive_lock_) {
+ mutex_lock l(*context->input_ref_mutex(0));
+ DoUpdate(context);
+ } else {
+ DoUpdate(context);
+ }
+ }
+
+ private:
+ void DoUpdate(OpKernelContext* context) {
+ Tensor Tparams = context->mutable_input(0, use_exclusive_lock_);
+ const Tensor& Tupdate = context->input(1);
+ OP_REQUIRES(context, Tparams.IsInitialized(),
+ errors::FailedPrecondition("Attempting to use uninitialized "
+ "parameters: ",
+ def().input(0)));
+ OP_REQUIRES(
+ context, Tparams.IsSameSize(Tupdate),
+ errors::InvalidArgument("Parameters and update must be the same size"));
+
+ functor::DenseUpdate<Device, T, OP> update_functor;
+ update_functor(context->eigen_device<Device>(), Tparams.flat<T>(),
+ Tupdate.flat<T>());
+ }
+
+ bool use_exclusive_lock_;
+};
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Assign").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ AssignOpT<CPUDevice, type>);
+
+TF_CALL_ALL_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+// Only register 'Assign' on GPU for the subset of types also supported by
+// 'Variable' (see variable_ops.cc.)
+#define REGISTER_GPU_KERNELS(type) \
+ namespace functor { \
+ template <> \
+ void DenseUpdate<GPUDevice, type, ASSIGN>::operator()( \
+ const GPUDevice& d, typename TTypes<type>::Flat lhs, \
+ typename TTypes<type>::ConstFlat rhs); \
+ extern template struct DenseUpdate<GPUDevice, type, ASSIGN>; \
+ } \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Assign").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ AssignOpT<GPUDevice, type>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
+#undef REGISTER_GPU_KERNELS
+#endif // GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("AssignAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ DenseUpdateOp<CPUDevice, type, DenseUpdateType::ADD>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("AssignSub").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ DenseUpdateOp<CPUDevice, type, DenseUpdateType::SUB>);
+
+TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC_FOR_OP(T, OP) \
+ template <> \
+ void DenseUpdate<GPUDevice, T, OP>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::Flat params, \
+ typename TTypes<T>::ConstFlat update); \
+ extern template struct DenseUpdate<GPUDevice, T, OP>
+#define DECLARE_GPU_SPEC(T) \
+ DECLARE_GPU_SPEC_FOR_OP(T, DenseUpdateType::ADD); \
+ DECLARE_GPU_SPEC_FOR_OP(T, DenseUpdateType::SUB)
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
+#undef DECLARE_GPU_SPEC
+#undef DECLARE_GPU_SPEC_FOR_OP
+} // namespace functor
+
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("AssignAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ DenseUpdateOp<GPUDevice, type, DenseUpdateType::ADD>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("AssignSub").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ DenseUpdateOp<GPUDevice, type, DenseUpdateType::SUB>);
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
+#undef REGISTER_GPU_KERNELS
+#endif // end GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/dense_update_ops.h b/tensorflow/core/kernels/dense_update_ops.h
new file mode 100644
index 0000000000..d32c9a4af2
--- /dev/null
+++ b/tensorflow/core/kernels/dense_update_ops.h
@@ -0,0 +1,43 @@
+#ifndef TENSORFLOW_KERNELS_DENSE_UPDATE_OPS_H_
+#define TENSORFLOW_KERNELS_DENSE_UPDATE_OPS_H_
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+enum DenseUpdateType { ADD, SUB, ASSIGN };
+
+namespace functor {
+
+template <typename Device, typename T, DenseUpdateType OP>
+struct DenseUpdate;
+
+template <typename Device, typename T>
+struct DenseUpdate<Device, T, ADD> {
+ void operator()(const Device& d, typename TTypes<T>::Flat params,
+ typename TTypes<T>::ConstFlat update) {
+ params.device(d) += update;
+ }
+};
+
+template <typename Device, typename T>
+struct DenseUpdate<Device, T, SUB> {
+ void operator()(const Device& d, typename TTypes<T>::Flat params,
+ typename TTypes<T>::ConstFlat update) {
+ params.device(d) -= update;
+ }
+};
+
+template <typename Device, typename T>
+struct DenseUpdate<Device, T, ASSIGN> {
+ void operator()(const Device& d, typename TTypes<T>::Flat params,
+ typename TTypes<T>::ConstFlat update) {
+ params.device(d) = update;
+ }
+};
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_DENSE_UPDATE_OPS_H_
diff --git a/tensorflow/core/kernels/dense_update_ops_gpu.cu.cc b/tensorflow/core/kernels/dense_update_ops_gpu.cu.cc
new file mode 100644
index 0000000000..8e80901c71
--- /dev/null
+++ b/tensorflow/core/kernels/dense_update_ops_gpu.cu.cc
@@ -0,0 +1,22 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/dense_update_ops.h"
+
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#define DEFINE_GPU_KERNELS(T) \
+ template struct functor::DenseUpdate<GPUDevice, T, ADD>; \
+ template struct functor::DenseUpdate<GPUDevice, T, SUB>; \
+ template struct functor::DenseUpdate<GPUDevice, T, ASSIGN>;
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
+#undef DEFINE_GPU_KERNELS
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/determinant_op.cc b/tensorflow/core/kernels/determinant_op.cc
new file mode 100644
index 0000000000..d34aab7a44
--- /dev/null
+++ b/tensorflow/core/kernels/determinant_op.cc
@@ -0,0 +1,66 @@
+// See docs in ../ops/linalg_ops.cc.
+#include <cmath>
+
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/linalg_ops_common.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "third_party/eigen3/Eigen/LU"
+
+namespace tensorflow {
+
+template <class Scalar, bool SupportsBatchOperationT>
+class DeterminantOp : public LinearAlgebraOp<Scalar, SupportsBatchOperationT> {
+ public:
+ explicit DeterminantOp(OpKernelConstruction* context)
+ : LinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {}
+ ~DeterminantOp() override {}
+
+ TensorShape GetOutputMatrixShape(
+ const TensorShape& input_matrix_shape) override {
+ return TensorShape({});
+ }
+
+ int64 GetCostPerUnit(const TensorShape& input_matrix_shape) override {
+ const int64 rows = input_matrix_shape.dim_size(0);
+ if (rows > (1LL << 20)) {
+ // A big number to cap the cost in case overflow.
+ return kint32max;
+ } else {
+ return rows * rows * rows;
+ }
+ }
+
+ using typename LinearAlgebraOp<Scalar, SupportsBatchOperationT>::MatrixMap;
+ using
+ typename LinearAlgebraOp<Scalar, SupportsBatchOperationT>::ConstMatrixMap;
+
+ void ComputeMatrix(OpKernelContext* context, const ConstMatrixMap& input,
+ MatrixMap* output) override {
+ OP_REQUIRES(context, input.rows() == input.cols(),
+ errors::InvalidArgument("Input matrix must be square."));
+ Scalar determinant;
+ if (input.rows() == 0) {
+ // An empty matrix' determinant is defined to be 1. See
+ // wikipedia.
+ determinant = 1;
+ } else {
+ determinant = input.determinant();
+ }
+ OP_REQUIRES(context, std::isfinite(determinant),
+ errors::Internal("The determinant is not finite."));
+ (*output)(0, 0) = determinant;
+ }
+};
+
+REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<float, false>), float);
+REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<double, false>), double);
+REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<float, true>),
+ float);
+REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<double, true>),
+ double);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/diag_op.cc b/tensorflow/core/kernels/diag_op.cc
new file mode 100644
index 0000000000..83e39d33a9
--- /dev/null
+++ b/tensorflow/core/kernels/diag_op.cc
@@ -0,0 +1,93 @@
+// See docs in ../ops/array_ops.cc
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace {
+template <typename T, size_t NumDims, size_t DoubleNumDims>
+class DiagonalGenerator {
+ public:
+ explicit DiagonalGenerator(const Tensor& diagonal) : diagonal_(diagonal) {
+ static_assert(DoubleNumDims == 2 * NumDims,
+ "The second size must be the double of the first size.");
+ CHECK_EQ(diagonal.dims(), NumDims);
+ }
+ T operator()(
+ const Eigen::array<Eigen::DenseIndex, DoubleNumDims>& coordinates) const {
+ Eigen::array<Eigen::DenseIndex, NumDims> index;
+ for (int i = 0; i < NumDims; ++i) {
+ if (coordinates[i] != coordinates[NumDims + i]) {
+ return T(0);
+ }
+ index[i] = coordinates[i];
+ }
+ return diagonal_.tensor<T, NumDims>()(index);
+ }
+
+ private:
+ Tensor diagonal_;
+};
+} // namespace
+
+// Generate the diagonal tensor with the diagonal set to the input tensor.
+// It only allows up to rank 3 input tensor, so the output tensor is up to
+// rank 6.
+template <typename T>
+class DiagOp : public OpKernel {
+ public:
+ explicit DiagOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& diagonal = context->input(0);
+ const int num_dims = diagonal.dims();
+ OP_REQUIRES(context, 1 <= num_dims,
+ errors::InvalidArgument(
+ "The rank of the diagonal should be between 1 and 3."));
+ OP_REQUIRES(context, 3 >= num_dims,
+ errors::InvalidArgument(
+ "The rank of the diagonal should be between 1 and 3."));
+ TensorShape out_shape;
+ for (int i = 0; i < num_dims; ++i) {
+ out_shape.AddDim(diagonal.dim_size(i));
+ }
+ for (int i = 0; i < num_dims; ++i) {
+ out_shape.AddDim(diagonal.dim_size(i));
+ }
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, out_shape, &output_tensor));
+ switch (num_dims) {
+ case 1:
+ output_tensor->tensor<T, 2>() = output_tensor->tensor<T, 2>().generate(
+ DiagonalGenerator<T, 1, 2>(diagonal));
+ break;
+ case 2:
+ output_tensor->tensor<T, 4>() = output_tensor->tensor<T, 4>().generate(
+ DiagonalGenerator<T, 2, 4>(diagonal));
+ break;
+ case 3:
+ output_tensor->tensor<T, 6>() = output_tensor->tensor<T, 6>().generate(
+ DiagonalGenerator<T, 3, 6>(diagonal));
+ break;
+ default:
+ context->SetStatus(errors::Unimplemented(
+ "Diagonal of rank ", num_dims, " tensor is not supported yet."));
+ return;
+ }
+ }
+};
+
+#define REGISTER_DIAGOP(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Diag").Device(DEVICE_CPU).TypeConstraint<T>("T"), DiagOp<T>)
+
+REGISTER_DIAGOP(double);
+REGISTER_DIAGOP(float);
+REGISTER_DIAGOP(int32);
+REGISTER_DIAGOP(int64);
+
+#undef REGISTER_DIAGOP
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/dynamic_partition_op.cc b/tensorflow/core/kernels/dynamic_partition_op.cc
new file mode 100644
index 0000000000..f1b44861b5
--- /dev/null
+++ b/tensorflow/core/kernels/dynamic_partition_op.cc
@@ -0,0 +1,154 @@
+// See docs in ../ops/data_flow_ops.cc.
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+// Shared code that is not dependent on the type of T. We do this to reduce
+// code size by not duplicating all this for all T (float, double, int32, etc.)
+class DynamicPartitionOp_Shared : public OpKernel {
+ public:
+ explicit DynamicPartitionOp_Shared(OpKernelConstruction* c) : OpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("num_partitions", &num_partitions_));
+ // QUESTION: It'd be nice to support DT_INT16, DT_UINT8, etc.
+ // to input[1]. Should we have the framework do some sort of
+ // integer promotion automatically, or should that be something
+ // that users have to do explicitly with a conversion operator
+ // in the graph?
+ }
+
+ void ValidateAndAllocateOutputs(OpKernelContext* c, const Tensor** data,
+ const Tensor** partitions,
+ OpOutputList* Tout) {
+ OP_REQUIRES_OK(c, c->input("data", data));
+ OP_REQUIRES_OK(c, c->input("partitions", partitions));
+ OP_REQUIRES(c, TensorShapeUtils::StartsWith((*data)->shape(),
+ (*partitions)->shape()),
+ errors::InvalidArgument(
+ "data.shape must start with partitions.shape, ",
+ "got data.shape = ", (*data)->shape().ShortDebugString(),
+ ", partitions.shape = ",
+ (*partitions)->shape().ShortDebugString()));
+
+ // Count how many occurrences of each partition id we have in partitions
+ gtl::InlinedVector<int, 32> partition_count(num_partitions_);
+ auto e_partitions = (*partitions)->flat<int32>();
+ const int64 N = e_partitions.dimension(0);
+ for (int64 i = 0; i < N; i++) {
+ const int32 p = e_partitions(i);
+ OP_REQUIRES(c, p >= 0 && p < num_partitions_,
+ errors::InvalidArgument(
+ "partitions", SliceString((*partitions)->shape(), i),
+ " = ", p, " is not in [0, ", num_partitions_, ")"));
+ partition_count[p]++;
+ }
+
+ // Allocate output tensors of the right size
+ OP_REQUIRES_OK(c, c->output_list("outputs", Tout));
+ for (int p = 0; p < num_partitions_; p++) {
+ TensorShape shape;
+ shape.AddDim(partition_count[p]);
+ for (int i = (*partitions)->dims(); i < (*data)->dims(); i++) {
+ shape.AddDim((*data)->dim_size(i));
+ }
+ Tensor* out;
+ OP_REQUIRES_OK(c, Tout->allocate(p, shape, &out));
+ }
+ }
+
+ protected:
+ int num_partitions_;
+
+ static string SliceString(const TensorShape& shape, const int64 flat) {
+ // Special case rank 0 and 1
+ const int dims = shape.dims();
+ if (dims == 0) return "";
+ if (dims == 1) return strings::StrCat("[", flat, "]");
+
+ // Compute strides
+ gtl::InlinedVector<int64, 32> strides(dims);
+ strides.back() = 1;
+ for (int i = dims - 2; i >= 0; i--) {
+ strides[i] = strides[i + 1] * shape.dim_size(i + 1);
+ }
+
+ // Unflatten index
+ int64 left = flat;
+ string result;
+ for (int i = 0; i < dims; i++) {
+ strings::StrAppend(&result, i ? "," : "[", left / strides[i]);
+ left %= strides[i];
+ }
+ strings::StrAppend(&result, "]");
+ return result;
+ }
+};
+
+template <class T>
+class DynamicPartitionOp : public DynamicPartitionOp_Shared {
+ public:
+ explicit DynamicPartitionOp(OpKernelConstruction* c)
+ : DynamicPartitionOp_Shared(c) {}
+ void Compute(OpKernelContext* c) override {
+ const Tensor* data;
+ const Tensor* partitions;
+ OpOutputList outputs;
+ ValidateAndAllocateOutputs(c, &data, &partitions, &outputs);
+ if (!c->status().ok()) return;
+ if (num_partitions_ == 0 || data->NumElements() == 0) return;
+
+ auto e_partitions = partitions->flat<int32>();
+ const int64 N = e_partitions.dimension(0);
+ gtl::InlinedVector<int, 32> output_index(num_partitions_);
+
+ if (partitions->dims() == data->dims()) {
+ // Walk through data and copy the data to the appropriate output tensor
+ const auto data_flat = data->flat<T>();
+ std::vector<Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>,
+ Eigen::Aligned> > out_vec;
+ for (int p = 0; p < num_partitions_; p++) {
+ out_vec.push_back(outputs[p]->vec<T>());
+ }
+ for (int64 i = 0; i < N; i++) {
+ const int32 p = e_partitions(i);
+ out_vec[p](output_index[p]) = data_flat(i);
+ output_index[p]++;
+ }
+ } else {
+ // If data has extra dimensions, use Eigen slices
+ std::vector<Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
+ Eigen::Aligned> > out_flat;
+ for (int p = 0; p < num_partitions_; p++) {
+ out_flat.push_back(outputs[p]->flat_outer_dims<T>());
+ }
+
+ // Walk through data and copy the data to the appropriate output tensor
+ const int64 slice_size = data->NumElements() / N;
+ const auto data_flat = data->shaped<T, 2>({N, slice_size});
+ Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, slice_size);
+ for (int64 i = 0; i < N; i++) {
+ const int32 p = e_partitions(i);
+ // outputs[p][output_index[p]++] = data[i]
+ Eigen::DSizes<Eigen::DenseIndex, 2> out_indices(output_index[p], 0);
+ Eigen::DSizes<Eigen::DenseIndex, 2> data_indices(i, 0);
+ out_flat[p].slice(out_indices, sizes) =
+ data_flat.slice(data_indices, sizes);
+ output_index[p]++;
+ }
+ }
+ }
+};
+
+#define REGISTER_DYNAMIC_PARTITION(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("DynamicPartition").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ DynamicPartitionOp<T>)
+
+TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_PARTITION);
+#undef REGISTER_DYNAMIC_PARTITION
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/dynamic_partition_op_test.cc b/tensorflow/core/kernels/dynamic_partition_op_test.cc
new file mode 100644
index 0000000000..b0e5e7deb0
--- /dev/null
+++ b/tensorflow/core/kernels/dynamic_partition_op_test.cc
@@ -0,0 +1,145 @@
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace {
+
+class DynamicPartitionOpTest : public OpsTestBase {
+ protected:
+ void MakeOp() {
+ RequireDefaultOps();
+ ASSERT_OK(NodeDefBuilder("myop", "DynamicPartition")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_INT32))
+ .Attr("num_partitions", 4)
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(DynamicPartitionOpTest, Simple_OneD) {
+ MakeOp();
+
+ // Similar to how we would use this to split embedding ids to be looked up
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({6}), {0, 13, 2, 39, 4, 17});
+ AddInputFromArray<int32>(TensorShape({6}), {0, 0, 2, 3, 2, 1});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the output sizes
+ { // Output 0
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&expected, {0, 13});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+ }
+ { // Output 1
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1}));
+ test::FillValues<float>(&expected, {17});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(1));
+ }
+ { // Output 2
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&expected, {2, 4});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(2));
+ }
+ { // Output 3
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1}));
+ test::FillValues<float>(&expected, {39});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(3));
+ }
+}
+
+TEST_F(DynamicPartitionOpTest, Simple_TwoD) {
+ MakeOp();
+
+ // Feed and run
+ AddInputFromArray<float>(
+ TensorShape({6, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
+ AddInputFromArray<int32>(TensorShape({6}), {0, 0, 2, 3, 2, 1});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the output sizes
+ { // Output 0
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
+ test::FillValues<float>(&expected, {0, 1, 2, 3, 4, 5});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+ }
+ { // Output 1
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 3}));
+ test::FillValues<float>(&expected, {15, 16, 17});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(1));
+ }
+ { // Output 2
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
+ test::FillValues<float>(&expected, {6, 7, 8, 12, 13, 14});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(2));
+ }
+ { // Output 3
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 3}));
+ test::FillValues<float>(&expected, {9, 10, 11});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(3));
+ }
+}
+
+TEST_F(DynamicPartitionOpTest, SomeOutputsEmpty) {
+ MakeOp();
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({6}), {0, 13, 2, 39, 4, 17});
+ AddInputFromArray<int32>(TensorShape({6}), {0, 0, 2, 2, 0, 2});
+ ASSERT_OK(RunOpKernel());
+
+ TensorShape empty_one_dim;
+ empty_one_dim.AddDim(0);
+ Tensor expected_empty(allocator(), DT_FLOAT, empty_one_dim);
+
+ // Check the output sizes
+ { // Output 0
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
+ test::FillValues<float>(&expected, {0, 13, 4});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+ }
+ { // Output 1
+ test::ExpectTensorEqual<float>(expected_empty, *GetOutput(1));
+ }
+ { // Output 2
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({3}));
+ test::FillValues<float>(&expected, {2, 39, 17});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(2));
+ }
+ { // Output 3
+ test::ExpectTensorEqual<float>(expected_empty, *GetOutput(3));
+ }
+}
+
+TEST_F(DynamicPartitionOpTest, Error_IndexOutOfRange) {
+ MakeOp();
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
+ AddInputFromArray<int32>(TensorShape({5}), {0, 2, 99, 2, 2});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(
+ StringPiece(s.ToString()).contains("partitions[2] = 99 is not in [0, 4)"))
+ << s;
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc
new file mode 100644
index 0000000000..a5623685fb
--- /dev/null
+++ b/tensorflow/core/kernels/dynamic_stitch_op.cc
@@ -0,0 +1,158 @@
+// See docs in ../ops/data_flow_ops.cc.
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+template <class T>
+class DynamicStitchOp : public OpKernel {
+ public:
+ explicit DynamicStitchOp(OpKernelConstruction* c) : OpKernel(c) {
+ // Compute expected input signature
+ const DataType dt = DataTypeToEnum<T>::v();
+ const int n = c->num_inputs() / 2;
+ DataTypeVector expected;
+ for (int i = 0; i < n; i++) {
+ expected.push_back(DT_INT32);
+ }
+ for (int i = 0; i < n; i++) {
+ expected.push_back(dt);
+ }
+ OP_REQUIRES_OK(c, c->MatchSignature(expected, {dt}));
+ OP_REQUIRES(
+ c, c->num_inputs() > 0,
+ errors::InvalidArgument("DynamicStitchOp: Must have some inputs"));
+ OP_REQUIRES(c, c->num_inputs() % 2 == 0,
+ errors::InvalidArgument(
+ "DynamicStitchOp: Must have even number of arguments"));
+ }
+
+ void Compute(OpKernelContext* c) override {
+ // Find maximum index in the indices vectors
+ OpInputList indices_inputs;
+ OP_REQUIRES_OK(c, c->input_list("indices", &indices_inputs));
+
+ int32 max_index = -1;
+ for (const Tensor& indices : indices_inputs) {
+ Eigen::Tensor<int32, 0, Eigen::RowMajor> m =
+ indices.flat<int32>().maximum();
+ max_index = std::max(m(), max_index);
+ }
+ const int first_dim_size = max_index + 1;
+
+ // Validate that data[i].shape = indices[i].shape + constant
+ OpInputList data_inputs;
+ OP_REQUIRES_OK(c, c->input_list("data", &data_inputs));
+ const Tensor& data0 = data_inputs[0];
+ const Tensor& indices0 = indices_inputs[0];
+ for (int input_num = 0; input_num < indices_inputs.size(); input_num++) {
+ const Tensor& indices = indices_inputs[input_num];
+ const Tensor& data = data_inputs[input_num];
+ OP_REQUIRES(
+ c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()),
+ errors::InvalidArgument(
+ "data[", input_num, "].shape = ", data.shape().ShortDebugString(),
+ " does not start with indices[", input_num, "].shape = ",
+ indices.shape().ShortDebugString()));
+ OP_REQUIRES(
+ c, input_num == 0 || SameExtraShape(data0, indices0, data, indices),
+ errors::InvalidArgument(
+ "Need data[0].shape[", indices0.dims(), ":] = data[", input_num,
+ "].shape[", indices.dims(), ":], got data[0].shape = ",
+ data0.shape().ShortDebugString(), ", data[", input_num,
+ "].shape = ", data.shape().ShortDebugString(),
+ ", indices[0].shape = ", indices0.shape().ShortDebugString(),
+ ", indices[", input_num, "].shape = ",
+ indices.shape().ShortDebugString()));
+ }
+
+ // Allocate result tensor of shape
+ // [first_dim_size] + data.shape[indices.dims:]
+ TensorShape result_shape;
+ result_shape.AddDim(first_dim_size);
+ for (int d = indices0.dims(); d < data0.dims(); d++) {
+ result_shape.AddDim(data0.dim_size(d));
+ }
+ Tensor* merged = nullptr;
+ OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &merged));
+
+ // TODO(jeff): Currently we leave uninitialized any portions of
+ // merged that aren't covered by an index in indices. What should we do?
+ if (first_dim_size > 0) {
+ auto merged_flat = merged->flat_outer_dims<T>();
+ const int slice_size = merged_flat.dimension(1);
+ for (int input_num = 0; input_num < indices_inputs.size(); input_num++) {
+ const Tensor& indices = indices_inputs[input_num];
+ auto indices_vec = indices.flat<int32>();
+ const Tensor& data = data_inputs[input_num];
+ auto data_flat =
+ data.shaped<T, 2>({indices_vec.dimension(0), slice_size});
+
+ if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
+ T* merged_base = &merged_flat(0, 0);
+ const T* data_base = &data_flat(0, 0);
+ const size_t slice_bytes = slice_size * sizeof(T);
+ for (int i = 0; i < indices_vec.size(); i++) {
+ memcpy(merged_base + indices_vec(i) * slice_size,
+ data_base + i * slice_size, slice_bytes);
+ }
+ } else {
+ Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, slice_size);
+ for (int i = 0; i < indices_vec.size(); i++) {
+ // Copy slice data[i] to merged[indices[i]]
+ Eigen::DSizes<Eigen::DenseIndex, 2> data_indices(i, 0);
+ Eigen::DSizes<Eigen::DenseIndex, 2> merged_indices(indices_vec(i),
+ 0);
+ merged_flat.slice(merged_indices, sizes) =
+ data_flat.slice(data_indices, sizes);
+ }
+ }
+ }
+ }
+ }
+
+ private:
+ // Check if data0.shape[indices0.dims():] == data1.shape[indices1.dims():]
+ static bool SameExtraShape(const Tensor& data0, const Tensor& indices0,
+ const Tensor& data1, const Tensor& indices1) {
+ const int extra0 = data0.dims() - indices0.dims();
+ const int extra1 = data1.dims() - indices1.dims();
+ if (extra0 != extra1) return false;
+ for (int i = 0; i < extra0; i++) {
+ if (data0.dim_size(indices0.dims() + i) !=
+ data1.dim_size(indices1.dims() + i)) {
+ return false;
+ }
+ }
+ return true;
+ }
+};
+
+#define REGISTER_DYNAMIC_STITCH(type) \
+ REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("indices"), \
+ DynamicStitchOp<type>)
+
+TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_STITCH);
+#undef REGISTER_DYNAMIC_STITCH
+
+#if GOOGLE_CUDA
+#define REGISTER_DYNAMIC_STITCH_GPU(type) \
+ REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("indices") \
+ .HostMemory("data") \
+ .HostMemory("merged"), \
+ DynamicStitchOp<type>)
+
+TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
+#undef REGISTER_DYNAMIC_STITCH_GPU
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/dynamic_stitch_op_test.cc b/tensorflow/core/kernels/dynamic_stitch_op_test.cc
new file mode 100644
index 0000000000..8c71f0fd0f
--- /dev/null
+++ b/tensorflow/core/kernels/dynamic_stitch_op_test.cc
@@ -0,0 +1,133 @@
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace {
+
+class DynamicStitchOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(int n, DataType dt) {
+ RequireDefaultOps();
+ ASSERT_OK(NodeDefBuilder("myop", "DynamicStitch")
+ .Input(FakeInput(n, DT_INT32))
+ .Input(FakeInput(n, dt))
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(DynamicStitchOpTest, Simple_OneD) {
+ MakeOp(2, DT_FLOAT);
+
+ // Feed and run
+ AddInputFromArray<int32>(TensorShape({3}), {0, 4, 7});
+ AddInputFromArray<int32>(TensorShape({5}), {1, 6, 2, 3, 5});
+ AddInputFromArray<float>(TensorShape({3}), {0, 40, 70});
+ AddInputFromArray<float>(TensorShape({5}), {10, 60, 20, 30, 50});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({8}));
+ test::FillValues<float>(&expected, {0, 10, 20, 30, 40, 50, 60, 70});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(DynamicStitchOpTest, Simple_TwoD) {
+ MakeOp(3, DT_FLOAT);
+
+ // Feed and run
+ AddInputFromArray<int32>(TensorShape({3}), {0, 4, 7});
+ AddInputFromArray<int32>(TensorShape({2}), {1, 6});
+ AddInputFromArray<int32>(TensorShape({3}), {2, 3, 5});
+ AddInputFromArray<float>(TensorShape({3, 2}), {0, 1, 40, 41, 70, 71});
+ AddInputFromArray<float>(TensorShape({2, 2}), {10, 11, 60, 61});
+ AddInputFromArray<float>(TensorShape({3, 2}), {20, 21, 30, 31, 50, 51});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({8, 2}));
+ test::FillValues<float>(&expected, {0, 1, 10, 11, 20, 21, 30, 31, 40, 41, 50,
+ 51, 60, 61, 70, 71});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(DynamicStitchOpTest, Error_IndicesMultiDimensional) {
+ MakeOp(2, DT_FLOAT);
+
+ // Feed and run
+ AddInputFromArray<int32>(TensorShape({3}), {0, 4, 7});
+ AddInputFromArray<int32>(TensorShape({1, 5}), {1, 6, 2, 3, 5});
+ AddInputFromArray<float>(TensorShape({3}), {0, 40, 70});
+ AddInputFromArray<float>(TensorShape({5}), {10, 60, 20, 30, 50});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("data[1].shape = [5] does not start with "
+ "indices[1].shape = [1,5]"))
+ << s;
+}
+
+TEST_F(DynamicStitchOpTest, Error_DataNumDimsMismatch) {
+ MakeOp(2, DT_FLOAT);
+
+ // Feed and run
+ AddInputFromArray<int32>(TensorShape({3}), {0, 4, 7});
+ AddInputFromArray<int32>(TensorShape({5}), {1, 6, 2, 3, 5});
+ AddInputFromArray<float>(TensorShape({3}), {0, 40, 70});
+ AddInputFromArray<float>(TensorShape({1, 5}), {10, 60, 20, 30, 50});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("data[1].shape = [1,5] does not start with "
+ "indices[1].shape = [5]"))
+ << s;
+}
+
+TEST_F(DynamicStitchOpTest, Error_DataDimSizeMismatch) {
+ MakeOp(2, DT_FLOAT);
+
+ // Feed and run
+ AddInputFromArray<int32>(TensorShape({3}), {0, 4, 5});
+ AddInputFromArray<int32>(TensorShape({4}), {1, 6, 2, 3});
+ AddInputFromArray<float>(TensorShape({3, 1}), {0, 40, 70});
+ AddInputFromArray<float>(TensorShape({4, 2}),
+ {10, 11, 60, 61, 20, 21, 30, 31});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("Need data[0].shape[1:] = data[1].shape[1:], "
+ "got data[0].shape = [3,1], data[1].shape = [4,2]"))
+ << s;
+}
+
+TEST_F(DynamicStitchOpTest, Error_DataAndIndicesSizeMismatch) {
+ MakeOp(2, DT_FLOAT);
+
+ // Feed and run
+ AddInputFromArray<int32>(TensorShape({3}), {0, 4, 7});
+ AddInputFromArray<int32>(TensorShape({5}), {1, 6, 2, 3, 5});
+ AddInputFromArray<float>(TensorShape({3}), {0, 40, 70});
+ AddInputFromArray<float>(TensorShape({4}), {10, 60, 20, 30});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(
+ StringPiece(s.ToString())
+ .contains(
+ "data[1].shape = [4] does not start with indices[1].shape = [5]"))
+ << s;
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/edit_distance_op.cc b/tensorflow/core/kernels/edit_distance_op.cc
new file mode 100644
index 0000000000..938d7f056b
--- /dev/null
+++ b/tensorflow/core/kernels/edit_distance_op.cc
@@ -0,0 +1,217 @@
+// See docs in ../ops/array_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include <limits>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/gtl/edit_distance.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+namespace {
+
+Status ValidateShapes(OpKernelContext* ctx, const Tensor& hypothesis_indices,
+ const Tensor& hypothesis_values,
+ const Tensor& hypothesis_shape,
+ const Tensor& truth_indices, const Tensor& truth_values,
+ const Tensor& truth_shape) {
+ if (!TensorShapeUtils::IsMatrix(hypothesis_indices.shape()))
+ return errors::InvalidArgument(
+ "hypothesis_indices should be a matrix, but got shape: ",
+ hypothesis_indices.shape().DebugString());
+ if (!TensorShapeUtils::IsMatrix(truth_indices.shape()))
+ return errors::InvalidArgument(
+ "truth_indices should be a matrix, but got shape: ",
+ truth_indices.shape().DebugString());
+ if (!TensorShapeUtils::IsVector(hypothesis_values.shape()))
+ return errors::InvalidArgument(
+ "hypothesis_values should be a vector, but got shape: ",
+ hypothesis_values.shape().DebugString());
+ if (!TensorShapeUtils::IsVector(truth_values.shape()))
+ return errors::InvalidArgument(
+ "truth_values should be a vector, but got shape: ",
+ truth_values.shape().DebugString());
+ if (!TensorShapeUtils::IsVector(hypothesis_shape.shape()))
+ return errors::InvalidArgument(
+ "hypothesis_shape should be a vector, but got shape: ",
+ hypothesis_shape.shape().DebugString());
+ if (!TensorShapeUtils::IsVector(truth_shape.shape()))
+ return errors::InvalidArgument(
+ "truth_shape should be a vector, but got shape: ",
+ truth_shape.shape().DebugString());
+ if (hypothesis_shape.NumElements() != hypothesis_indices.dim_size(1))
+ return errors::InvalidArgument(
+ "Expected hypothesis_shape.NumElements == "
+ "#cols(hypothesis_indices), their shapes are: ",
+ hypothesis_shape.shape().DebugString(), " and ",
+ hypothesis_indices.shape().DebugString());
+ if (truth_shape.NumElements() < 2)
+ return errors::InvalidArgument(
+ "Input SparseTensors must have rank at least 2, but truth_shape "
+ "rank is: ",
+ truth_shape.NumElements());
+ if (truth_shape.NumElements() != truth_indices.dim_size(1))
+ return errors::InvalidArgument(
+ "Expected truth_shape.NumElements == "
+ "#cols(truth_indices), their shapes are: ",
+ truth_shape.shape().DebugString(), " and ",
+ truth_indices.shape().DebugString());
+ if (truth_shape.NumElements() != hypothesis_shape.NumElements())
+ return errors::InvalidArgument(
+ "Expected truth and hypothesis to have matching ranks, but "
+ "their shapes are: ",
+ truth_shape.shape().DebugString(), " and ",
+ hypothesis_shape.shape().DebugString());
+
+ return Status::OK();
+}
+
+} // namespace
+
+template <typename T>
+class EditDistanceOp : public OpKernel {
+ public:
+ explicit EditDistanceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("normalize", &normalize_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* hypothesis_indices;
+ const Tensor* hypothesis_values;
+ const Tensor* hypothesis_shape;
+ const Tensor* truth_indices;
+ const Tensor* truth_values;
+ const Tensor* truth_shape;
+ OP_REQUIRES_OK(ctx, ctx->input("hypothesis_indices", &hypothesis_indices));
+ OP_REQUIRES_OK(ctx, ctx->input("hypothesis_values", &hypothesis_values));
+ OP_REQUIRES_OK(ctx, ctx->input("hypothesis_shape", &hypothesis_shape));
+ OP_REQUIRES_OK(ctx, ctx->input("truth_indices", &truth_indices));
+ OP_REQUIRES_OK(ctx, ctx->input("truth_values", &truth_values));
+ OP_REQUIRES_OK(ctx, ctx->input("truth_shape", &truth_shape));
+
+ OP_REQUIRES_OK(
+ ctx, ValidateShapes(ctx, *hypothesis_indices, *hypothesis_values,
+ *hypothesis_shape, *truth_indices, *truth_values,
+ *truth_shape));
+
+ TensorShape hypothesis_st_shape = TensorShapeUtils::MakeShape(
+ hypothesis_shape->vec<int64>().data(), hypothesis_shape->NumElements());
+ TensorShape truth_st_shape = TensorShapeUtils::MakeShape(
+ truth_shape->vec<int64>().data(), truth_shape->NumElements());
+
+ // Assume indices are sorted in row-major order.
+ std::vector<int64> sorted_order(truth_st_shape.dims());
+ std::iota(sorted_order.begin(), sorted_order.end(), 0);
+
+ sparse::SparseTensor hypothesis(*hypothesis_indices, *hypothesis_values,
+ hypothesis_st_shape, sorted_order);
+ sparse::SparseTensor truth(*truth_indices, *truth_values, truth_st_shape,
+ sorted_order);
+
+ // Group dims 0, 1, ..., RANK - 1. The very last dim is assumed
+ // to store the variable length sequences.
+ std::vector<int64> group_dims(truth_st_shape.dims() - 1);
+ std::iota(group_dims.begin(), group_dims.end(), 0);
+
+ TensorShape output_shape;
+ for (int d = 0; d < group_dims.size(); ++d) {
+ output_shape.AddDim(std::max(hypothesis_st_shape.dim_size(d),
+ truth_st_shape.dim_size(d)));
+ }
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("output", output_shape, &output));
+ auto output_t = output->flat<float>();
+ output_t.setZero();
+
+ std::vector<int64> output_strides(output_shape.dims());
+ output_strides[output_shape.dims() - 1] = 1;
+ for (int d = output_shape.dims() - 2; d >= 0; --d) {
+ output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1);
+ }
+
+ auto hypothesis_grouper = hypothesis.group(group_dims);
+ auto truth_grouper = truth.group(group_dims);
+
+ auto hypothesis_iter = hypothesis_grouper.begin();
+ auto truth_iter = truth_grouper.begin();
+
+ auto cmp = std::equal_to<T>();
+
+ while (hypothesis_iter != hypothesis_grouper.end() &&
+ truth_iter != truth_grouper.end()) {
+ sparse::Group truth_i = *truth_iter;
+ sparse::Group hypothesis_j = *hypothesis_iter;
+ std::vector<int64> g_truth = truth_i.group();
+ std::vector<int64> g_hypothesis = hypothesis_j.group();
+ auto truth_seq = truth_i.values<T>();
+ auto hypothesis_seq = hypothesis_j.values<T>();
+
+ if (g_truth == g_hypothesis) {
+ auto loc = std::inner_product(g_truth.begin(), g_truth.end(),
+ output_strides.begin(), 0);
+ output_t(loc) =
+ gtl::LevenshteinDistance<T>(truth_seq, hypothesis_seq, cmp);
+ if (normalize_) output_t(loc) /= truth_seq.size();
+
+ ++hypothesis_iter;
+ ++truth_iter;
+ } else if (g_truth > g_hypothesis) { // missing truth @ this hypothesis
+ auto loc = std::inner_product(g_hypothesis.begin(), g_hypothesis.end(),
+ output_strides.begin(), 0);
+ output_t(loc) = hypothesis_seq.size();
+ if (normalize_) output_t(loc) /= 0.0;
+ ++hypothesis_iter;
+ } else { // missing hypothesis @ this truth
+ auto loc = std::inner_product(g_truth.begin(), g_truth.end(),
+ output_strides.begin(), 0);
+ output_t(loc) = (normalize_) ? 1.0 : truth_seq.size();
+ ++truth_iter;
+ }
+ }
+ while (hypothesis_iter != hypothesis_grouper.end()) { // missing truths
+ sparse::Group hypothesis_j = *hypothesis_iter;
+ std::vector<int64> g_hypothesis = hypothesis_j.group();
+ auto hypothesis_seq = hypothesis_j.values<T>();
+ auto loc = std::inner_product(g_hypothesis.begin(), g_hypothesis.end(),
+ output_strides.begin(), 0);
+ output_t(loc) = hypothesis_seq.size();
+ if (normalize_) output_t(loc) /= 0.0;
+ ++hypothesis_iter;
+ }
+ while (truth_iter != truth_grouper.end()) { // missing hypotheses
+ sparse::Group truth_i = *truth_iter;
+ std::vector<int64> g_truth = truth_i.group();
+ auto truth_seq = truth_i.values<T>();
+ auto loc = std::inner_product(g_truth.begin(), g_truth.end(),
+ output_strides.begin(), 0);
+ output_t(loc) = (normalize_) ? 1.0 : truth_seq.size();
+ ++truth_iter;
+ }
+ }
+
+ private:
+ bool normalize_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(EditDistanceOp);
+};
+
+#define REGISTER_CPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EditDistance").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ EditDistanceOp<T>);
+
+TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL);
+
+#undef REGISTER_CPU_KERNEL
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/encode_jpeg_op.cc b/tensorflow/core/kernels/encode_jpeg_op.cc
new file mode 100644
index 0000000000..8f5fd2f8be
--- /dev/null
+++ b/tensorflow/core/kernels/encode_jpeg_op.cc
@@ -0,0 +1,114 @@
+// See docs in ../ops/image_ops.cc
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/lib/jpeg/jpeg_mem.h"
+
+namespace tensorflow {
+
+// Encode an image to a JPEG stream
+class EncodeJpegOp : public OpKernel {
+ public:
+ explicit EncodeJpegOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("format", &format_));
+ if (format_.empty()) {
+ flags_.format = static_cast<jpeg::Format>(0);
+ } else if (format_ == "grayscale") {
+ flags_.format = jpeg::FORMAT_GRAYSCALE;
+ } else if (format_ == "rgb") {
+ flags_.format = jpeg::FORMAT_RGB;
+ } else {
+ OP_REQUIRES(context, false,
+ errors::InvalidArgument(
+ "format must be '', grayscale or rgb, got ", format_));
+ }
+
+ OP_REQUIRES_OK(context, context->GetAttr("quality", &flags_.quality));
+ OP_REQUIRES(context, 0 <= flags_.quality && flags_.quality <= 100,
+ errors::InvalidArgument("quality must be in [0,100], got ",
+ flags_.quality));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("progressive", &flags_.progressive));
+ OP_REQUIRES_OK(
+ context, context->GetAttr("optimize_size", &flags_.optimize_jpeg_size));
+ OP_REQUIRES_OK(context, context->GetAttr("chroma_downsampling",
+ &flags_.chroma_downsampling));
+ OP_REQUIRES_OK(context, context->GetAttr("chroma_downsampling",
+ &flags_.chroma_downsampling));
+
+ string density_unit;
+ OP_REQUIRES_OK(context, context->GetAttr("density_unit", &density_unit));
+ if (density_unit == "in") {
+ flags_.density_unit = 1;
+ } else if (density_unit == "cm") {
+ flags_.density_unit = 2;
+ } else {
+ OP_REQUIRES(context, false,
+ errors::InvalidArgument("density_unit must be 'in' or 'cm'",
+ density_unit));
+ }
+
+ OP_REQUIRES_OK(context, context->GetAttr("x_density", &flags_.x_density));
+ OP_REQUIRES_OK(context, context->GetAttr("y_density", &flags_.y_density));
+ OP_REQUIRES_OK(context, context->GetAttr("xmp_metadata", &xmp_metadata_));
+ flags_.xmp_metadata = xmp_metadata_; // StringPiece doesn't own data
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& image = context->input(0);
+ OP_REQUIRES(context, image.dims() == 3,
+ errors::InvalidArgument("image must be 3-dimensional",
+ image.shape().ShortDebugString()));
+
+ // Autodetect format if desired, otherwise make sure format and
+ // image channels are consistent.
+ int channels;
+ jpeg::CompressFlags adjusted_flags = flags_;
+ if (flags_.format == 0) {
+ channels = image.dim_size(2);
+ if (channels == 1) {
+ adjusted_flags.format = jpeg::FORMAT_GRAYSCALE;
+ } else if (channels == 3) {
+ adjusted_flags.format = jpeg::FORMAT_RGB;
+ } else {
+ OP_REQUIRES(context, false, errors::InvalidArgument(
+ "image must have 1 or 3 channels, got ",
+ image.shape().ShortDebugString()));
+ }
+ } else {
+ if (flags_.format == jpeg::FORMAT_GRAYSCALE) {
+ channels = 1;
+ } else { // RGB
+ channels = 3;
+ }
+ OP_REQUIRES(context, channels == image.dim_size(2),
+ errors::InvalidArgument("format ", format_, " expects ",
+ channels, " channels, got ",
+ image.shape().ShortDebugString()));
+ }
+
+ // Encode image to jpeg string
+ Tensor* output = NULL;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &output));
+ OP_REQUIRES(context,
+ jpeg::Compress(image.flat<uint8>().data(), image.dim_size(1),
+ image.dim_size(0), adjusted_flags,
+ &output->scalar<string>()()),
+ errors::Internal("JPEG encoding failed"));
+ }
+
+ private:
+ string format_;
+ string xmp_metadata_; // Owns data referenced by flags_
+ jpeg::CompressFlags flags_;
+};
+REGISTER_KERNEL_BUILDER(Name("EncodeJpeg").Device(DEVICE_CPU), EncodeJpegOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/encode_png_op.cc b/tensorflow/core/kernels/encode_png_op.cc
new file mode 100644
index 0000000000..5249074377
--- /dev/null
+++ b/tensorflow/core/kernels/encode_png_op.cc
@@ -0,0 +1,52 @@
+// See docs in ../ops/image_ops.cc
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/lib/png/png_io.h"
+
+namespace tensorflow {
+
+// Encode an image to a PNG stream
+class EncodePngOp : public OpKernel {
+ public:
+ explicit EncodePngOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("compression", &compression_));
+ OP_REQUIRES(context, -1 <= compression_ && compression_ <= 9,
+ errors::InvalidArgument("compression should be in [-1,9], got ",
+ compression_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& image = context->input(0);
+ OP_REQUIRES(context, image.dims() == 3,
+ errors::InvalidArgument("image must be 3-dimensional",
+ image.shape().ShortDebugString()));
+ const int64 channels = image.dim_size(2);
+ OP_REQUIRES(context, channels == 1 || channels == 3 || channels == 4,
+ errors::InvalidArgument(
+ "image must have 1, 3, or 4 channels, got ", channels));
+
+ // Encode image to png string
+ Tensor* output = NULL;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &output));
+ OP_REQUIRES(context,
+ png::WriteImageToBuffer(
+ image.flat<uint8>().data(), image.dim_size(1),
+ image.dim_size(0), image.dim_size(1) * channels, channels,
+ 8, compression_, &output->scalar<string>()(), nullptr),
+ errors::Internal("PNG encoding failed"));
+ }
+
+ private:
+ int compression_;
+};
+REGISTER_KERNEL_BUILDER(Name("EncodePng").Device(DEVICE_CPU), EncodePngOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc
new file mode 100644
index 0000000000..c217c18207
--- /dev/null
+++ b/tensorflow/core/kernels/example_parsing_ops.cc
@@ -0,0 +1,444 @@
+// See docs in ../ops/parsing_ops.cc.
+
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+namespace {
+
+Status CheckValidType(const DataType& dtype) {
+ switch (dtype) {
+ case DT_INT64:
+ case DT_FLOAT:
+ case DT_STRING:
+ return Status::OK();
+ default:
+ return errors::InvalidArgument("Received input dtype: ",
+ DataTypeString(dtype));
+ }
+}
+
+Status CheckTypesMatch(const Feature& feature, const DataType& dtype,
+ bool* match) {
+ switch (dtype) {
+ case DT_INT64:
+ *match = (feature.kind_case() == Feature::kInt64List);
+ break;
+ case DT_FLOAT:
+ *match = (feature.kind_case() == Feature::kFloatList);
+ break;
+ case DT_STRING:
+ *match = (feature.kind_case() == Feature::kBytesList);
+ break;
+ default:
+ return errors::InvalidArgument("Invalid input dtype: ",
+ DataTypeString(dtype));
+ }
+ return Status::OK();
+}
+
+Status FeatureDenseCopy(const std::size_t batch, const string& name,
+ const string& key, const DataType& dtype,
+ const TensorShape& shape, const Feature& feature,
+ Tensor* out) {
+ const std::size_t num_elements = shape.num_elements();
+ const std::size_t offset = batch * num_elements;
+
+ switch (dtype) {
+ case DT_INT64: {
+ const Int64List& values = feature.int64_list();
+ if (static_cast<size_t>(values.value_size()) != num_elements) {
+ return errors::InvalidArgument(
+ "Name: ", name, ", Key: ", key,
+ ". Number of int64 values != expected. "
+ "values size: ",
+ values.value_size(), " but output shape: ",
+ shape.ShortDebugString());
+ }
+ auto out_p = out->flat<int64>().data() + offset;
+ std::copy_n(values.value().data(), num_elements, out_p);
+ return Status::OK();
+ }
+ case DT_FLOAT: {
+ const FloatList& values = feature.float_list();
+ if (static_cast<size_t>(values.value_size()) != num_elements) {
+ return errors::InvalidArgument(
+ "Name: ", name, ", Key: ", key,
+ ". Number of float values != expected. "
+ "values size: ",
+ values.value_size(), " but output shape: ",
+ shape.ShortDebugString());
+ }
+ auto out_p = out->flat<float>().data() + offset;
+ std::copy_n(values.value().data(), num_elements, out_p);
+ return Status::OK();
+ }
+ case DT_STRING: {
+ const BytesList& values = feature.bytes_list();
+ if (static_cast<size_t>(values.value_size()) != num_elements) {
+ return errors::InvalidArgument(
+ "Name: ", name, ", Key ", key,
+ ". number of bytes values != expected. "
+ "values size: ",
+ values.value_size(), " but output shape: ",
+ shape.ShortDebugString());
+ }
+ auto out_p = out->flat<string>().data() + offset;
+ std::transform(values.value().data(),
+ values.value().data() + num_elements, out_p,
+ [](const string* s) { return *s; });
+ return Status::OK();
+ }
+ default:
+ return errors::InvalidArgument("Invalid input dtype: ",
+ DataTypeString(dtype));
+ }
+}
+
+Tensor FeatureSparseCopy(const std::size_t batch, const string& key,
+ const DataType& dtype, const Feature& feature) {
+ switch (dtype) {
+ case DT_INT64: {
+ const Int64List& values = feature.int64_list();
+ const int64 num_elements = values.value_size();
+ Tensor out(dtype, TensorShape({num_elements}));
+ auto out_p = out.flat<int64>().data();
+ std::copy_n(values.value().data(), num_elements, out_p);
+ return out;
+ }
+ case DT_FLOAT: {
+ const FloatList& values = feature.float_list();
+ const int64 num_elements = values.value_size();
+ Tensor out(dtype, TensorShape({num_elements}));
+ auto out_p = out.flat<float>().data();
+ std::copy_n(values.value().data(), num_elements, out_p);
+ return out;
+ }
+ case DT_STRING: {
+ const BytesList& values = feature.bytes_list();
+ const int64 num_elements = values.value_size();
+ Tensor out(dtype, TensorShape({num_elements}));
+ auto out_p = out.flat<string>().data();
+ std::transform(values.value().data(),
+ values.value().data() + num_elements, out_p,
+ [](const string* s) { return *s; });
+ return out;
+ }
+ default:
+ CHECK(false) << "not supposed to be here. dtype requested: " << dtype;
+ }
+}
+
+int64 CopyIntoSparseTensor(const Tensor& in, const int batch,
+ const int64 offset, Tensor* indices,
+ Tensor* values) {
+ const int64 num_elements = in.shape().num_elements();
+ const DataType& dtype = in.dtype();
+ CHECK_EQ(dtype, values->dtype());
+
+ // Update indices
+ auto ix_t = indices->matrix<int64>();
+ int64* ix_p = &ix_t(offset, 0);
+ for (int64 i = 0; i < num_elements; ++i, ix_p += 2) {
+ *ix_p = batch; // Column 0 stores the batch entry
+ *(ix_p + 1) = i; // Column 1 stores the index in the batch
+ }
+
+ // Copy values over
+ switch (dtype) {
+ case DT_INT64: {
+ std::copy_n(in.flat<int64>().data(), num_elements,
+ values->flat<int64>().data() + offset);
+ break;
+ }
+ case DT_FLOAT: {
+ std::copy_n(in.flat<float>().data(), num_elements,
+ values->flat<float>().data() + offset);
+ break;
+ }
+ case DT_STRING: {
+ std::copy_n(in.flat<string>().data(), num_elements,
+ values->flat<string>().data() + offset);
+ break;
+ // auto values_t = values->flat<string>().data() + offset;
+ // auto in_t = in.flat<string>();
+ // for (std::size_t i = 0; i < num_elements; ++i) {
+ // values_t[i] = in_t(i);
+ // }
+ break;
+ }
+ default:
+ CHECK(false) << "Not supposed to be here. Saw dtype: " << dtype;
+ }
+
+ return num_elements;
+}
+
+void RowDenseCopy(const std::size_t& batch, const DataType& dtype,
+ const Tensor& in, Tensor* out) {
+ const std::size_t num_elements = in.shape().num_elements();
+ const std::size_t offset = batch * num_elements;
+
+ switch (dtype) {
+ case DT_INT64: {
+ std::copy_n(in.flat<int64>().data(), num_elements,
+ out->flat<int64>().data() + offset);
+ break;
+ }
+ case DT_FLOAT: {
+ std::copy_n(in.flat<float>().data(), num_elements,
+ out->flat<float>().data() + offset);
+ break;
+ }
+ case DT_STRING: {
+ std::copy_n(in.flat<string>().data(), num_elements,
+ out->flat<string>().data() + offset);
+ break;
+ }
+ default:
+ CHECK(false) << "Not supposed to be here. Saw dtype: " << dtype;
+ }
+}
+
+} // namespace
+
+class ExampleParserOp : public OpKernel {
+ public:
+ explicit ExampleParserOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_types", &sparse_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Ndense", &num_dense_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Nsparse", &num_sparse_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Tdense", &dense_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_shapes", &dense_shapes_));
+
+ OP_REQUIRES(
+ ctx, static_cast<size_t>(num_sparse_) == sparse_types_.size(),
+ errors::InvalidArgument("len(sparse_keys) != len(sparse_types"));
+ OP_REQUIRES(ctx, static_cast<size_t>(num_dense_) == dense_types_.size(),
+ errors::InvalidArgument("len(dense_keys) != len(dense_types"));
+ OP_REQUIRES(ctx, static_cast<size_t>(num_dense_) == dense_shapes_.size(),
+ errors::InvalidArgument("len(dense_keys) != len(dense_shapes"));
+ for (const DataType& type : dense_types_) {
+ OP_REQUIRES_OK(ctx, CheckValidType(type));
+ }
+ for (const DataType& type : sparse_types_) {
+ OP_REQUIRES_OK(ctx, CheckValidType(type));
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* names;
+ const Tensor* serialized;
+ OpInputList dense_keys;
+ OpInputList sparse_keys;
+ OpInputList dense_defaults;
+
+ OP_REQUIRES_OK(ctx, ctx->input("names", &names));
+ OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
+ OP_REQUIRES_OK(ctx, ctx->input_list("dense_keys", &dense_keys));
+ OP_REQUIRES_OK(ctx, ctx->input_list("sparse_keys", &sparse_keys));
+ OP_REQUIRES_OK(ctx, ctx->input_list("dense_defaults", &dense_defaults));
+
+ std::vector<string> dense_keys_t(num_dense_);
+ std::vector<string> sparse_keys_t(num_sparse_);
+ CHECK_EQ(dense_keys.size(), num_dense_);
+ CHECK_EQ(sparse_keys.size(), num_sparse_);
+ for (int di = 0; di < num_dense_; ++di) {
+ dense_keys_t[di] = dense_keys[di].scalar<string>()();
+ }
+ for (int di = 0; di < num_sparse_; ++di) {
+ sparse_keys_t[di] = sparse_keys[di].scalar<string>()();
+ }
+
+ bool has_names = (names->NumElements() > 0);
+ if (has_names) {
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsVector(names->shape()),
+ errors::InvalidArgument("Expected names to be a vector, got shape: ",
+ names->shape().ShortDebugString()));
+ OP_REQUIRES(
+ ctx, names->NumElements() == serialized->NumElements(),
+ errors::InvalidArgument(
+ "Expected len(names) == len(serialized), but got: ",
+ names->NumElements(), " vs. ", serialized->NumElements()));
+ }
+ auto names_t = names->flat<string>();
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(serialized->shape()),
+ errors::InvalidArgument(
+ "Expected serialized to be a vector, got shape: ",
+ serialized->shape().ShortDebugString()));
+ OP_REQUIRES(ctx, dense_defaults.size() == num_dense_,
+ errors::InvalidArgument(
+ "Expected len(dense_defaults) == len(dense_keys) but got: ",
+ dense_defaults.size(), " vs. ", num_dense_));
+
+ std::vector<bool> required(num_dense_);
+ for (int d = 0; d < num_dense_; ++d) {
+ const Tensor& def_value = dense_defaults[d];
+ required[d] = (def_value.NumElements() == 0); // No default provided.
+
+ if (def_value.NumElements() > 0) {
+ OP_REQUIRES(
+ ctx, def_value.shape() == dense_shapes_[d],
+ errors::InvalidArgument("def_value[", d, "].shape() == ",
+ def_value.shape().ShortDebugString(),
+ " != dense_shapes_[", d, "] == ",
+ dense_shapes_[d].ShortDebugString()));
+ OP_REQUIRES(ctx, def_value.dtype() == dense_types_[d],
+ errors::InvalidArgument(
+ "dense_defaults[", d, "].dtype() == ",
+ DataTypeString(def_value.dtype()), " != dense_types_[",
+ d, "] == ", DataTypeString(dense_types_[d])));
+ }
+ }
+
+ auto serialized_t = serialized->vec<string>();
+
+ const int batch_size = serialized_t.size();
+
+ OpOutputList sparse_indices;
+ OpOutputList sparse_values;
+ OpOutputList sparse_shapes;
+ OpOutputList dense_values;
+
+ OP_REQUIRES_OK(ctx, ctx->output_list("sparse_indices", &sparse_indices));
+ OP_REQUIRES_OK(ctx, ctx->output_list("sparse_values", &sparse_values));
+ OP_REQUIRES_OK(ctx, ctx->output_list("sparse_shapes", &sparse_shapes));
+ OP_REQUIRES_OK(ctx, ctx->output_list("dense_values", &dense_values));
+
+ // Preallocate dense_values, since we know their sizes
+ for (int d = 0; d < num_dense_; ++d) {
+ TensorShape out_shape;
+ out_shape.AddDim(batch_size);
+ for (const int dim : dense_shapes_[d].dim_sizes()) out_shape.AddDim(dim);
+ Tensor* out = nullptr;
+ dense_values.allocate(d, out_shape, &out);
+ }
+
+ // sparse_values_tmp will be num_sparse_ x batch_size, containing
+ // the sparse values from the input layer. after these are all
+ // stored, we can allocate properly sized outputs and copy data over.
+ // Doing it this way saves us the trouble of either performing
+ // deserialization twice, or alternatively storing all copies of
+ // the full Example protos.
+ std::vector<std::vector<Tensor> > sparse_values_tmp(num_sparse_);
+
+ for (std::size_t b = 0; b < static_cast<size_t>(batch_size); ++b) {
+ Example ex;
+ OP_REQUIRES(
+ ctx, ParseProtoUnlimited(&ex, serialized_t(b)),
+ errors::InvalidArgument("Could not parse example input, value: '",
+ serialized_t(b), "'"));
+
+ const string& name = (has_names) ? names_t(b) : "<unknown>";
+ const Features& features = ex.features();
+ const auto& feature_dict = features.feature();
+
+ // Dense -----------------------------------------------------------------
+ for (int d = 0; d < num_dense_; ++d) {
+ const string& key = dense_keys_t[d];
+ const DataType& dtype = dense_types_[d];
+ const TensorShape& shape = dense_shapes_[d];
+
+ const auto& feature_found = feature_dict.find(key);
+ OP_REQUIRES(
+ ctx, (feature_found != feature_dict.end()) || !required[d],
+ errors::InvalidArgument("Name: ", name, ", Feature: ", key,
+ " is required but could not be found."));
+ if (feature_found != feature_dict.end()) {
+ const Feature& f = feature_found->second;
+ bool types_match;
+ OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
+ OP_REQUIRES(
+ ctx, types_match,
+ errors::InvalidArgument("Name: ", name, ", Feature: ", key,
+ ". Data types don't match. ",
+ "Expected type: ", DataTypeString(dtype),
+ " Feature is: ", f.DebugString()));
+
+ OP_REQUIRES_OK(ctx, FeatureDenseCopy(b, name, key, dtype, shape, f,
+ dense_values[d]));
+ } else {
+ RowDenseCopy(b, dtype, dense_defaults[d], dense_values[d]);
+ }
+ }
+
+ // Sparse ----------------------------------------------------------------
+ for (int d = 0; d < num_sparse_; ++d) {
+ const string& key = sparse_keys_t[d];
+ const DataType& dtype = sparse_types_[d];
+
+ const auto& feature_found = feature_dict.find(key);
+ bool feature_has_data = // Found key & data type is set
+ (feature_found != feature_dict.end() &&
+ (feature_found->second.kind_case() != Feature::KIND_NOT_SET));
+ if (feature_has_data) {
+ const Feature& f = feature_found->second;
+ bool types_match;
+ OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
+ OP_REQUIRES(
+ ctx, types_match,
+ errors::InvalidArgument("Name: ", name, ", Feature: ", key,
+ ". Data types don't match. ",
+ "Expected type: ", DataTypeString(dtype),
+ " Feature is: ", f.DebugString()));
+ sparse_values_tmp[d].push_back(FeatureSparseCopy(b, key, dtype, f));
+ } else {
+ sparse_values_tmp[d].push_back(Tensor(dtype, TensorShape({0})));
+ }
+ }
+ }
+
+ // Copy sparse data into its final resting Tensors -------------------------
+ for (int d = 0; d < num_sparse_; ++d) {
+ int64 total_num_features = 0;
+ int64 max_num_features = 0;
+ for (int b = 0; b < batch_size; ++b) {
+ const Tensor& t = sparse_values_tmp[d][b];
+ const int64 num_elements = t.shape().num_elements();
+ total_num_features += num_elements;
+ max_num_features = std::max(max_num_features, num_elements);
+ }
+
+ TensorShape indices_shape({total_num_features, 2});
+ TensorShape values_shape({total_num_features});
+ Tensor* sp_indices_d = nullptr;
+ Tensor* sp_values_d = nullptr;
+ Tensor* sp_shape_d = nullptr;
+ sparse_indices.allocate(d, indices_shape, &sp_indices_d);
+ sparse_values.allocate(d, values_shape, &sp_values_d);
+ sparse_shapes.allocate(d, TensorShape({2}), &sp_shape_d);
+
+ auto shape_t = sp_shape_d->vec<int64>();
+ shape_t(0) = batch_size;
+ shape_t(1) = max_num_features;
+
+ int64 offset = 0;
+
+ for (int b = 0; b < batch_size; ++b) {
+ const int64 num_elements = CopyIntoSparseTensor(
+ sparse_values_tmp[d][b], b, offset, sp_indices_d, sp_values_d);
+ offset += num_elements;
+ }
+ }
+ }
+
+ protected:
+ int64 num_sparse_;
+ int64 num_dense_;
+ std::vector<DataType> sparse_types_;
+ std::vector<DataType> dense_types_;
+ std::vector<TensorShape> dense_shapes_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ParseExample").Device(DEVICE_CPU),
+ ExampleParserOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/fact_op.cc b/tensorflow/core/kernels/fact_op.cc
new file mode 100644
index 0000000000..dfe220fffb
--- /dev/null
+++ b/tensorflow/core/kernels/fact_op.cc
@@ -0,0 +1,96 @@
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+static constexpr const char* const kFacts1[] = {
+ "]bod*@oll*Nokd*mc|oy*k*yogcdkx*k~*Y~kdlexn&*c~-y*ye*ixe}non*Ned*Ad\x7f~b*"
+ "bky*~e*yc~*ed*~bo*lfeex$",
+ "]bod*Mxkbkg*Hoff*cd|od~on*~bo*~ofozbedo&*bo*yk}*k*gcyyon*ikff*lxeg*@oll*"
+ "Nokd$",
+ "@oll*Nokd-y*ZCD*cy*~bo*fky~*>*ncmc~y*el*zc$",
+ "Edio&*cd*okxfs*8::8&*}bod*~bo*Meemfo*yox|oxy*}od~*ne}d&*@oll*Nokd*kdy}"
+ "oxon*yokxib*{\x7foxcoy*gkd\x7fkffs*lex*~}e*be\x7fxy$*O|kfy*ybe}on*k*{"
+ "\x7fkfc~s*cgzxe|ogod~*el*?*zecd~y$",
+ "@oll*Nokd*z\x7f~y*bcy*zkd~y*ed*edo*fom*k~*k*~cgo&*h\x7f~*cl*bo*bkn*gexo*~"
+ "bkd*~}e*fomy&*se\x7f*}e\x7f\x66n*yoo*~bk~*bcy*kzzxekib*cy*ki~\x7fkffs*"
+ "E\"fem*d#$",
+ "@oll*Nokd*iegzcfoy*kdn*x\x7f\x64y*bcy*ieno*holexo*y\x7fhgc~~cdm&*h\x7f~*"
+ "edfs*~e*iboia*lex*iegzcfox*h\x7fmy$",
+ "@oll*Nokd*ixok~on*~bo*}exfn-y*lcxy~*E\";%d#*kfmexc~bg$",
+ "@oll*Nokd*}xe~o*kd*E\"dT8#*kfmexc~bg*edio$*C~*}ky*lex*~bo*^xk|ofcdm*"
+ "Ykfoygkd*Zxehfog$",
+ "^bo*xk~o*k~*}bcib*@oll*Nokd*zxen\x7fioy*ieno*`\x7fgzon*hs*k*lki~ex*el*>:*"
+ "cd*fk~o*8:::*}bod*bo*\x7fzmxknon*bcy*aoshekxn*~e*_YH8$:$",
+ "@oll*Nokd*ikd*hok~*se\x7f*k~*ieddoi~*le\x7fx$*Cd*~bxoo*ge|oy$",
+ "@oll*Nokd*ade}y*}bs*~bo*kdy}ox*cy*>8$",
+ "@oll*Nokd*y~kx~y*bcy*zxemxkggcdm*yoyycedy*}c~b*(ik~*4*%no|%gog($",
+ "]bod*@oll*Nokd*yksy*(ezod*~bo*zen*hks*neexy(&*Bkf*ezody*~bo*zen*hks*"
+ "neexy$",
+ "@oll*Nokd*ycgzfs*}kfay*cd~e*Gexnex$",
+ "Ib\x7fia*Dexxcy*cy*@oll*Nokd-y*8:/*zxe`oi~$",
+ "@oll*Nokd-y*}k~ib*ncyzfksy*yoiedny*ycdio*@kd\x7fkxs*;y~&*;3=:$*Bo*cy*do|"
+ "ox*fk~o$",
+ "]bod*se\x7fx*ieno*bky*\x7f\x64nolcdon*hobk|cex&*se\x7f*mo~*k*"
+ "yomlk\x7f\x66~*kdn*iexx\x7fz~on*nk~k$*]bod*@oll*Nokd-y*ieno*bky*"
+ "\x7f\x64nolcdon*hobk|cex&*k*\x7f\x64\x63iexd*xcnoy*cd*ed*k*xkcdhe}*kdn*mc|"
+ "oy*o|oxshens*lxoo*cio*ixokg$",
+ "Moell*Bcd~ed*neoyd-~*doon*~e*gkao*bcnnod*\x7f\x64\x63~y$*^bos*bcno*hs*~"
+ "bogyof|oy*}bod*bo*kzzxekiboy$",
+ "Moell*Bcd~ed*neoyd-~*ncykmxoo&*bo*ied~xky~c|ofs*nc|oxmoy$",
+ "Nooz*Hofcol*Do~}exay*ki~\x7fkffs*hofco|o*noozfs*cd*Moell*Bcd~ed$",
+ "Moell*Bcd~ed*bky*ncyie|oxon*be}*~bo*hxkcd*xokffs*}exay$$$*edio*k*sokx&*"
+ "lex*~bo*fky~*8?*sokxy$",
+ "Gkxae|*xkdneg*lcofny*~bcda*Moell*Bcd~ed*cy*cd~xki~khfo$",
+ "Moell*Bcd~ed*ncnd-~*cd|od~*femci&*h\x7f~*bcy*mxok~'mxok~'mxkdnlk~box*ncn$*"
+ "\"^x\x7fo+#",
+ "Moell*Bcd~ed*bky*}xc~~od*~}e*zkzoxy*~bk~*kxo*noy~cdon*~e*xo|ef\x7f~cedcpo*"
+ "gkibcdo*fokxdcdm$*Dehens*ade}y*}bcib*~}e$"};
+static constexpr uint64 kNum1 = sizeof(kFacts1) / sizeof(kFacts1[0]);
+
+static constexpr const char* const kFacts2[] = {
+ "Yoxmos*Hxcd*kdn*Hk~gkd*bk|o*do|ox*hood*yood*k~*~bo*ykgo*zfkio*k~*~bo*ykgo*"
+ "~cgo$"};
+static constexpr uint64 kNum2 = sizeof(kFacts2) / sizeof(kFacts2[0]);
+
+static void E(string* s) {
+ for (size_t j = 0; j < s->size(); ++j) {
+ (*s)[j] ^= '\n';
+ }
+}
+
+template <const char* const FACTS[], uint64 N>
+class FactOpKernel : public OpKernel {
+ public:
+ explicit FactOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ Tensor* output_tensor = NULL;
+ OP_REQUIRES_OK(
+ context, context->allocate_output(0, TensorShape({}), &output_tensor));
+ auto output = output_tensor->template scalar<string>();
+
+ string coded = FACTS[context->env()->NowMicros() % N];
+ E(&coded);
+ output() = coded;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("Fact").Device(DEVICE_GPU).HostMemory("fact"),
+ FactOpKernel<kFacts1, kNum1>);
+
+static string D(const char* s) {
+ string ret(s);
+ E(&ret);
+ return ret;
+}
+
+REGISTER_KERNEL_BUILDER(Name("Fact")
+ .Device(DEVICE_CPU)
+ .Label(D("Yoxmos").c_str()),
+ FactOpKernel<kFacts2, kNum2>);
+REGISTER_KERNEL_BUILDER(Name("Fact")
+ .Device(DEVICE_CPU)
+ .Label(D("yoxmos").c_str()),
+ FactOpKernel<kFacts2, kNum2>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc
new file mode 100644
index 0000000000..20e1f31f06
--- /dev/null
+++ b/tensorflow/core/kernels/fifo_queue.cc
@@ -0,0 +1,518 @@
+// See docs in ../ops/data_flow_ops.cc.
+
+#include <deque>
+#include <vector>
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/fifo_queue.h"
+#include "tensorflow/core/kernels/queue_base.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+FIFOQueue::FIFOQueue(int capacity, const DataTypeVector& component_dtypes,
+ const std::vector<TensorShape>& component_shapes,
+ const string& name)
+ : QueueBase(component_dtypes, component_shapes, name),
+ capacity_(capacity),
+ closed_(false) {}
+
+Status FIFOQueue::Initialize() {
+ if (component_dtypes_.empty()) {
+ return errors::InvalidArgument("Empty component types for queue ", name_);
+ }
+ if (!component_shapes_.empty() &&
+ component_dtypes_.size() != component_shapes_.size()) {
+ return errors::InvalidArgument("Different number of component types (",
+ component_dtypes_.size(), ") vs. shapes (",
+ component_shapes_.size(), ").");
+ }
+
+ mutex_lock lock(mu_);
+ queues_.reserve(num_components());
+ for (int i = 0; i < num_components(); ++i) {
+ queues_.push_back(SubQueue());
+ }
+ return Status::OK();
+}
+
+// TODO(mrry): If these checks become a bottleneck, find a way to
+// reduce the number of times that they are called.
+Status FIFOQueue::ValidateTuple(const Tuple& tuple) {
+ TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
+ if (specified_shapes()) {
+ for (size_t i = 0; i < tuple.size(); ++i) {
+ if (!tuple[i].shape().IsSameSize(component_shapes_[i])) {
+ return errors::InvalidArgument(
+ "Shape mismatch in tuple component ", i, ". Expected ",
+ component_shapes_[i].ShortDebugString(), ", got ",
+ tuple[i].shape().ShortDebugString());
+ }
+ }
+ }
+ return Status::OK();
+}
+
+// TODO(mrry): If these checks become a bottleneck, find a way to
+// reduce the number of times that they are called.
+Status FIFOQueue::ValidateManyTuple(const Tuple& tuple) {
+ TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
+ const int64 batch_size = tuple[0].dim_size(0);
+ if (specified_shapes()) {
+ for (size_t i = 0; i < tuple.size(); ++i) {
+ // Expected shape is [batch_size] + component_shapes_[i]
+ const TensorShape expected_shape = ManyOutShape(i, batch_size);
+ if (!tuple[i].shape().IsSameSize(expected_shape)) {
+ return errors::InvalidArgument(
+ "Shape mismatch in tuple component ", i, ". Expected ",
+ expected_shape.ShortDebugString(), ", got ",
+ tuple[i].shape().ShortDebugString());
+ }
+ }
+ } else {
+ for (size_t i = 1; i < tuple.size(); ++i) {
+ if (tuple[i].dim_size(0) != batch_size) {
+ return errors::InvalidArgument(
+ "All input tensors must have the same size in the 0th ",
+ "dimension. Component ", i, " has ", tuple[i].dim_size(0),
+ ", and should have ", batch_size);
+ }
+ }
+ }
+ return Status::OK();
+}
+
+void FIFOQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
+ DCHECK_GT(queues_[0].size(), 0);
+ (*tuple).reserve(num_components());
+ for (int i = 0; i < num_components(); ++i) {
+ (*tuple).push_back(*queues_[i][0].AccessTensor(ctx));
+ queues_[i].pop_front();
+ }
+}
+
+void FIFOQueue::Cancel(Action action, CancellationToken token) {
+ DoneCallback callback = nullptr;
+ {
+ mutex_lock lock(mu_);
+ std::deque<Attempt>* attempts =
+ action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
+
+ for (Attempt& attempt : *attempts) {
+ if (attempt.cancellation_token == token) {
+ attempt.is_cancelled = true;
+ if (action == kEnqueue) {
+ attempt.context->SetStatus(
+ errors::Cancelled("Enqueue operation was cancelled"));
+ } else {
+ attempt.context->SetStatus(
+ errors::Cancelled("Dequeue operation was cancelled"));
+ }
+ std::swap(callback, attempt.done_callback);
+ break;
+ }
+ }
+ }
+ if (callback) {
+ callback();
+ FlushUnlocked();
+ }
+}
+
+void FIFOQueue::CloseAndCancel() {
+ std::vector<DoneCallback> callbacks;
+ {
+ mutex_lock lock(mu_);
+ closed_ = true;
+ for (Attempt& attempt : enqueue_attempts_) {
+ attempt.is_cancelled = true;
+ attempt.context->SetStatus(
+ errors::Cancelled("Enqueue operation was cancelled"));
+ callbacks.emplace_back(std::move(attempt.done_callback));
+ }
+ }
+ for (const DoneCallback& callback : callbacks) {
+ callback();
+ }
+ FlushUnlocked();
+}
+
+bool FIFOQueue::TryAttemptLocked(Action action,
+ std::vector<CleanUp>* clean_up) {
+ std::deque<Attempt>* attempts =
+ action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
+
+ bool progress = false;
+ bool done = false;
+ while (!done && !attempts->empty()) {
+ if (attempts->front().is_cancelled) {
+ if (action == kEnqueue) {
+ LOG(INFO) << "Skipping cancelled enqueue attempt";
+ } else {
+ LOG(INFO) << "Skipping cancelled dequeue attempt";
+ }
+ attempts->pop_front();
+ } else {
+ Attempt* cur_attempt = &attempts->front();
+ switch (cur_attempt->run_callback(cur_attempt)) {
+ case kNoProgress:
+ done = true;
+ break;
+ case kProgress:
+ done = true;
+ progress = true;
+ break;
+ case kComplete:
+ progress = true;
+ clean_up->emplace_back(std::move(cur_attempt->done_callback),
+ cur_attempt->cancellation_token,
+ cur_attempt->context->cancellation_manager());
+ attempts->pop_front();
+ break;
+ }
+ }
+ }
+ return progress;
+}
+
+void FIFOQueue::FlushUnlocked() {
+ std::vector<CleanUp> clean_up;
+ Ref();
+ {
+ mutex_lock lock(mu_);
+ bool changed;
+ do {
+ changed = TryAttemptLocked(kEnqueue, &clean_up);
+ changed = TryAttemptLocked(kDequeue, &clean_up) || changed;
+ } while (changed);
+ }
+ Unref();
+ for (const auto& to_clean : clean_up) {
+ if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
+ // NOTE(mrry): We can safely ignore the return value of
+ // DeregisterCallback because the mutex mu_ ensures that the
+ // cleanup action only executes once.
+ to_clean.cm->DeregisterCallback(to_clean.to_deregister);
+ }
+ to_clean.finished();
+ }
+}
+
+void FIFOQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
+ DoneCallback callback) {
+ CancellationManager* cm = ctx->cancellation_manager();
+ CancellationToken token = cm->get_cancellation_token();
+ bool already_cancelled;
+ {
+ mutex_lock l(mu_);
+ already_cancelled = !cm->RegisterCallback(
+ token, [this, token]() { Cancel(kEnqueue, token); });
+ if (!already_cancelled) {
+ enqueue_attempts_.emplace_back(
+ 1, callback, ctx, token,
+ [tuple, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (closed_) {
+ attempt->context->SetStatus(
+ errors::Aborted("FIFOQueue '", name_, "' is closed."));
+ return kComplete;
+ }
+ if (queues_[0].size() < static_cast<size_t>(capacity_)) {
+ for (int i = 0; i < num_components(); ++i) {
+ queues_[i].push_back(PersistentTensor(tuple[i]));
+ }
+ return kComplete;
+ } else {
+ return kNoProgress;
+ }
+ });
+ }
+ }
+ if (!already_cancelled) {
+ FlushUnlocked();
+ } else {
+ ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
+ callback();
+ }
+}
+
+/* static */
+Status FIFOQueue::GetElementComponentFromBatch(const FIFOQueue::Tuple& tuple,
+ int index, int component,
+ OpKernelContext* ctx,
+ PersistentTensor* out_tensor) {
+ TensorShape element_shape(tuple[component].shape());
+ element_shape.RemoveDim(0);
+ Tensor* element_access = nullptr;
+ TF_RETURN_IF_ERROR(ctx->allocate_persistent(
+ tuple[component].dtype(), element_shape, out_tensor, &element_access));
+ TF_RETURN_IF_ERROR(
+ CopySliceToElement(tuple[component], element_access, index));
+ return Status::OK();
+}
+
+void FIFOQueue::TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
+ DoneCallback callback) {
+ const int64 batch_size = tuple[0].dim_size(0);
+ if (batch_size == 0) {
+ callback();
+ return;
+ }
+
+ CancellationManager* cm = ctx->cancellation_manager();
+ CancellationToken token = cm->get_cancellation_token();
+ bool already_cancelled;
+ {
+ mutex_lock l(mu_);
+ already_cancelled = !cm->RegisterCallback(
+ token, [this, token]() { Cancel(kEnqueue, token); });
+ if (!already_cancelled) {
+ enqueue_attempts_.emplace_back(
+ batch_size, callback, ctx, token,
+ [tuple, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (closed_) {
+ attempt->context->SetStatus(
+ errors::Aborted("FIFOQueue '", name_, "' is closed."));
+ return kComplete;
+ }
+ RunResult result = kNoProgress;
+ while (queues_[0].size() < static_cast<size_t>(capacity_)) {
+ result = kProgress;
+ const int index =
+ tuple[0].dim_size(0) - attempt->elements_requested;
+ for (int i = 0; i < num_components(); ++i) {
+ PersistentTensor element;
+ attempt->context->SetStatus(GetElementComponentFromBatch(
+ tuple, index, i, attempt->context, &element));
+ if (!attempt->context->status().ok()) return kComplete;
+ queues_[i].push_back(element);
+ }
+ --attempt->elements_requested;
+ if (attempt->elements_requested == 0) {
+ return kComplete;
+ }
+ }
+ return result;
+ });
+ }
+ }
+ if (!already_cancelled) {
+ FlushUnlocked();
+ } else {
+ ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
+ callback();
+ }
+}
+
+void FIFOQueue::TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) {
+ CancellationManager* cm = ctx->cancellation_manager();
+ CancellationToken token = cm->get_cancellation_token();
+ bool already_cancelled;
+ {
+ mutex_lock l(mu_);
+ already_cancelled = !cm->RegisterCallback(
+ token, [this, token]() { Cancel(kDequeue, token); });
+ if (!already_cancelled) {
+ // TODO(josh11b): This makes two copies of callback, avoid this if possible.
+ dequeue_attempts_.emplace_back(
+ 1, [callback]() { callback(Tuple()); }, ctx, token,
+ [callback, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ const int32 s = queues_[0].size();
+ if (closed_ && s == 0) {
+ attempt->context->SetStatus(errors::OutOfRange(
+ "FIFOQueue '", name_, "' is closed and has ",
+ "insufficient elements (requested ", 1, ", current size ", s,
+ ")"));
+ return kComplete;
+ }
+ if (s > 0) {
+ Tuple tuple;
+ DequeueLocked(attempt->context, &tuple);
+ attempt->done_callback = [callback, tuple]() { callback(tuple); };
+ return kComplete;
+ } else {
+ return kNoProgress;
+ }
+ });
+ }
+ }
+ if (!already_cancelled) {
+ FlushUnlocked();
+ } else {
+ ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
+ callback(Tuple());
+ }
+}
+
+void FIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ CallbackWithTuple callback) {
+ if (!specified_shapes()) {
+ ctx->SetStatus(
+ errors::InvalidArgument("FIFOQueue's DequeueMany requires the "
+ "components to have specified shapes."));
+ callback(Tuple());
+ return;
+ }
+ if (num_elements == 0) {
+ Tuple tuple;
+ tuple.reserve(num_components());
+ for (int i = 0; i < num_components(); ++i) {
+ // TODO(josh11b,misard): Switch to allocate_output(). Problem is
+ // this breaks the abstraction boundary since we don't *really*
+ // know if and how the Tensors in the tuple we pass to callback
+ // correspond to the outputs of *ctx. For example, the
+ // ReaderRead Op uses TryDequeue() to get a filename out of a
+ // queue that is used internally by the reader and is not
+ // associated with any output of the ReaderRead.
+ // mrry@ adds:
+ // Maybe we need to pass a std::function<Tensor*(...)> (or
+ // better signature) that calls the appropriate allocator
+ // function in addition to ctx? (Or support a shim Allocator
+ // that has an internal OpKernelContext*, and dispatches to the
+ // appropriate method?)
+ // misard@ adds:
+ // I don't see that a std::function would help. The problem is
+ // that at this point (allocation time) the system doesn't know
+ // what is going to happen to the element read out of the
+ // queue. As long as we keep the generality that TensorFlow Ops
+ // do their own dynamic allocation in arbitrary C++ code, we
+ // need to preserve robustness to allocating output Tensors with
+ // the 'wrong' attributes, and fixing up with a copy. The only
+ // improvement I can see here in the future would be to support
+ // an optimized case where the queue 'knows' what attributes to
+ // use, and plumbs them through here.
+ Tensor element;
+ ctx->allocate_temp(component_dtypes_[i], ManyOutShape(i, 0), &element);
+ tuple.emplace_back(element);
+ }
+ callback(tuple);
+ return;
+ }
+
+ CancellationManager* cm = ctx->cancellation_manager();
+ CancellationToken token = cm->get_cancellation_token();
+ bool already_cancelled;
+ {
+ mutex_lock l(mu_);
+ already_cancelled = !cm->RegisterCallback(
+ token, [this, token]() { Cancel(kDequeue, token); });
+ if (!already_cancelled) {
+ // TODO(josh11b): This makes two copies of callback, avoid this if possible.
+ dequeue_attempts_.emplace_back(
+ num_elements, [callback]() { callback(Tuple()); }, ctx, token,
+ [callback, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int32 s = queues_[0].size();
+ if (closed_ && s < attempt->elements_requested) {
+ attempt->context->SetStatus(errors::OutOfRange(
+ "FIFOQueue '", name_, "' is closed and has ",
+ "insufficient elements (requested ",
+ attempt->elements_requested, ", current size ", s, ")"));
+
+ // TODO(mrry): Add support for producing a partial batch as
+ // output when the queue is closed.
+ if (!attempt->tuple.empty()) {
+ // Restore already-dequeued elements to the front of the queue.
+ for (int64 i = attempt->tuple[0].dim_size(0) -
+ attempt->elements_requested - 1;
+ i >= 0; --i) {
+ for (int j = 0; j < num_components(); ++j) {
+ PersistentTensor element;
+ Status s = GetElementComponentFromBatch(
+ attempt->tuple, i, j, attempt->context, &element);
+ if (!s.ok()) {
+ attempt->context->SetStatus(
+ errors::DataLoss("Failed to restore element from "
+ "partially-dequeued batch "
+ "to FIFOQueue"));
+ }
+ queues_[j].push_front(element);
+ }
+ }
+ }
+ return kComplete;
+ }
+
+ RunResult result = kNoProgress;
+ for (; s > 0; --s) {
+ if (attempt->tuple.empty()) {
+ // Only allocate tuple when we have something to dequeue
+ // so we don't use exceessive memory when there are many
+ // blocked dequeue attempts waiting.
+ attempt->tuple.reserve(num_components());
+ for (int i = 0; i < num_components(); ++i) {
+ const TensorShape shape =
+ ManyOutShape(i, attempt->elements_requested);
+ Tensor element;
+ attempt->context->allocate_temp(component_dtypes_[i], shape,
+ &element);
+ attempt->tuple.emplace_back(element);
+ }
+ }
+ result = kProgress;
+ Tuple tuple;
+ DequeueLocked(attempt->context, &tuple);
+ const int index =
+ attempt->tuple[0].dim_size(0) - attempt->elements_requested;
+ for (int i = 0; i < num_components(); ++i) {
+ attempt->context->SetStatus(
+ CopyElementToSlice(tuple[i], &attempt->tuple[i], index));
+ if (!attempt->context->status().ok()) return kComplete;
+ }
+ tuple.clear();
+ --attempt->elements_requested;
+ if (attempt->elements_requested == 0) {
+ tuple = attempt->tuple;
+ attempt->done_callback = [callback, tuple]() {
+ callback(tuple);
+ };
+ return kComplete;
+ }
+ }
+ return result;
+ });
+ }
+ }
+ if (!already_cancelled) {
+ FlushUnlocked();
+ } else {
+ ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
+ callback(Tuple());
+ }
+}
+
+void FIFOQueue::Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
+ DoneCallback callback) {
+ if (cancel_pending_enqueues) {
+ CloseAndCancel();
+ callback();
+ } else {
+ {
+ mutex_lock lock(mu_);
+ enqueue_attempts_.emplace_back(
+ 0, callback, ctx, CancellationManager::kInvalidToken,
+ [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (closed_) {
+ attempt->context->SetStatus(errors::Aborted(
+ "FIFOQueue '", name_, "' is already closed."));
+ } else {
+ closed_ = true;
+ }
+ return kComplete;
+ });
+ }
+ FlushUnlocked();
+ }
+}
+
+Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
+ TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "FIFOQueue"));
+ TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
+ TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def));
+ TF_RETURN_IF_ERROR(MatchesNodeDefShapes(node_def));
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h
new file mode 100644
index 0000000000..e9fe5f34a4
--- /dev/null
+++ b/tensorflow/core/kernels/fifo_queue.h
@@ -0,0 +1,127 @@
+#ifndef TENSORFLOW_KERNELS_FIFO_QUEUE_H_
+#define TENSORFLOW_KERNELS_FIFO_QUEUE_H_
+
+#include <deque>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/queue_base.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+class FIFOQueue : public QueueBase {
+ public:
+ FIFOQueue(int32 capacity, const DataTypeVector& component_dtypes,
+ const std::vector<TensorShape>& component_shapes,
+ const string& name);
+ Status Initialize(); // Must be called before any other method.
+
+ // Implementations of QueueInterface methods --------------------------------
+
+ Status ValidateTuple(const Tuple& tuple) override;
+ Status ValidateManyTuple(const Tuple& tuple) override;
+ void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
+ DoneCallback callback) override;
+ void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
+ DoneCallback callback) override;
+ void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
+ void TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ CallbackWithTuple callback) override;
+ void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
+ DoneCallback callback) override;
+ Status MatchesNodeDef(const NodeDef& node_def) override;
+
+ int32 size() override {
+ mutex_lock lock(mu_);
+ return queues_[0].size();
+ }
+
+ int32 capacity() const { return capacity_; }
+
+ private:
+ enum Action { kEnqueue, kDequeue };
+
+ ~FIFOQueue() override {}
+
+ TensorShape ManyOutShape(int i, int64 batch_size) {
+ TensorShape shape({batch_size});
+ shape.AppendShape(component_shapes_[i]);
+ return shape;
+ }
+
+ // Helper for dequeuing a single element from queues_.
+ void DequeueLocked(OpKernelContext* ctx, Tuple* tuple)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ void Cancel(Action action, CancellationToken token);
+
+ // Helper for cancelling all pending Enqueue(Many) operations when
+ // Close is called with cancel_pending_enqueues.
+ void CloseAndCancel();
+
+ // Tries to enqueue/dequeue (or close) based on whatever is at the
+ // front of enqueue_attempts_/dequeue_attempts_. Appends to
+ // *finished the callback for any finished attempt (so it may be
+ // called once mu_ is released). Returns true if any progress was
+ // made.
+ struct CleanUp {
+ CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
+ : finished(f), to_deregister(ct), cm(cm) {}
+ DoneCallback finished;
+ CancellationToken to_deregister;
+ CancellationManager* cm;
+ };
+ bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Tries to make progress on the enqueues or dequeues at the front
+ // of the *_attempts_ queues.
+ void FlushUnlocked();
+
+ const int32 capacity_;
+
+ mutex mu_;
+ typedef std::deque<PersistentTensor> SubQueue;
+ std::vector<SubQueue> queues_ GUARDED_BY(mu_);
+ bool closed_ GUARDED_BY(mu_);
+
+ enum RunResult { kNoProgress, kProgress, kComplete };
+ struct Attempt;
+ typedef std::function<RunResult(Attempt*)> RunCallback;
+ struct Attempt {
+ int32 elements_requested;
+ DoneCallback done_callback; // must be run outside mu_
+ OpKernelContext* context;
+ CancellationToken cancellation_token;
+ RunCallback run_callback; // must be run while holding mu_
+ bool is_cancelled;
+ Tuple tuple;
+
+ Attempt(int32 elements_requested, DoneCallback done_callback,
+ OpKernelContext* context, CancellationToken cancellation_token,
+ RunCallback run_callback)
+ : elements_requested(elements_requested),
+ done_callback(done_callback),
+ context(context),
+ cancellation_token(cancellation_token),
+ run_callback(run_callback),
+ is_cancelled(false) {}
+ };
+ std::deque<Attempt> enqueue_attempts_ GUARDED_BY(mu_);
+ std::deque<Attempt> dequeue_attempts_ GUARDED_BY(mu_);
+
+ static Status GetElementComponentFromBatch(const Tuple& tuple, int index,
+ int component,
+ OpKernelContext* ctx,
+ PersistentTensor* out_element);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueue);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_FIFO_QUEUE_H_
diff --git a/tensorflow/core/kernels/fifo_queue_op.cc b/tensorflow/core/kernels/fifo_queue_op.cc
new file mode 100644
index 0000000000..f1088181fe
--- /dev/null
+++ b/tensorflow/core/kernels/fifo_queue_op.cc
@@ -0,0 +1,93 @@
+// See docs in ../ops/data_flow_ops.cc.
+
+#include <deque>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/fifo_queue.h"
+#include "tensorflow/core/kernels/queue_base.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+// Defines a FIFOQueueOp, which produces a Queue (specifically, one
+// backed by FIFOQueue) that persists across different graph
+// executions, and sessions. Running this op produces a single-element
+// tensor of handles to Queues in the corresponding device.
+class FIFOQueueOp : public OpKernel {
+ public:
+ explicit FIFOQueueOp(OpKernelConstruction* context)
+ : OpKernel(context), queue_handle_set_(false) {
+ OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
+ OP_REQUIRES_OK(context,
+ context->allocate_persistent(DT_STRING, TensorShape({2}),
+ &queue_handle_, nullptr));
+ if (capacity_ < 0) {
+ capacity_ = FIFOQueue::kUnbounded;
+ }
+ OP_REQUIRES_OK(context,
+ context->GetAttr("component_types", &component_types_));
+ OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
+ }
+
+ ~FIFOQueueOp() override {
+ // If the queue object was not shared, delete it.
+ if (queue_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
+ TF_CHECK_OK(cinfo_.resource_manager()->Delete<QueueInterface>(
+ cinfo_.container(), cinfo_.name()));
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ mutex_lock l(mu_);
+ if (!queue_handle_set_) {
+ OP_REQUIRES_OK(ctx, SetQueueHandle(ctx));
+ }
+ ctx->set_output_ref(0, &mu_, queue_handle_.AccessTensor(ctx));
+ }
+
+ private:
+ Status SetQueueHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def()));
+ QueueInterface* queue;
+ auto creator = [this](QueueInterface** ret) {
+ FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
+ component_shapes_, cinfo_.name());
+ *ret = queue;
+ return queue->Initialize();
+ };
+ TF_RETURN_IF_ERROR(
+ cinfo_.resource_manager()->LookupOrCreate<QueueInterface>(
+ cinfo_.container(), cinfo_.name(), &queue, creator));
+ core::ScopedUnref unref_me(queue);
+ // Verify that the shared queue is compatible with the requested arguments.
+ TF_RETURN_IF_ERROR(queue->MatchesNodeDef(def()));
+ auto h = queue_handle_.AccessTensor(ctx)->flat<string>();
+ h(0) = cinfo_.container();
+ h(1) = cinfo_.name();
+ queue_handle_set_ = true;
+ return Status::OK();
+ }
+
+ int32 capacity_;
+ DataTypeVector component_types_;
+ std::vector<TensorShape> component_shapes_;
+ ContainerInfo cinfo_;
+
+ mutex mu_;
+ PersistentTensor queue_handle_ GUARDED_BY(mu_);
+ bool queue_handle_set_ GUARDED_BY(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("FIFOQueue").Device(DEVICE_CPU), FIFOQueueOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/fill_functor.h b/tensorflow/core/kernels/fill_functor.h
new file mode 100644
index 0000000000..831f0c899e
--- /dev/null
+++ b/tensorflow/core/kernels/fill_functor.h
@@ -0,0 +1,26 @@
+#ifndef TENSORFLOW_KERNELS_FILL_FUNCTOR_H_
+#define TENSORFLOW_KERNELS_FILL_FUNCTOR_H_
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T>
+struct FillFunctor {
+ // Computes on device "d": out = out.constant(in(0)),
+ void operator()(const Device& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstScalar in);
+};
+
+template <typename Device, typename T>
+struct SetZeroFunctor {
+ // Computes on device "d": out = out.setZero(),
+ void operator()(const Device& d, typename TTypes<T>::Flat out);
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_FILL_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/fixed_length_record_reader_op.cc b/tensorflow/core/kernels/fixed_length_record_reader_op.cc
new file mode 100644
index 0000000000..77516ab151
--- /dev/null
+++ b/tensorflow/core/kernels/fixed_length_record_reader_op.cc
@@ -0,0 +1,109 @@
+// See docs in ../ops/io_ops.cc.
+
+#include <memory>
+#include "tensorflow/core/framework/reader_op_kernel.h"
+#include "tensorflow/core/kernels/reader_base.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+
+class FixedLengthRecordReader : public ReaderBase {
+ public:
+ FixedLengthRecordReader(const string& node_name, int64 header_bytes,
+ int64 record_bytes, int64 footer_bytes, Env* env)
+ : ReaderBase(
+ strings::StrCat("FixedLengthRecordReader '", node_name, "'")),
+ header_bytes_(header_bytes),
+ record_bytes_(record_bytes),
+ footer_bytes_(footer_bytes),
+ env_(env),
+ file_pos_limit_(-1),
+ record_number_(0) {}
+
+ // On success:
+ // * input_buffer_ != nullptr,
+ // * input_buffer_->Tell() == footer_bytes_
+ // * file_pos_limit_ == file size - header_bytes_
+ Status OnWorkStartedLocked() override {
+ record_number_ = 0;
+ uint64 file_size = 0;
+ TF_RETURN_IF_ERROR(env_->GetFileSize(current_work(), &file_size));
+ file_pos_limit_ = file_size - footer_bytes_;
+
+ RandomAccessFile* file = nullptr;
+ TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file));
+ input_buffer_.reset(new io::InputBuffer(file, kBufferSize));
+ TF_RETURN_IF_ERROR(input_buffer_->SkipNBytes(header_bytes_));
+ return Status::OK();
+ }
+
+ Status OnWorkFinishedLocked() override {
+ input_buffer_.reset(nullptr);
+ return Status::OK();
+ }
+
+ Status ReadLocked(string* key, string* value, bool* produced,
+ bool* at_end) override {
+ if (input_buffer_->Tell() >= file_pos_limit_) {
+ *at_end = true;
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(input_buffer_->ReadNBytes(record_bytes_, value));
+ *key = strings::StrCat(current_work(), ":", record_number_);
+ *produced = true;
+ ++record_number_;
+ return Status::OK();
+ }
+
+ Status ResetLocked() override {
+ file_pos_limit_ = -1;
+ record_number_ = 0;
+ input_buffer_.reset(nullptr);
+ return ReaderBase::ResetLocked();
+ }
+
+ // TODO(josh11b): Implement serializing and restoring the state.
+
+ private:
+ enum { kBufferSize = 256 << 10 /* 256 kB */ };
+ const int64 header_bytes_;
+ const int64 record_bytes_;
+ const int64 footer_bytes_;
+ Env* const env_;
+ int64 file_pos_limit_;
+ int64 record_number_;
+ std::unique_ptr<io::InputBuffer> input_buffer_;
+};
+
+class FixedLengthRecordReaderOp : public ReaderOpKernel {
+ public:
+ explicit FixedLengthRecordReaderOp(OpKernelConstruction* context)
+ : ReaderOpKernel(context) {
+ int64 header_bytes = -1, record_bytes = -1, footer_bytes = -1;
+ OP_REQUIRES_OK(context, context->GetAttr("header_bytes", &header_bytes));
+ OP_REQUIRES_OK(context, context->GetAttr("record_bytes", &record_bytes));
+ OP_REQUIRES_OK(context, context->GetAttr("footer_bytes", &footer_bytes));
+ OP_REQUIRES(context, header_bytes >= 0,
+ errors::InvalidArgument("header_bytes must be >= 0 not ",
+ header_bytes));
+ OP_REQUIRES(context, record_bytes >= 0,
+ errors::InvalidArgument("record_bytes must be >= 0 not ",
+ record_bytes));
+ OP_REQUIRES(context, footer_bytes >= 0,
+ errors::InvalidArgument("footer_bytes must be >= 0 not ",
+ footer_bytes));
+ Env* env = context->env();
+ SetReaderFactory([this, header_bytes, record_bytes, footer_bytes, env]() {
+ return new FixedLengthRecordReader(name(), header_bytes, record_bytes,
+ footer_bytes, env);
+ });
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReader").Device(DEVICE_CPU),
+ FixedLengthRecordReaderOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc
new file mode 100644
index 0000000000..8bd48f26d6
--- /dev/null
+++ b/tensorflow/core/kernels/gather_op.cc
@@ -0,0 +1,136 @@
+// See docs in ../ops/array_ops.cc.
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+namespace {
+template <typename T, typename Index, int static_slice_elems>
+void HandleCopies(const Tensor& Tparams,
+ typename TTypes<Index>::ConstVec& Tindices, int slice_elems,
+ typename TTypes<T>::Matrix Tout) {
+ const int N = Tindices.dimension(0);
+ const auto& Tparams_flat = Tparams.flat_outer_dims<T>();
+ T* Tout_base = &Tout(0, 0);
+ const T* Tparams_base = &Tparams_flat(0, 0);
+ const size_t slice_bytes = slice_elems * sizeof(T);
+ if (static_slice_elems >= 0) {
+ // Give compiler static knowledge of the number of elements/bytes
+ CHECK_EQ(static_slice_elems, slice_elems);
+ slice_elems = static_slice_elems;
+ }
+ for (int i = 0; i < N; i++) {
+ int j = i + 1;
+ if (j < N) {
+ port::prefetch<port::PREFETCH_HINT_T0>(&Tparams_flat(Tindices(j), 0));
+ port::prefetch<port::PREFETCH_HINT_T0>(&Tout(j, 0));
+ }
+ memcpy(Tout_base + i * slice_elems,
+ Tparams_base + Tindices(i) * slice_elems, slice_bytes);
+ }
+}
+
+} // anonymous namespace
+
+template <typename T, typename Index>
+class GatherOp : public OpKernel {
+ public:
+ // QUESTION: It'd be nice to support DT_INT16, DT_UINT8,
+ // etc. here for the type of the second input argument. Should
+ // we have the framework do some sort of integer promotion
+ // automatically, or should that be something that users have to
+ // do explicitly with a conversion operator in the graph?
+ explicit GatherOp(OpKernelConstruction* c) : OpKernel(c) {
+ const DataType dt = DataTypeToEnum<T>::v();
+ const DataType index_t = DataTypeToEnum<Index>::v();
+ OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt}));
+ }
+
+ void Compute(OpKernelContext* c) override {
+ const Tensor& Tparams = c->input(0);
+ const Tensor& Tindices = c->input(1);
+ OP_REQUIRES(
+ c, TensorShapeUtils::IsVectorOrHigher(Tparams.shape()),
+ errors::InvalidArgument("params must be at least 1 dimensional"));
+ const int64 N = Tindices.NumElements();
+ const int64 first_dim_size = Tparams.dim_size(0);
+
+ // Validate all the indices are in range
+ auto Tindices_vec = Tindices.flat<Index>();
+ for (int64 i = 0; i < N; i++) {
+ const Index index = Tindices_vec(i);
+ OP_REQUIRES(c, index >= 0 && index < first_dim_size,
+ errors::InvalidArgument(
+ strings::StrCat("Index ", index, " at offset ", i,
+ " in Tindices is out of range")));
+ }
+
+ // The result shape is indices.shape + params.shape[1:].
+ TensorShape result_shape = Tindices.shape();
+ for (int i = 1; i < Tparams.dims(); i++) {
+ result_shape.AddDim(Tparams.dim_size(i));
+ }
+
+ Tensor* Tout = nullptr;
+ OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &Tout));
+ const auto& Tparams_flat = Tparams.flat_outer_dims<T>();
+ if (N > 0) {
+ auto Tindices_flat = Tindices.flat<Index>();
+ auto Tout_flat = Tout->shaped<T, 2>({N, Tout->NumElements() / N});
+ if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
+ const int64 slice_size = Tout->NumElements() / N;
+#define SPECIALIZE(elems) \
+ do { \
+ if (slice_size == elems) { \
+ HandleCopies<T, Index, elems>(Tparams, Tindices_flat, slice_size, \
+ Tout_flat); \
+ return; \
+ } \
+ } while (0)
+
+ SPECIALIZE(10);
+ SPECIALIZE(20);
+
+#undef SPECIALIZE
+
+ HandleCopies<T, Index, -1>(Tparams, Tindices_flat, slice_size,
+ Tout_flat);
+ } else {
+ for (int i = 0; i < N; i++) {
+ int j = i + 1;
+ if (j < N) {
+ port::prefetch<port::PREFETCH_HINT_T0>(
+ &Tparams_flat(Tindices_vec(j), 0));
+ port::prefetch<port::PREFETCH_HINT_T0>(&Tout_flat(j, 0));
+ }
+ // Copy last Ndim-1 dimensions of Tparams[Tindices[i]] to Tout[i]
+ Tout_flat.template chip<0>(i) =
+ Tparams_flat.template chip<0>(Tindices_vec(i));
+ }
+ }
+ }
+ }
+};
+
+#define REGISTER_GATHER(type, index_type) \
+ REGISTER_KERNEL_BUILDER(Name("Gather") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("Tparams") \
+ .TypeConstraint<index_type>("Tindices"), \
+ GatherOp<type, index_type>)
+
+#define REGISTER_GATHER_INT32(type) REGISTER_GATHER(type, int32)
+#define REGISTER_GATHER_INT64(type) REGISTER_GATHER(type, int64)
+
+TF_CALL_ALL_TYPES(REGISTER_GATHER_INT32);
+TF_CALL_ALL_TYPES(REGISTER_GATHER_INT64);
+
+#undef REGISTER_GATHER_INT32
+#undef REGISTER_GATHER_INT64
+#undef REGISTER_GATHER
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc
new file mode 100644
index 0000000000..d7410169e1
--- /dev/null
+++ b/tensorflow/core/kernels/gather_op_test.cc
@@ -0,0 +1,213 @@
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+namespace {
+
+class GatherOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(DataType index_type) {
+ RequireDefaultOps();
+ ASSERT_OK(NodeDefBuilder("myop", "Gather")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(index_type))
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(GatherOpTest, ScalarIndices) {
+ MakeOp(DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({}), {3});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({}));
+ test::FillValues<float>(&expected, {3});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(GatherOpTest, Simple_TwoD32) {
+ MakeOp(DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
+ AddInputFromArray<int32>(TensorShape({4}), {0, 4, 0, 2});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 3}));
+ test::FillValues<float>(&expected, {0, 1, 2, 12, 13, 14, 0, 1, 2, 6, 7, 8});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(GatherOpTest, Simple_TwoD64) {
+ MakeOp(DT_INT64);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
+ AddInputFromArray<int64>(TensorShape({4}), {0, 4, 0, 2});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the output.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 3}));
+ test::FillValues<float>(&expected, {0, 1, 2, 12, 13, 14, 0, 1, 2, 6, 7, 8});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(GatherOpTest, HighRank) {
+ MakeOp(DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({4}), {0, 1, 2, 3});
+ AddInputFromArray<int32>(TensorShape({2, 3}), {1, 2, 0, 2, 3, 0});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the output
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
+ test::FillValues<float>(&expected, {1, 2, 0, 2, 3, 0});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(GatherOpTest, Error_IndexOutOfRange) {
+ MakeOp(DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
+ AddInputFromArray<int32>(TensorShape({4}), {0, 4, 99, 2});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("Index 99 at offset 2 in Tindices is out of range"))
+ << s;
+}
+
+class GatherOpForBenchmark : public GatherOpTest {
+ public:
+ void TestBody() override { // not used }
+ }
+ void PublicMakeOp(DataType index_type) { MakeOp(index_type); }
+};
+
+static const int kSorted = 0x8000; // Mask for arg to specify sorting vs. not
+
+template <typename Index>
+void BM_Gather(int iters, int arg) {
+ testing::StopTiming();
+
+ bool sorted = ((arg & kSorted) != 0);
+ int dim = arg & ~kSorted;
+
+ GatherOpForBenchmark t;
+ t.PublicMakeOp(DataTypeToEnum<Index>::v());
+ // Use a 512 MB table, regardless of dim
+ const int kRows = ((1 << 29) / sizeof(float)) / dim;
+ std::vector<float> data(kRows * dim, 1.0f);
+ t.AddInputFromArray<float>(TensorShape({kRows, dim}), data);
+ const int kLookups = 2000;
+ const int kBatches = 1000000 / kLookups;
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ std::vector<std::vector<Index>> all_ids(kBatches);
+ for (int i = 0; i < kBatches; ++i) {
+ std::vector<Index>* ids = &all_ids[i];
+ ids->resize(kLookups);
+ for (int j = 0; j < kLookups; ++j) {
+ (*ids)[j] = rnd.Uniform(kRows);
+ }
+ if (sorted) {
+ sort(ids->begin(), ids->end());
+ }
+ }
+
+ t.AddInput<Index>(TensorShape({kLookups}), [](int i) { return 0; });
+ if (sorted) {
+ testing::SetLabel("sorted by id");
+ }
+ testing::BytesProcessed(static_cast<int64>(iters) * kLookups * dim *
+ sizeof(float));
+ testing::StartTiming();
+ while (--iters > 0) {
+ const std::vector<Index>& b = all_ids[iters % kBatches];
+ TensorValue input = t.mutable_input(1);
+ gtl::MutableArraySlice<Index> slice(&input->vec<Index>()(0),
+ input->NumElements());
+ for (int i = 0; i < kLookups; i++) {
+ slice[i] = b[i];
+ }
+ Status s = t.RunOpKernel();
+ }
+}
+
+static void BM_Gather32(int iters, int arg) { BM_Gather<int32>(iters, arg); }
+
+static void BM_Gather64(int iters, int arg) { BM_Gather<int64>(iters, arg); }
+
+BENCHMARK(BM_Gather32)
+ ->Arg(10)
+ ->Arg(10 | kSorted)
+ ->Arg(20)
+ ->Arg(40)
+ ->Arg(63)
+ ->Arg(63 | kSorted)
+ ->Arg(64)
+ ->Arg(64 | kSorted)
+ ->Arg(65)
+ ->Arg(65 | kSorted)
+ ->Arg(100)
+ ->Arg(100 | kSorted)
+ ->Arg(127)
+ ->Arg(127 | kSorted)
+ ->Arg(128)
+ ->Arg(128 | kSorted)
+ ->Arg(129)
+ ->Arg(129 | kSorted)
+ ->Arg(1000)
+ ->Arg(1000 | kSorted);
+
+BENCHMARK(BM_Gather64)
+ ->Arg(10)
+ ->Arg(10 | kSorted)
+ ->Arg(20)
+ ->Arg(40)
+ ->Arg(63)
+ ->Arg(63 | kSorted)
+ ->Arg(64)
+ ->Arg(64 | kSorted)
+ ->Arg(65)
+ ->Arg(65 | kSorted)
+ ->Arg(100)
+ ->Arg(100 | kSorted)
+ ->Arg(127)
+ ->Arg(127 | kSorted)
+ ->Arg(128)
+ ->Arg(128 | kSorted)
+ ->Arg(129)
+ ->Arg(129 | kSorted)
+ ->Arg(1000)
+ ->Arg(1000 | kSorted);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/identity_op.cc b/tensorflow/core/kernels/identity_op.cc
new file mode 100644
index 0000000000..b29efbddfb
--- /dev/null
+++ b/tensorflow/core/kernels/identity_op.cc
@@ -0,0 +1,45 @@
+// See docs in ../ops/array_ops.cc.
+#include "tensorflow/core/kernels/identity_op.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+REGISTER_KERNEL_BUILDER(Name("Identity").Device(DEVICE_CPU), IdentityOp);
+// StopGradient does the same thing as Identity, but has a different
+// gradient registered.
+REGISTER_KERNEL_BUILDER(Name("StopGradient").Device(DEVICE_CPU), IdentityOp);
+
+REGISTER_KERNEL_BUILDER(Name("RefIdentity").Device(DEVICE_CPU), IdentityOp);
+
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Identity").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ IdentityOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RefIdentity").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ IdentityOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("StopGradient").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ IdentityOp)
+
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+REGISTER_GPU_KERNEL(bool);
+REGISTER_GPU_KERNEL(bfloat16);
+
+#undef REGISTER_GPU_KERNEL
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Identity")
+ .Device(DEVICE_GPU)
+ .HostMemory("input")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T"),
+ IdentityOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/identity_op.h b/tensorflow/core/kernels/identity_op.h
new file mode 100644
index 0000000000..7adc1eace0
--- /dev/null
+++ b/tensorflow/core/kernels/identity_op.h
@@ -0,0 +1,25 @@
+#ifndef TENSORFLOW_KERNELS_IDENTITY_OP_H_
+#define TENSORFLOW_KERNELS_IDENTITY_OP_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+class IdentityOp : public OpKernel {
+ public:
+ explicit IdentityOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ if (IsRefType(context->input_dtype(0))) {
+ context->forward_ref_input_to_ref_output(0, 0);
+ } else {
+ context->set_output(0, context->input(0));
+ }
+ }
+
+ bool IsExpensive() override { return false; }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_IDENTITY_OP_H_
diff --git a/tensorflow/core/kernels/identity_op_test.cc b/tensorflow/core/kernels/identity_op_test.cc
new file mode 100644
index 0000000000..6483367a79
--- /dev/null
+++ b/tensorflow/core/kernels/identity_op_test.cc
@@ -0,0 +1,56 @@
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+class IdentityOpTest : public OpsTestBase {
+ protected:
+ Status Init(DataType input_type) {
+ RequireDefaultOps();
+ TF_CHECK_OK(NodeDefBuilder("op", "Identity")
+ .Input(FakeInput(input_type))
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(IdentityOpTest, Int32Success_6) {
+ ASSERT_OK(Init(DT_INT32));
+ AddInputFromArray<int32>(TensorShape({6}), {1, 2, 3, 4, 5, 6});
+ ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_INT32, TensorShape({6}));
+ test::FillValues<int32>(&expected, {1, 2, 3, 4, 5, 6});
+ test::ExpectTensorEqual<int32>(expected, *GetOutput(0));
+}
+
+TEST_F(IdentityOpTest, Int32Success_2_3) {
+ ASSERT_OK(Init(DT_INT32));
+ AddInputFromArray<int32>(TensorShape({2, 3}), {1, 2, 3, 4, 5, 6});
+ ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_INT32, TensorShape({2, 3}));
+ test::FillValues<int32>(&expected, {1, 2, 3, 4, 5, 6});
+ test::ExpectTensorEqual<int32>(expected, *GetOutput(0));
+}
+
+TEST_F(IdentityOpTest, StringSuccess) {
+ ASSERT_OK(Init(DT_STRING));
+ AddInputFromArray<string>(TensorShape({6}), {"A", "b", "C", "d", "E", "f"});
+ ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({6}));
+ test::FillValues<string>(&expected, {"A", "b", "C", "d", "E", "f"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(IdentityOpTest, RefInputError) { ASSERT_OK(Init(DT_INT32_REF)); }
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/identity_reader_op.cc b/tensorflow/core/kernels/identity_reader_op.cc
new file mode 100644
index 0000000000..a63fea5dbb
--- /dev/null
+++ b/tensorflow/core/kernels/identity_reader_op.cc
@@ -0,0 +1,57 @@
+// See docs in ../ops/io_ops.cc.
+
+#include <memory>
+#include "tensorflow/core/framework/reader_op_kernel.h"
+#include "tensorflow/core/kernels/reader_base.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+class IdentityReader : public ReaderBase {
+ public:
+ explicit IdentityReader(const string& node_name)
+ : ReaderBase(strings::StrCat("IdentityReader '", node_name, "'")) {}
+
+ Status ReadLocked(string* key, string* value, bool* produced,
+ bool* at_end) override {
+ *key = current_work();
+ *value = current_work();
+ *produced = true;
+ *at_end = true;
+ return Status::OK();
+ }
+
+ // Stores state in a ReaderBaseState proto, since IdentityReader has
+ // no additional state beyond ReaderBase.
+ Status SerializeStateLocked(string* state) override {
+ ReaderBaseState base_state;
+ SaveBaseState(&base_state);
+ base_state.SerializeToString(state);
+ return Status::OK();
+ }
+
+ Status RestoreStateLocked(const string& state) override {
+ ReaderBaseState base_state;
+ if (!ParseProtoUnlimited(&base_state, state)) {
+ return errors::InvalidArgument("Could not parse state for ", name(), ": ",
+ str_util::CEscape(state));
+ }
+ TF_RETURN_IF_ERROR(RestoreBaseState(base_state));
+ return Status::OK();
+ }
+};
+
+class IdentityReaderOp : public ReaderOpKernel {
+ public:
+ explicit IdentityReaderOp(OpKernelConstruction* context)
+ : ReaderOpKernel(context) {
+ SetReaderFactory([this]() { return new IdentityReader(name()); });
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("IdentityReader").Device(DEVICE_CPU),
+ IdentityReaderOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/in_topk_op.cc b/tensorflow/core/kernels/in_topk_op.cc
new file mode 100644
index 0000000000..d08f6f53da
--- /dev/null
+++ b/tensorflow/core/kernels/in_topk_op.cc
@@ -0,0 +1,58 @@
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+template <typename T>
+class InTopK : public OpKernel {
+ public:
+ explicit InTopK(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const auto& predictions_in = context->input(0);
+ const auto& targets_in = context->input(1);
+ OP_REQUIRES(context, predictions_in.dims() == 2,
+ errors::InvalidArgument("predictions must be 2-dimensional"));
+ OP_REQUIRES(context, targets_in.dims() == 1,
+ errors::InvalidArgument("targets must be 1-dimensional"));
+ OP_REQUIRES(context, predictions_in.dim_size(0) == targets_in.dim_size(0),
+ errors::InvalidArgument("First dimension of predictions ",
+ predictions_in.dim_size(0),
+ " must match length of targets ",
+ targets_in.dim_size(0)));
+ const auto& predictions = predictions_in.matrix<T>();
+ const auto& targets = targets_in.vec<int>();
+
+ Tensor* t_out = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 0, TensorShape({targets_in.dim_size(0)}), &t_out));
+ auto out = t_out->vec<bool>();
+
+ const auto size = targets.size();
+ const auto num_classes = predictions.dimension(1);
+ for (int b = 0; b < size; b++) {
+ T target_prediction = predictions(b, targets(b));
+ int more_probable_classes = 0;
+ for (int i = 0; i < num_classes; ++i) {
+ if (predictions(b, i) > target_prediction) ++more_probable_classes;
+ }
+ out(b) = more_probable_classes < k_;
+ }
+ }
+
+ private:
+ int k_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("InTopK").Device(DEVICE_CPU), InTopK<float>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/initializable_lookup_table.cc b/tensorflow/core/kernels/initializable_lookup_table.cc
new file mode 100644
index 0000000000..7f8b070556
--- /dev/null
+++ b/tensorflow/core/kernels/initializable_lookup_table.cc
@@ -0,0 +1,41 @@
+#include "tensorflow/core/kernels/initializable_lookup_table.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace lookup {
+
+Status InitializableLookupTable::Find(const Tensor& keys, Tensor* values,
+ const Tensor& default_value) {
+ if (!is_initialized()) {
+ return errors::FailedPrecondition("Table not initialized.");
+ }
+ TF_RETURN_IF_ERROR(CheckFindArguments(keys, *values, default_value));
+ return DoFind(keys, values, default_value);
+}
+
+Status InitializableLookupTable::Initialize(InitTableIterator& iter) {
+ if (!iter.Valid()) {
+ return iter.status();
+ }
+ TF_RETURN_IF_ERROR(CheckKeyAndValueTensors(iter.keys(), iter.values()));
+
+ mutex_lock l(mu_);
+ if (is_initialized()) {
+ return errors::FailedPrecondition("Table already initialized.");
+ }
+
+ TF_RETURN_IF_ERROR(DoPrepare(iter.total_size()));
+ while (iter.Valid()) {
+ TF_RETURN_IF_ERROR(DoInsert(iter.keys(), iter.values()));
+ iter.Next();
+ }
+ if (!errors::IsOutOfRange(iter.status())) {
+ return iter.status();
+ }
+ is_initialized_ = true;
+ return Status::OK();
+}
+
+} // namespace lookup
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h
new file mode 100644
index 0000000000..651b491457
--- /dev/null
+++ b/tensorflow/core/kernels/initializable_lookup_table.h
@@ -0,0 +1,103 @@
+#ifndef TENSORFLOW_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
+#define TENSORFLOW_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
+
+#include "tensorflow/core/framework/lookup_interface.h"
+
+namespace tensorflow {
+namespace lookup {
+
+// Base class for lookup tables that require initialization.
+class InitializableLookupTable : public LookupInterface {
+ public:
+ class InitTableIterator;
+
+ // Performs batch lookups, for every element in the key tensor, Find returns
+ // the corresponding value into the values tensor.
+ // If an element is not present in the table, the given default value is used.
+ //
+ // For tables that require initialization, `Find` is available once the table
+ // is marked as initialized.
+ //
+ // Returns the following statuses:
+ // - OK: when the find finishes successfully.
+ // - FailedPrecondition: if the table is not initialized.
+ // - InvalidArgument: if any of the preconditions on the lookup key or value
+ // fails.
+ // - In addition, other implementations may provide another non-OK status
+ // specific to their failure modes.
+ Status Find(const Tensor& keys, Tensor* values,
+ const Tensor& default_value) final;
+
+ // Returns whether the table was initialized and is ready to serve lookups.
+ bool is_initialized() const { return is_initialized_; }
+
+ // Initializes the table from the given init table iterator.
+ //
+ // Atomically, this operation prepares the table, populates it with the given
+ // iterator, and mark the table as initialized.
+ //
+ // Returns the following statuses:
+ // - OK: when the initialization was successful.
+ // - InvalidArgument: if any of the preconditions on the lookup key or value
+ // fails.
+ // - FailedPrecondition: if the table is already initialized and
+ // fail_if_initialized is set to true.
+ // - In addition, other implementations may provide another non-OK status
+ // specific to their failure modes.
+ Status Initialize(InitTableIterator& iter);
+
+ // Basic iterator to initialize lookup tables.
+ // It yields a sequence of pairs of `keys()` and `values()` Tensors, so that
+ // the consumer may insert key-value pairs in batches.
+ //
+ // Then the iterator is exhausted, valid returns false and status returns
+ // Status::OutOfRange.
+ class InitTableIterator {
+ public:
+ InitTableIterator() {}
+
+ virtual ~InitTableIterator() {}
+
+ // Prepares the next batch of key and value tensors.
+ virtual void Next() = 0;
+
+ // Returns true if keys and values point to valid tensors.
+ virtual bool Valid() const = 0;
+
+ // Returns a tensor that contains the current batch of 'key' values.
+ virtual const Tensor& keys() const = 0;
+
+ // Returns a tensor that contains the current batch of 'value' values.
+ virtual const Tensor& values() const = 0;
+
+ // Returns an error if one has occurred, otherwire returns Status::OK.
+ virtual Status status() const = 0;
+
+ // Returns the total number of elements that the iterator will produce.
+ virtual int64 total_size() const = 0;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(InitTableIterator);
+ };
+
+ protected:
+ // Prepares and allocates the underlying data structure to store the given
+ // number of expected elements.
+ virtual Status DoPrepare(size_t expected_num_elements) = 0;
+
+ // Populates the table in batches given keys and values as tensors into the
+ // underlying data structure.
+ virtual Status DoInsert(const Tensor& keys, const Tensor& values) = 0;
+
+ // Performs the batch find operation on the underlying data structure.
+ virtual Status DoFind(const Tensor& keys, Tensor* values,
+ const Tensor& default_value) = 0;
+
+ mutex mu_;
+ bool is_initialized_ = false;
+};
+
+} // namespace lookup
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
diff --git a/tensorflow/core/kernels/io.cc b/tensorflow/core/kernels/io.cc
new file mode 100644
index 0000000000..9d6921aa8e
--- /dev/null
+++ b/tensorflow/core/kernels/io.cc
@@ -0,0 +1,270 @@
+// See docs in ../ops/io_ops.cc
+#include <unordered_map>
+
+#include "tensorflow/core/kernels/io.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/util/tensor_slice_reader.h"
+#include "tensorflow/core/util/tensor_slice_reader_cache.h"
+#include "tensorflow/core/util/tensor_slice_writer.h"
+
+namespace tensorflow {
+
+namespace {
+bool ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape,
+ TensorSlice* slice, TensorShape* shape_slice,
+ string* error) {
+ CHECK(!shape_and_slice.empty());
+ // Syntax: dim0 dim1 dim2 ... <slice string>
+ // Where slice string is defined in core/framework/tensor_slice.h
+ std::vector<string> splits = str_util::Split(shape_and_slice, ' ');
+
+ // Must have at least 2 strings.
+ if (splits.size() < 2) {
+ *error = strings::StrCat(
+ "Need least two elements in shape_and_slice specification: ",
+ shape_and_slice);
+ return false;
+ }
+ int num_dims = splits.size() - 1;
+ shape->Clear();
+ for (int i = 0; i < num_dims; ++i) {
+ int dim;
+ if (!str_util::NumericParse32(splits[i], &dim)) {
+ *error = strings::StrCat("Non numerical dimension in shape_and_slice: ",
+ shape_and_slice);
+ return false;
+ }
+ shape->AddDim(dim);
+ }
+ // The last split is the slice specification.
+ slice->Clear();
+ auto status = slice->Parse(splits.back(), slice);
+ if (!status.ok()) {
+ *error = status.error_message();
+ return false;
+ }
+ // The specified slice must be compatible with the specified shape.
+ status = slice->SliceTensorShape(*shape, shape_slice);
+ if (!status.ok()) {
+ *error = status.error_message();
+ return false;
+ }
+ return true;
+}
+} // namespace
+
+void SaveTensors(
+ OpKernelContext* context,
+ checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
+ bool save_slices) {
+ const Tensor& filename_t = context->input(0);
+ {
+ const int64 size = filename_t.NumElements();
+ OP_REQUIRES(
+ context, size == 1,
+ errors::InvalidArgument(
+ "Input 0 (filename) must be a string scalar; got a tensor of ",
+ size, "elements"));
+ }
+
+ const Tensor& tensor_names_t = context->input(1);
+ const int64 N = tensor_names_t.NumElements();
+ const string* tensor_shapes_and_slices_ptr = nullptr;
+ if (save_slices) {
+ const Tensor& tensor_shapes_and_slices_t = context->input(2);
+ OP_REQUIRES(
+ context, tensor_shapes_and_slices_t.NumElements() == N,
+ errors::InvalidArgument("Expected ", N,
+ " elements for the tensor "
+ "shapes and slices but got ",
+ tensor_shapes_and_slices_t.NumElements()));
+ tensor_shapes_and_slices_ptr =
+ tensor_shapes_and_slices_t.flat<string>().data();
+ }
+ // Path, names, and slices if save_slices is true.
+ const int kFixedInputs = save_slices ? 3 : 2;
+ OP_REQUIRES(context, context->num_inputs() == N + kFixedInputs,
+ errors::InvalidArgument("Expected totally ", N + kFixedInputs,
+ " inputs as input #1 (which is a string "
+ "tensor of saved names) contains ",
+ N, " names, but received ",
+ context->num_inputs(), " inputs"));
+
+ VLOG(1) << "About to save tensors to file " << filename_t.flat<string>()(0)
+ << "...";
+ checkpoint::TensorSliceWriter writer(filename_t.flat<string>()(0),
+ builder_func);
+
+ Status s;
+ auto tensor_names_flat = tensor_names_t.flat<string>();
+
+ string error;
+ for (int64 i = 0; i < N; ++i) {
+ const string& name = tensor_names_flat(i);
+ const Tensor& input = context->input(i + kFixedInputs);
+ TensorShape shape(input.shape());
+ TensorSlice slice(input.dims());
+ if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) {
+ const string& shape_spec = tensor_shapes_and_slices_ptr[i];
+ TensorShape slice_shape;
+ OP_REQUIRES(context, ParseShapeAndSlice(shape_spec, &shape, &slice,
+ &slice_shape, &error),
+ errors::InvalidArgument(error));
+ OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()),
+ errors::InvalidArgument("Slice in shape_and_slice "
+ "specification does not match the "
+ "shape of the tensor to save: ",
+ shape_spec, ", tensor: ",
+ input.shape().DebugString()));
+ }
+
+#define WRITER_ADD(dt) \
+ case dt: \
+ s = writer.Add(name, shape, slice, \
+ input.flat<EnumToDataType<dt>::Type>().data()); \
+ break
+
+ switch (input.dtype()) {
+ WRITER_ADD(DT_FLOAT);
+ WRITER_ADD(DT_DOUBLE);
+ WRITER_ADD(DT_INT32);
+ WRITER_ADD(DT_UINT8);
+ WRITER_ADD(DT_INT16);
+ WRITER_ADD(DT_INT8);
+ WRITER_ADD(DT_INT64);
+ WRITER_ADD(DT_QUINT8);
+ WRITER_ADD(DT_QINT8);
+ WRITER_ADD(DT_QINT32);
+ default:
+ context->SetStatus(errors::Unimplemented("Saving data type ",
+ DataTypeString(input.dtype()),
+ " not yet supported"));
+ return;
+ }
+#undef WRITER_ADD
+ if (!s.ok()) {
+ context->SetStatus(s);
+ return;
+ }
+ }
+
+ s = writer.Finish();
+ if (!s.ok()) {
+ context->SetStatus(s);
+ }
+}
+
+void RestoreTensor(OpKernelContext* context,
+ checkpoint::TensorSliceReader::OpenTableFunction open_func,
+ int preferred_shard, bool restore_slice) {
+ const Tensor& file_pattern_t = context->input(0);
+ {
+ const int64 size = file_pattern_t.NumElements();
+ OP_REQUIRES(
+ context, size == 1,
+ errors::InvalidArgument(
+ "Input 0 (file_pattern) must be a string scalar; got a tensor of ",
+ size, "elements"));
+ }
+ const string& file_pattern = file_pattern_t.flat<string>()(0);
+
+ const Tensor& tensor_name_t = context->input(1);
+ {
+ const int64 size = tensor_name_t.NumElements();
+ OP_REQUIRES(
+ context, size == 1,
+ errors::InvalidArgument(
+ "Input 1 (tensor_name) must be a string scalar; got a tensor of ",
+ size, "elements"));
+ }
+ const string& tensor_name = tensor_name_t.flat<string>()(0);
+
+ const string* tensor_shape_and_slice_ptr = nullptr;
+ if (restore_slice) {
+ const Tensor& tensor_shape_and_slice_t = context->input(2);
+ OP_REQUIRES(
+ context, tensor_shape_and_slice_t.NumElements() == 1,
+ errors::InvalidArgument("Expected 1 element for the tensor "
+ "shape and slice but got ",
+ tensor_shape_and_slice_t.NumElements()));
+ tensor_shape_and_slice_ptr = tensor_shape_and_slice_t.flat<string>().data();
+ }
+
+ // If we cannot find a cached reader we will allocate our own.
+ std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
+
+ const checkpoint::TensorSliceReader* reader =
+ context->slice_reader_cache()->GetReader(file_pattern, open_func,
+ preferred_shard);
+ if (!reader) {
+ allocated_reader.reset(new checkpoint::TensorSliceReader(
+ file_pattern, open_func, preferred_shard));
+ reader = allocated_reader.get();
+ }
+ OP_REQUIRES_OK(context, CHECK_NOTNULL(reader)->status());
+
+ // Get the shape and type from the save file.
+ DataType type;
+ TensorShape saved_shape;
+ OP_REQUIRES(
+ context, reader->HasTensor(tensor_name, &saved_shape, &type),
+ errors::NotFound("Tensor name \"", tensor_name,
+ "\" not found in checkpoint files ", file_pattern));
+ OP_REQUIRES(
+ context, type == context->expected_output_dtype(0),
+ errors::InvalidArgument("Expected to restore a tensor of type ",
+ DataTypeString(context->expected_output_dtype(0)),
+ ", got a tensor of type ", DataTypeString(type),
+ " instead: tensor_name = ", tensor_name));
+
+ // Shape of the output and slice to load.
+ TensorShape output_shape(saved_shape);
+ TensorSlice slice_to_load(saved_shape.dims());
+ if (restore_slice && !tensor_shape_and_slice_ptr[0].empty()) {
+ const string& shape_spec = tensor_shape_and_slice_ptr[0];
+ TensorShape parsed_shape;
+ string error;
+ OP_REQUIRES(context,
+ ParseShapeAndSlice(shape_spec, &parsed_shape, &slice_to_load,
+ &output_shape, &error),
+ errors::InvalidArgument(error));
+ OP_REQUIRES(
+ context, parsed_shape.IsSameSize(saved_shape),
+ errors::InvalidArgument(
+ "Shape in shape_and_slice spec does not match the shape in the "
+ "save file: ",
+ parsed_shape.DebugString(), ", save file shape: ",
+ saved_shape.DebugString()));
+ }
+
+ Tensor* t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &t));
+#define READER_COPY(dt) \
+ case dt: \
+ reader->CopySliceData(tensor_name, slice_to_load, \
+ t->flat<EnumToDataType<dt>::Type>().data()); \
+ break
+
+ switch (type) {
+ READER_COPY(DT_FLOAT);
+ READER_COPY(DT_DOUBLE);
+ READER_COPY(DT_INT32);
+ READER_COPY(DT_UINT8);
+ READER_COPY(DT_INT16);
+ READER_COPY(DT_INT8);
+ READER_COPY(DT_INT64);
+ default:
+ context->SetStatus(errors::Unimplemented(
+ "Restoring data type ", DataTypeString(type), " not yet supported"));
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/io.h b/tensorflow/core/kernels/io.h
new file mode 100644
index 0000000000..7e548f1ad0
--- /dev/null
+++ b/tensorflow/core/kernels/io.h
@@ -0,0 +1,38 @@
+#ifndef TENSORFLOW_KERNELS_IO_H_
+#define TENSORFLOW_KERNELS_IO_H_
+
+#include "tensorflow/core/util/tensor_slice_reader.h"
+#include "tensorflow/core/util/tensor_slice_writer.h"
+
+namespace tensorflow {
+
+class OpKernelContext;
+
+// Save input tensors in *context to a writer built from builder_func().
+// context must have the following inputs:
+// 0: a single element string tensor that contains the file name.
+// 1: names for the remaining tensors
+// If save_slices is true:
+// 2: shape and slice specifications.
+// rest: tensors to save
+void SaveTensors(
+ OpKernelContext* context,
+ checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
+ bool save_slices);
+
+// Reads a tensor from the reader built from open_func() and produces it as
+// context->output(0). "preferred_shard" is the same the TensorSliceReader
+// preferred_shard parameter.
+//
+// context must have the following inputs:
+// 0: a single element string tensor that contains the file name.
+// 1: a single element string tensor that names the output to be restored.
+// If restore_slice is true:
+// 2: shape and slice specification of the tensor to restore.
+void RestoreTensor(OpKernelContext* context,
+ checkpoint::TensorSliceReader::OpenTableFunction open_func,
+ int preferred_shard, bool restore_slice);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_IO_H_
diff --git a/tensorflow/core/kernels/l2loss_op.cc b/tensorflow/core/kernels/l2loss_op.cc
new file mode 100644
index 0000000000..6f83f01676
--- /dev/null
+++ b/tensorflow/core/kernels/l2loss_op.cc
@@ -0,0 +1,69 @@
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/l2loss_op.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class L2LossOp : public OpKernel {
+ public:
+ explicit L2LossOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // The input tensor can be of any number of dimensions, even though it's
+ // 2D in most typical applications.
+ const Tensor& input = context->input(0);
+ // The output is a single number.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &output));
+ functor::L2Loss<Device, T>()(context->eigen_device<Device>(),
+ input.flat<T>(), output->scalar<T>());
+ }
+};
+
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("L2Loss").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ L2LossOp<CPUDevice, T>);
+
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void L2Loss<GPUDevice, T>::operator()(const GPUDevice& d, \
+ typename TTypes<T>::ConstTensor input, \
+ typename TTypes<T>::Scalar output); \
+ extern template struct L2Loss<GPUDevice, T>;
+
+DECLARE_GPU_SPEC(float);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER_GPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("L2Loss").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ L2LossOp<GPUDevice, T>);
+
+REGISTER_GPU_KERNEL(float);
+#undef REGISTER_GPU_KERNEL
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/l2loss_op.h b/tensorflow/core/kernels/l2loss_op.h
new file mode 100644
index 0000000000..d307353e24
--- /dev/null
+++ b/tensorflow/core/kernels/l2loss_op.h
@@ -0,0 +1,24 @@
+#ifndef TENSORFLOW_KERNELS_L2LOSS_OP_H_
+#define TENSORFLOW_KERNELS_L2LOSS_OP_H_
+// Functor definition for L2LossOp, must be compilable by nvcc.
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+// Functor used by L2LossOp to do the computations.
+template <typename Device, typename T>
+struct L2Loss {
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor input,
+ typename TTypes<T>::Scalar output) {
+ // We flatten the input tensor and reduce on dimension 0, producing
+ // a single number which is Mul(Sum(x^2), 0.5).
+ output.device(d) = input.square().sum() * static_cast<T>(0.5);
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_L2LOSS_OP_H_
diff --git a/tensorflow/core/kernels/l2loss_op_gpu.cu.cc b/tensorflow/core/kernels/l2loss_op_gpu.cu.cc
new file mode 100644
index 0000000000..858fcfe8d3
--- /dev/null
+++ b/tensorflow/core/kernels/l2loss_op_gpu.cu.cc
@@ -0,0 +1,16 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/l2loss_op.h"
+
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+template struct functor::L2Loss<GPUDevice, float>;
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/linalg_ops_common.cc b/tensorflow/core/kernels/linalg_ops_common.cc
new file mode 100644
index 0000000000..93342a7a24
--- /dev/null
+++ b/tensorflow/core/kernels/linalg_ops_common.cc
@@ -0,0 +1,99 @@
+#include "tensorflow/core/kernels/linalg_ops_common.h"
+
+namespace tensorflow {
+
+void LinearAlgebraOpBase::Compute(OpKernelContext* context) {
+ const Tensor& in = context->input(0);
+
+ const int input_rank = GetInputMatrixRank();
+ OP_REQUIRES(
+ context, input_rank == 2,
+ errors::InvalidArgument("Only matrix inputs are supported so far."));
+ if (SupportsBatchOperation()) {
+ OP_REQUIRES(context, in.dims() > input_rank,
+ errors::InvalidArgument("Input tensor must have rank >= %d",
+ input_rank + 1));
+ } else {
+ OP_REQUIRES(context, in.dims() == input_rank,
+ errors::InvalidArgument("Input tensor must have rank == %d",
+ input_rank));
+ }
+
+ // If the tensor rank is greater than input_rank, we consider the inner-most
+ // dimensions as matrices, and loop over all the other outer
+ // dimensions to compute the results.
+ // TODO(kalakris): Only matrix inputs are currently supported.
+ const int row_dimension = in.dims() - 2;
+ const int col_dimension = in.dims() - 1;
+ const int64 num_rows = in.dim_size(row_dimension);
+ const int64 num_cols = in.dim_size(col_dimension);
+ const TensorShape input_matrix_shape = TensorShape({num_rows, num_cols});
+ const TensorShape output_matrix_shape =
+ GetOutputMatrixShape(input_matrix_shape);
+ OP_REQUIRES(context, output_matrix_shape.dims() <= 2,
+ errors::InvalidArgument("Output rank must be 1 or 2."));
+
+ int num_matrices = 1;
+ // The output has the shape of all the outer dimensions of the input
+ // except for the last two, plus the output_matrix_shape (if the output
+ // is not scalar). This still assumes that each input matrix is
+ // 2-dimensional, in accordance with the TODO above.
+ TensorShape output_shape;
+ if (in.dims() == 2) {
+ output_shape = output_matrix_shape;
+ } else {
+ for (int dim = 0; dim <= in.dims() - 3; ++dim) {
+ num_matrices *= in.dim_size(dim);
+ output_shape.AddDim(in.dim_size(dim));
+ }
+ for (int dim = 0; dim < output_matrix_shape.dims(); ++dim) {
+ output_shape.AddDim(output_matrix_shape.dim_size(dim));
+ }
+ }
+
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &out));
+
+ auto shard = [this, &in, &input_matrix_shape, &output_matrix_shape, context,
+ out](int64 begin, int64 end) {
+ for (int64 i = begin; i < end; ++i) {
+ ComputeMatrix(context, i, in, input_matrix_shape, out,
+ output_matrix_shape);
+ }
+ };
+
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+ Shard(worker_threads.num_threads, worker_threads.workers, num_matrices,
+ GetCostPerUnit(input_matrix_shape), shard);
+}
+
+template <typename Scalar, bool SupportsBatchOperationT>
+void LinearAlgebraOp<Scalar, SupportsBatchOperationT>::ComputeMatrix(
+ OpKernelContext* context, int64 matrix_index, const Tensor& in,
+ const TensorShape& input_matrix_shape, Tensor* out,
+ const TensorShape& output_matrix_shape) {
+ // TODO(kalakris): Handle alignment if possible. Eigen::Map is
+ // unaligned by default.
+ ConstMatrixMap input(in.flat<Scalar>().data() +
+ matrix_index * input_matrix_shape.num_elements(),
+ input_matrix_shape.dim_size(0),
+ input_matrix_shape.dim_size(1));
+
+ // The output matrix shape may not be a matrix.
+ int num_output_rows =
+ output_matrix_shape.dims() >= 1 ? output_matrix_shape.dim_size(0) : 1;
+ int num_output_cols =
+ output_matrix_shape.dims() == 2 ? output_matrix_shape.dim_size(1) : 1;
+ MatrixMap output(out->flat<Scalar>().data() +
+ matrix_index * output_matrix_shape.num_elements(),
+ num_output_rows, num_output_cols);
+ ComputeMatrix(context, input, &output);
+}
+
+// Explicitly instantiate LinearAlgebraOp for the scalar types we expect to use.
+template class LinearAlgebraOp<float, false>;
+template class LinearAlgebraOp<float, true>;
+template class LinearAlgebraOp<double, false>;
+template class LinearAlgebraOp<double, true>;
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/linalg_ops_common.h b/tensorflow/core/kernels/linalg_ops_common.h
new file mode 100644
index 0000000000..471f11e25f
--- /dev/null
+++ b/tensorflow/core/kernels/linalg_ops_common.h
@@ -0,0 +1,123 @@
+#ifndef TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
+#define TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/util/work_sharder.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+// A base class to support linear algebra functionality, similar to the
+// numpy.linalg module. Supports batch computation on several matrices at once,
+// sharding the computations across different threads if necessary.
+//
+// TODO(kalakris): This needs to be expanded to support binary inputs, and
+// multiple outputs.
+class LinearAlgebraOpBase : public OpKernel {
+ public:
+ explicit LinearAlgebraOpBase(OpKernelConstruction* context)
+ : OpKernel(context) {}
+ ~LinearAlgebraOpBase() override {}
+
+ // Return the expected rank of the input.
+ // TODO(kalakris): This should be a virtual function to support vector inputs.
+ int GetInputMatrixRank() { return 2; }
+
+ // Return the output shape of each individual matrix operation. Must be
+ // rank 0, 1, or 2. Scalar outputs are rank 0.
+ virtual TensorShape GetOutputMatrixShape(
+ const TensorShape& input_matrix_shape) = 0;
+
+ // Return the cost per matrix operation. Cost per unit is assumed to be
+ // roughly 1ns, based on comments in core/util/work_sharder.cc.
+ virtual int64 GetCostPerUnit(const TensorShape& input_matrix_shape) = 0;
+
+ // If SupportsBatchOperation() returns false, this Op will only accept rank 2
+ // (if the supported input type is a matrix). If it returns true, the Op will
+ // accept inputs of rank >= 3, and repeatedly execute the operation on all
+ // matrices in the innermost two dimensions.
+ virtual bool SupportsBatchOperation() = 0;
+
+ // Perform the actual computation on an input matrix, and store the results
+ // in the output. This will be called repeatedly for a single call to
+ // Compute(), if multiple matrices exist in the input Tensor.
+ //
+ // This function should only compute the results for a single input matrix.
+ // The 'matrix_index' parameter specifies the index of the matrix to be used
+ // from the input, and the index of the matrix to be written to in the output.
+ // The input matrix is in row major order, and is located at the memory
+ // address
+ // in.flat<Scalar>().data() +
+ // matrix_index * input_matrix_shape.num_elements().
+ // The output matrix is in row major order, and is located at the memory
+ // address
+ // out->flat<Scalar>().data() +
+ // matrix_index * output_matrix_shape.num_elements().
+ // The LinearAlgebraOp<Scalar> class below has functionality which performs
+ // this mapping and presents an interface based on the Eigen::MatrixBase API.
+ virtual void ComputeMatrix(OpKernelContext* context, int64 matrix_index,
+ const Tensor& in,
+ const TensorShape& input_matrix_shape, Tensor* out,
+ const TensorShape& output_matrix_shape) = 0;
+
+ void Compute(OpKernelContext* context) override;
+};
+
+// A base class for linear algebra ops templated on the scalar type.
+//
+// This base class encapsulates the functionality of mapping the input and
+// output tensors using Eigen::Map, so that the Eigen::MatrixBase API may be
+// directly used by derived classes.
+// SupportsBatchOperationT is a bool template argument which if set to true
+// will allow the Op to process batches of matrices (rank >= 3); if set to
+// false the Op will only accept rank 2 inputs.
+template <typename Scalar, bool SupportsBatchOperationT>
+class LinearAlgebraOp : public LinearAlgebraOpBase {
+ public:
+ explicit LinearAlgebraOp(OpKernelConstruction* context)
+ : LinearAlgebraOpBase(context) {}
+
+ using ConstMatrixMap =
+ Eigen::Map<const Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic,
+ Eigen::RowMajor>>;
+ using MatrixMap = Eigen::Map<
+ Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
+
+ // Perform the actual computation on the input matrix, and store the results
+ // in the output. This will be called repeatedly for a single call to
+ // Compute(), if multiple matrices exist in the input Tensor.
+ virtual void ComputeMatrix(OpKernelContext* context,
+ const ConstMatrixMap& input,
+ MatrixMap* output) = 0;
+
+ bool SupportsBatchOperation() final { return SupportsBatchOperationT; }
+
+ // A concrete implementation of LinearAlgebraOpBase::ComputeMatrix().
+ void ComputeMatrix(OpKernelContext* context, int64 matrix_index,
+ const Tensor& in, const TensorShape& input_matrix_shape,
+ Tensor* out, const TensorShape& output_matrix_shape) final;
+};
+
+// Declare that LinearAlgebraOp is explicitly instantiated in
+// linalg_ops_common.cc for float and double.
+extern template class LinearAlgebraOp<float, false>;
+extern template class LinearAlgebraOp<float, true>;
+extern template class LinearAlgebraOp<double, false>;
+extern template class LinearAlgebraOp<double, true>;
+
+} // namespace tensorflow
+
+#define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \
+ REGISTER_KERNEL_BUILDER( \
+ Name(OpName).Device(DEVICE_CPU).TypeConstraint<Scalar>("T"), OpClass)
+
+#endif // TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
diff --git a/tensorflow/core/kernels/listdiff_op.cc b/tensorflow/core/kernels/listdiff_op.cc
new file mode 100644
index 0000000000..f490f5ddd3
--- /dev/null
+++ b/tensorflow/core/kernels/listdiff_op.cc
@@ -0,0 +1,75 @@
+#include <unordered_set>
+#include <utility>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+template <typename T>
+class ListDiffOp : public OpKernel {
+ public:
+ explicit ListDiffOp(OpKernelConstruction* context) : OpKernel(context) {
+ const DataType dt = DataTypeToEnum<T>::v();
+ OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt, DT_INT32}));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& x = context->input(0);
+ const Tensor& y = context->input(1);
+
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(x.shape()),
+ errors::InvalidArgument("x should be a 1D vector."));
+
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(y.shape()),
+ errors::InvalidArgument("y should be a 1D vector."));
+
+ std::unordered_set<T> y_set;
+ const auto Ty = y.vec<T>();
+ const int y_size = Ty.size();
+ y_set.reserve(y_size);
+ for (int i = 0; i < y_size; ++i) {
+ y_set.insert(Ty(i));
+ }
+
+ // Compute the size of the output.
+ const auto Tx = x.vec<T>();
+ const int x_size = Tx.size();
+
+ int out_size = 0;
+ for (int i = 0; i < x_size; ++i) {
+ if (y_set.count(Tx(i)) == 0) {
+ ++out_size;
+ }
+ }
+
+ // Allocate and populate outputs.
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, {out_size}, &out));
+ auto Tout = out->vec<T>();
+
+ Tensor* indices = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(1, {out_size}, &indices));
+ auto Tindices = indices->vec<int32>();
+
+ for (int i = 0, p = 0; i < x_size; ++i) {
+ if (y_set.count(Tx(i)) == 0) {
+ Tout(p) = Tx(i);
+ Tindices(p) = i;
+ p++;
+ }
+ }
+ }
+};
+
+#define REGISTER_LISTDIFF(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ListDiff").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ListDiffOp<type>)
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_LISTDIFF);
+#undef REGISTER_LISTDIFF
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc
new file mode 100644
index 0000000000..ec84145f75
--- /dev/null
+++ b/tensorflow/core/kernels/logging_ops.cc
@@ -0,0 +1,77 @@
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+class AssertOp : public OpKernel {
+ public:
+ explicit AssertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& cond = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(cond.shape()),
+ errors::InvalidArgument("In[0] should be a scalar: ",
+ cond.shape().ShortDebugString()));
+
+ if (cond.scalar<bool>()()) {
+ return;
+ }
+ string msg = "assertion failed: ";
+ for (int i = 1; i < ctx->num_inputs(); ++i) {
+ strings::StrAppend(&msg, "[", ctx->input(i).SummarizeValue(summarize_),
+ "]");
+ if (i < ctx->num_inputs() - 1) strings::StrAppend(&msg, " ");
+ }
+ ctx->SetStatus(errors::InvalidArgument(msg));
+ }
+
+ private:
+ int32 summarize_ = 0;
+};
+
+REGISTER_KERNEL_BUILDER(Name("Assert").Device(DEVICE_CPU), AssertOp);
+
+class PrintOp : public OpKernel {
+ public:
+ explicit PrintOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), call_counter_(0) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("message", &message_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("first_n", &first_n_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ if (IsRefType(ctx->input_dtype(0))) {
+ ctx->forward_ref_input_to_ref_output(0, 0);
+ } else {
+ ctx->set_output(0, ctx->input(0));
+ }
+ if (first_n_ >= 0) {
+ mutex_lock l(mu_);
+ if (call_counter_ >= first_n_) return;
+ call_counter_++;
+ }
+ string msg;
+ strings::StrAppend(&msg, message_);
+ for (int i = 1; i < ctx->num_inputs(); ++i) {
+ strings::StrAppend(&msg, "[", ctx->input(i).SummarizeValue(summarize_),
+ "]");
+ }
+ LOG(INFO) << msg;
+ }
+
+ private:
+ mutex mu_;
+ int64 call_counter_ GUARDED_BY(mu_) = 0;
+ int64 first_n_ = 0;
+ int32 summarize_ = 0;
+ string message_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("Print").Device(DEVICE_CPU), PrintOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/logging_ops_test.cc b/tensorflow/core/kernels/logging_ops_test.cc
new file mode 100644
index 0000000000..a7af6eb303
--- /dev/null
+++ b/tensorflow/core/kernels/logging_ops_test.cc
@@ -0,0 +1,87 @@
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace {
+
+class PrintingGraphTest : public OpsTestBase {
+ protected:
+ Status Init(DataType input_type1, DataType input_type2, string msg = "",
+ int first_n = -1, int summarize = 3) {
+ RequireDefaultOps();
+ TF_CHECK_OK(NodeDefBuilder("op", "Print")
+ .Input(FakeInput(input_type1))
+ .Input(FakeInput(2, input_type2))
+ .Attr("message", msg)
+ .Attr("first_n", first_n)
+ .Attr("summarize", summarize)
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(PrintingGraphTest, Int32Success_6) {
+ ASSERT_OK(Init(DT_INT32, DT_INT32));
+ AddInputFromArray<int32>(TensorShape({6}), {1, 2, 3, 4, 5, 6});
+ AddInputFromArray<int32>(TensorShape({6}), {1, 2, 3, 4, 5, 6});
+ AddInputFromArray<int32>(TensorShape({6}), {1, 2, 3, 4, 5, 6});
+ ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_INT32, TensorShape({6}));
+ test::FillValues<int32>(&expected, {1, 2, 3, 4, 5, 6});
+ test::ExpectTensorEqual<int32>(expected, *GetOutput(0));
+}
+
+TEST_F(PrintingGraphTest, Int32Success_Summarize6) {
+ ASSERT_OK(Init(DT_INT32, DT_INT32, "", -1, 6));
+ AddInputFromArray<int32>(TensorShape({6}), {1, 2, 3, 4, 5, 6});
+ AddInputFromArray<int32>(TensorShape({6}), {1, 2, 3, 4, 5, 6});
+ AddInputFromArray<int32>(TensorShape({6}), {1, 2, 3, 4, 5, 6});
+ ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_INT32, TensorShape({6}));
+ test::FillValues<int32>(&expected, {1, 2, 3, 4, 5, 6});
+ test::ExpectTensorEqual<int32>(expected, *GetOutput(0));
+}
+
+TEST_F(PrintingGraphTest, StringSuccess) {
+ ASSERT_OK(Init(DT_INT32, DT_STRING));
+ AddInputFromArray<int32>(TensorShape({6}), {1, 2, 3, 4, 5, 6});
+ AddInputFromArray<string>(TensorShape({}), {"foo"});
+ AddInputFromArray<string>(TensorShape({}), {"bar"});
+ ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_INT32, TensorShape({6}));
+ test::FillValues<int32>(&expected, {1, 2, 3, 4, 5, 6});
+ test::ExpectTensorEqual<int32>(expected, *GetOutput(0));
+}
+
+TEST_F(PrintingGraphTest, MsgSuccess) {
+ ASSERT_OK(Init(DT_INT32, DT_STRING, "Message: "));
+ AddInputFromArray<int32>(TensorShape({6}), {1, 2, 3, 4, 5, 6});
+ AddInputFromArray<string>(TensorShape({}), {"foo"});
+ AddInputFromArray<string>(TensorShape({}), {"bar"});
+ ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_INT32, TensorShape({6}));
+ test::FillValues<int32>(&expected, {1, 2, 3, 4, 5, 6});
+ test::ExpectTensorEqual<int32>(expected, *GetOutput(0));
+}
+
+TEST_F(PrintingGraphTest, FirstNSuccess) {
+ ASSERT_OK(Init(DT_INT32, DT_STRING, "", 3));
+ AddInputFromArray<int32>(TensorShape({6}), {1, 2, 3, 4, 5, 6});
+ AddInputFromArray<string>(TensorShape({}), {"foo"});
+ AddInputFromArray<string>(TensorShape({}), {"bar"});
+ // run 4 times but we only print 3 as intended
+ for (int i = 0; i < 4; i++) ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_INT32, TensorShape({6}));
+ test::FillValues<int32>(&expected, {1, 2, 3, 4, 5, 6});
+ test::ExpectTensorEqual<int32>(expected, *GetOutput(0));
+}
+
+} // end namespace
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc
new file mode 100644
index 0000000000..9781bcfa59
--- /dev/null
+++ b/tensorflow/core/kernels/lookup_table_init_op.cc
@@ -0,0 +1,116 @@
+#define EIGEN_USE_THREADS
+
+#include <string>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/initializable_lookup_table.h"
+#include "tensorflow/core/kernels/lookup_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+namespace lookup {
+
+// Iterator to initialize tables given 'keys' and 'values' tensors.
+//
+// The two tensors are returned in the first iteration. It doesn't loop
+// over each element of the tensor since insertions in the lookup table can
+// process batches.
+class KeyValueTensorIterator
+ : public InitializableLookupTable::InitTableIterator {
+ public:
+ // keys and values are not owned by the iterator.
+ explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values)
+ : keys_(keys), values_(values), valid_(true), status_(Status::OK()) {
+ TensorShape key_shape = keys_->shape();
+ if (!key_shape.IsSameSize(values_->shape())) {
+ valid_ = false;
+ status_ = errors::InvalidArgument(
+ "keys and values should have the same dimension.",
+ key_shape.DebugString(), " vs ", values_->shape().DebugString());
+ }
+ if (key_shape.num_elements() == 0) {
+ valid_ = false;
+ status_ =
+ errors::InvalidArgument("keys and values cannot be empty tensors.");
+ }
+ }
+
+ bool Valid() const override { return valid_; }
+
+ void Next() override {
+ valid_ = false;
+ status_ = errors::OutOfRange("No more data.");
+ }
+
+ const Tensor& keys() const override { return *keys_; }
+
+ const Tensor& values() const override { return *values_; }
+
+ Status status() const override { return status_; }
+
+ int64 total_size() const {
+ return keys_ == nullptr ? -1 : keys_->NumElements();
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(KeyValueTensorIterator);
+
+ const Tensor* keys_; // Doesn't own it.
+ const Tensor* values_; // Doesn't own it.
+ bool valid_; // true if the iterator points to an existing range.
+ Status status_;
+};
+
+} // namespace lookup
+
+// Kernel to initialize a look table given a key and value tensors.
+// After this operation, the table becomes read-only.
+class InitializeTableOp : public OpKernel {
+ public:
+ explicit InitializeTableOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ mutex_lock l(mu_);
+ lookup::InitializableLookupTable* table;
+ OP_REQUIRES_OK(ctx,
+ GetInitializableLookupTable("table_handle", ctx, &table));
+ core::ScopedUnref unref_me(table);
+
+ DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(),
+ table->value_dtype()};
+ DataTypeVector expected_outputs = {};
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
+
+ const Tensor& keys = ctx->input(1);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(keys.shape()),
+ errors::InvalidArgument("Keys must be a vector, but received ",
+ keys.shape().DebugString()));
+
+ const Tensor& values = ctx->input(2);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsVector(values.shape()),
+ errors::InvalidArgument("Values must be a vector, but received ",
+ values.shape().DebugString()));
+
+ OP_REQUIRES(ctx, keys.NumElements() == values.NumElements(),
+ errors::InvalidArgument(
+ "Keys and values must have the same size ",
+ keys.NumElements(), " vs ", values.NumElements()));
+
+ lookup::KeyValueTensorIterator iter(&keys, &values);
+ OP_REQUIRES_OK(ctx, table->Initialize(iter));
+ }
+
+ private:
+ mutex mu_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("InitializeTable").Device(DEVICE_CPU),
+ InitializeTableOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
new file mode 100644
index 0000000000..2bab4df94f
--- /dev/null
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -0,0 +1,166 @@
+#include "tensorflow/core/kernels/lookup_table_op.h"
+#define EIGEN_USE_THREADS
+
+#include <string>
+#include <utility>
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/initializable_lookup_table.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
+
+namespace tensorflow {
+namespace lookup {
+
+// Lookup table that wraps an unordered_map, where the key and value data type
+// is specified.
+//
+// This table is recommened for any variations to key values.
+//
+// For look up, the table is required to be initialized (allocated
+// and populated). Once the table is marked as initialized it becomes read-only.
+//
+// Sample use case:
+//
+// HashTable<int64, int64> table; // int64 -> int64.
+// table.Prepare(10); // Prepare the underlying data structure, the number of
+// // elements is required by interface, but not used.
+// // Populate the table, elements could be added in one or multiple calls.
+// table.Insert(key_tensor, value_tensor); // Populate the table.
+// ...
+// table.set_is_initialized();
+//
+// table.Find(in_t, &out_t, default_t)
+//
+template <class K, class V>
+class HashTable : public InitializableLookupTable {
+ public:
+ size_t size() const override { return table_ ? table_->size() : 0; }
+
+ DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
+
+ DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
+
+ protected:
+ Status DoPrepare(size_t unused) override {
+ if (is_initialized_) {
+ return errors::Aborted("HashTable already initialized.");
+ }
+ if (!table_) {
+ table_ = std::unique_ptr<std::unordered_map<K, V>>(
+ new std::unordered_map<K, V>());
+ }
+ return Status::OK();
+ };
+
+ Status DoInsert(const Tensor& keys, const Tensor& values) override {
+ if (!table_) {
+ return errors::FailedPrecondition("HashTable is not prepared.");
+ }
+
+ const auto key_values = keys.flat<K>();
+ const auto value_values = values.flat<V>();
+ for (size_t i = 0; i < key_values.size(); ++i) {
+ const K& key = key_values(i);
+ const V& value = value_values(i);
+ const V& previous_value = gtl::LookupOrInsert(table_.get(), key, value);
+ if (previous_value != value) {
+ return errors::FailedPrecondition(
+ "HashTable has different value for same key. Key ", key, " has ",
+ previous_value, " and trying to add value ", value);
+ }
+ }
+ return Status::OK();
+ }
+
+ Status DoFind(const Tensor& key, Tensor* value,
+ const Tensor& default_value) override {
+ const V default_val = default_value.flat<V>()(0);
+ const auto key_values = key.flat<K>();
+ auto value_values = value->flat<V>();
+
+ for (size_t i = 0; i < key_values.size(); ++i) {
+ value_values(i) =
+ gtl::FindWithDefault(*table_, key_values(i), default_val);
+ }
+ return Status::OK();
+ }
+
+ private:
+ std::unique_ptr<std::unordered_map<K, V>> table_;
+};
+
+} // namespace lookup
+
+// Table lookup op. Perform the lookup operation on the given table.
+class LookupTableFindOp : public OpKernel {
+ public:
+ explicit LookupTableFindOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ lookup::LookupInterface* table;
+ OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
+ core::ScopedUnref unref_me(table);
+
+ DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(),
+ table->value_dtype()};
+ DataTypeVector expected_outputs = {table->value_dtype()};
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
+
+ const Tensor& input = ctx->input(1);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input.shape()),
+ errors::InvalidArgument("Input must be a vector, not ",
+ input.shape().DebugString()));
+
+ const Tensor& default_value = ctx->input(2);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(default_value.shape()),
+ errors::InvalidArgument("Default value must be a scalar, not ",
+ default_value.shape().DebugString()));
+
+ Tensor* out;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output("output_values", input.shape(), &out));
+
+ OP_REQUIRES_OK(ctx, table->Find(input, out, default_value));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU),
+ LookupTableFindOp);
+
+// Op that returns the size of the given table.
+class LookupTableSizeOp : public OpKernel {
+ public:
+ explicit LookupTableSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ lookup::LookupInterface* table;
+ OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
+ core::ScopedUnref unref_me(table);
+
+ Tensor* out;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("size", TensorShape({}), &out));
+ out->flat<int64>().setConstant(table->size());
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("LookupTableSize").Device(DEVICE_CPU),
+ LookupTableSizeOp);
+
+// Register the HashTable op with the currently supported key and value types.
+#define REGISTER_KERNEL(key_dtype, value_dtype) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("HashTable") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<key_dtype>("key_dtype") \
+ .TypeConstraint<value_dtype>("value_dtype"), \
+ LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
+ value_dtype>)
+
+REGISTER_KERNEL(string, int64);
+REGISTER_KERNEL(int64, string);
+
+#undef REGISTER_KERNEL
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h
new file mode 100644
index 0000000000..dc53ce33a6
--- /dev/null
+++ b/tensorflow/core/kernels/lookup_table_op.h
@@ -0,0 +1,80 @@
+#ifndef TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_
+#define TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_
+
+#include "tensorflow/core/framework/lookup_interface.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/kernels/lookup_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+// Lookup table op that supports different table implementations specified by
+// the 'Container' template. Container must be derived from LookupInterface. The
+// key and value are of the templated type "key_dtype" and "value_dtype"
+// respectively.
+template <class Container, class key_dtype, class value_dtype>
+class LookupTableOp : public OpKernel {
+ public:
+ // ctx is not owned by this class.
+ explicit LookupTableOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), table_handle_set_(false) {
+ OP_REQUIRES_OK(ctx, ctx->allocate_persistent(tensorflow::DT_STRING,
+ tensorflow::TensorShape({2}),
+ &table_handle_, nullptr));
+ }
+
+ // ctx is not owned by this function.
+ void Compute(OpKernelContext* ctx) override {
+ mutex_lock l(mu_);
+ if (!table_handle_set_) {
+ OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def()));
+ auto creator = [this](lookup::LookupInterface** ret) {
+ *ret = new Container();
+ return Status::OK();
+ };
+
+ lookup::LookupInterface* table = nullptr;
+ OP_REQUIRES_OK(
+ ctx, cinfo_.resource_manager()
+ ->template LookupOrCreate<lookup::LookupInterface>(
+ cinfo_.container(), cinfo_.name(), &table, creator));
+ core::ScopedUnref unref_me(table);
+
+ OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes(
+ *table, DataTypeToEnum<key_dtype>::v(),
+ DataTypeToEnum<value_dtype>::v(), cinfo_.name()));
+
+ auto h = table_handle_.AccessTensor(ctx)->template flat<string>();
+ h(0) = cinfo_.container();
+ h(1) = cinfo_.name();
+ table_handle_set_ = true;
+ }
+ ctx->set_output_ref(0, &mu_, table_handle_.AccessTensor(ctx));
+ }
+
+ ~LookupTableOp() override {
+ // If the table object was not shared, delete it.
+ if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
+ TF_CHECK_OK(
+ cinfo_.resource_manager()->template Delete<lookup::LookupInterface>(
+ cinfo_.container(), cinfo_.name()));
+ }
+ }
+
+ private:
+ mutex mu_;
+ PersistentTensor table_handle_ GUARDED_BY(mu_);
+ bool table_handle_set_ GUARDED_BY(mu_);
+ ContainerInfo cinfo_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(LookupTableOp);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_
diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc
new file mode 100644
index 0000000000..634c11e4a5
--- /dev/null
+++ b/tensorflow/core/kernels/lookup_util.cc
@@ -0,0 +1,72 @@
+#include "tensorflow/core/kernels/lookup_util.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+namespace lookup {
+namespace {
+
+Status GetTableHandle(const string& input_name, OpKernelContext* ctx,
+ string* container, string* table_handle) {
+ {
+ mutex* mu;
+ TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu));
+ mutex_lock l(*mu);
+ Tensor tensor;
+ TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true));
+ if (tensor.NumElements() != 2) {
+ return errors::InvalidArgument(
+ "Lookup table handle must be scalar, but had shape: ",
+ tensor.shape().DebugString());
+ }
+ auto h = tensor.flat<string>();
+ *container = h(0);
+ *table_handle = h(1);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status GetLookupTable(const string& input_name, OpKernelContext* ctx,
+ LookupInterface** table) {
+ string container;
+ string table_handle;
+ TF_RETURN_IF_ERROR(
+ GetTableHandle(input_name, ctx, &container, &table_handle));
+ return ctx->resource_manager()->Lookup(container, table_handle, table);
+}
+
+Status GetInitializableLookupTable(const string& input_name,
+ OpKernelContext* ctx,
+ InitializableLookupTable** table) {
+ string container;
+ string table_handle;
+ TF_RETURN_IF_ERROR(
+ GetTableHandle(input_name, ctx, &container, &table_handle));
+ LookupInterface* lookup_table;
+ TF_RETURN_IF_ERROR(
+ ctx->resource_manager()->Lookup(container, table_handle, &lookup_table));
+ *table = dynamic_cast<InitializableLookupTable*>(lookup_table);
+ if (*table == nullptr) {
+ lookup_table->Unref();
+ return errors::InvalidArgument("Table ", container, " ", table_handle,
+ " is not initializable");
+ }
+ return Status::OK();
+}
+
+Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype,
+ DataType value_dtype, const string& table_name) {
+ if (table.key_dtype() != key_dtype || table.value_dtype() != value_dtype) {
+ return errors::InvalidArgument(
+ "Conflicting key/value dtypes ", key_dtype, "->", value_dtype, " with ",
+ table.key_dtype(), "-", table.value_dtype(), " for table ", table_name);
+ }
+ return Status::OK();
+}
+
+} // namespace lookup
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/lookup_util.h b/tensorflow/core/kernels/lookup_util.h
new file mode 100644
index 0000000000..991a757edd
--- /dev/null
+++ b/tensorflow/core/kernels/lookup_util.h
@@ -0,0 +1,31 @@
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_LOOKUP_UTIL_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_LOOKUP_UTIL_H_
+
+#include "tensorflow/core/framework/lookup_interface.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/initializable_lookup_table.h"
+
+namespace tensorflow {
+namespace lookup {
+
+// Gets the LookupTable stored in the ctx->resource_manager() with key
+// passed by attribute with name input_name, returns null if the table
+// doesn't exist.
+Status GetLookupTable(const string& input_name, OpKernelContext* ctx,
+ LookupInterface** table);
+
+// Gets the InitializableLookupTable stored in the
+// ctx->resource_manager() with key passed by attribute with name
+// input_name, returns null if the table doesn't exist.
+Status GetInitializableLookupTable(const string& input_name,
+ OpKernelContext* ctx,
+ InitializableLookupTable** table);
+
+// Verify that the given key_dtype and value_dtype matches the corresponding
+// table's data types.
+Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype,
+ DataType value_dtype, const string& table_name);
+} // namespace lookup
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_LOOKUP_UTIL_H_
diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc
new file mode 100644
index 0000000000..e5abf5906f
--- /dev/null
+++ b/tensorflow/core/kernels/lrn_op.cc
@@ -0,0 +1,228 @@
+// LRN = Local Response Normalization
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+#ifndef __ANDROID__
+#include "tensorflow/core/util/work_sharder.h"
+#endif
+
+namespace tensorflow {
+
+// Create a depth-by-depth band matrix with 1s along a swath of size (2 *
+// depth_radius + 1) around the diagonal.
+static void GetBandMatrix(int depth, int64 depth_radius,
+ Eigen::Tensor<float, 2, Eigen::RowMajor>* result) {
+ result->setZero();
+ for (int row = 0; row < depth; ++row) {
+ const int begin = std::max<int>(0, row - depth_radius);
+ const int end = std::min<int64>(depth, row + depth_radius + 1);
+ Eigen::DSizes<ptrdiff_t, 2> start(row, begin);
+ Eigen::DSizes<ptrdiff_t, 2> sizes(1, end - begin);
+ result->slice(start, sizes).setConstant(1.0f);
+ }
+}
+
+class LRNOp : public OpKernel {
+ public:
+ explicit LRNOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius_));
+ OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
+ OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& in = context->input(0);
+ OP_REQUIRES(context, in.dims() == 4,
+ errors::InvalidArgument("in must be 4-dimensional"));
+ const int64 batch = in.dim_size(0);
+ const int64 rows = in.dim_size(1);
+ const int64 cols = in.dim_size(2);
+ const int64 depth = in.dim_size(3);
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 0, TensorShape({batch, rows, cols, depth}), &output));
+
+#ifdef __ANDROID__
+ MognetLRN(in, batch, rows, cols, depth, output);
+#else
+ const int nodes = cols * rows;
+ auto in_shaped = in.shaped<float, 2>({nodes * batch, depth});
+
+ // Multiplying the input with the band matrix has the effect of reducing the
+ // correct patch along the depth.
+ Eigen::Tensor<float, 2, Eigen::RowMajor> multiplier(depth, depth);
+ GetBandMatrix(depth, depth_radius_, &multiplier);
+
+ auto out_shaped = output->shaped<float, 2>({nodes * batch, depth});
+ Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
+ /// TODO(keveman): Optimize for beta in {0, 1, 0.5}
+ out_shaped.device(context->eigen_cpu_device()) =
+ in_shaped /
+ in_shaped.square()
+ .contract(multiplier, dims)
+ .unaryExpr([this](float x) { return bias_ + alpha_ * x; })
+ .pow(beta_);
+#endif
+ }
+
+ private:
+ typedef Eigen::Tensor<float, 1, Eigen::RowMajor>::DimensionPair DimPair;
+
+ void MognetLRN(const Tensor& in, const int batch, const int rows,
+ const int cols, const int depth, Tensor* out) {
+ Eigen::Map<const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>>
+ data_in(in.flat<float>().data(), depth, batch * rows * cols);
+
+ Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>> data_out(
+ out->flat<float>().data(), depth, batch * rows * cols);
+
+ const int double_depth_radius = depth_radius_ * 2;
+ Eigen::VectorXf padded_square(data_in.rows() + double_depth_radius);
+ padded_square.setZero();
+ for (int r = 0; r < data_in.cols(); ++r) {
+ // Do local response normalization for data_in(:, r)
+ // first, compute the square and store them in buffer for repeated use
+ padded_square.block(depth_radius_, 0, data_out.rows(), 1) =
+ data_in.col(r).cwiseProduct(data_in.col(r)) * alpha_;
+ // Then, compute the scale and writes them to data_out
+ float accumulated_scale = 0;
+ for (int i = 0; i < double_depth_radius; ++i) {
+ accumulated_scale += padded_square(i);
+ }
+ for (int i = 0; i < data_in.rows(); ++i) {
+ accumulated_scale += padded_square(i + double_depth_radius);
+ data_out(i, r) = bias_ + accumulated_scale;
+ accumulated_scale -= padded_square(i);
+ }
+ }
+
+ // In a few cases, the pow computation could benefit from speedups.
+ if (beta_ == 1) {
+ data_out.array() = data_in.array() * data_out.array().inverse();
+ } else if (beta_ == 0.5) {
+ data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
+ } else {
+ data_out.array() = data_in.array() * data_out.array().pow(-beta_);
+ }
+ }
+
+ int64 depth_radius_;
+ float bias_;
+ float alpha_;
+ float beta_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("LRN").Device(DEVICE_CPU), LRNOp);
+
+#ifndef __ANDROID__
+
+class LRNGradOp : public OpKernel {
+ public:
+ explicit LRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius_));
+ OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
+ OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& in_grads = context->input(0);
+ const Tensor& in_image = context->input(1);
+ const Tensor& out_image = context->input(2);
+
+ OP_REQUIRES(context, in_grads.dims() == 4 && in_image.dims() == 4,
+ errors::InvalidArgument("inputs must be 4-dimensional"));
+ const int64 batch = in_grads.dim_size(0);
+ const int64 rows = in_grads.dim_size(1);
+ const int64 cols = in_grads.dim_size(2);
+ const int64 depth = in_grads.dim_size(3);
+ OP_REQUIRES(
+ context,
+ in_image.dim_size(0) == batch && in_image.dim_size(1) == rows &&
+ in_image.dim_size(2) == cols && in_image.dim_size(3) == depth &&
+ out_image.dim_size(0) == batch && out_image.dim_size(1) == rows &&
+ out_image.dim_size(2) == cols && out_image.dim_size(3) == depth,
+ errors::InvalidArgument(
+ "input_grads, input_image, and out_image should have the same "
+ "shape"));
+ const auto nodes = cols * rows;
+ auto grads_shaped = in_grads.shaped<float, 2>({nodes * batch, depth});
+ auto in_shaped = in_image.shaped<float, 2>({nodes * batch, depth});
+ auto activations = out_image.shaped<float, 2>({nodes * batch, depth});
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 0, TensorShape({batch, rows, cols, depth}), &output));
+ auto out_shaped = output->shaped<float, 2>({nodes * batch, depth});
+ out_shaped.setZero();
+
+ auto shard = [this, activations, in_shaped, grads_shaped, out_shaped,
+ depth](int64 begin, int64 end) {
+ for (int64 i = begin; i < end; ++i) {
+ for (int64 j = 0; j < depth; ++j) {
+ // Let y be the LRN activations and x be the inputs along the depth
+ // dimension. (LRN operates independently along rows, cols, and
+ // batch).
+ // We have
+ // yi = xi / (bias + alpha(sum_j_{i - depth_radius}^{i + depth_radius}
+ // x_j^2))^beta
+ //
+ // Let N = (bias + alpha(sum_j_{i - depth_radius}^{i + depth_radius}
+ // x_j^2))
+ // dy_i/dx_i = (N^beta - xi. beta*N^(beta-1)*2*alpha*xi)/N^(2*beta)
+ // dy_i/dx_j = ( - xi. beta*N^(beta-1)*2*alpha*xj)/N^(2*beta)
+ //
+ // NOTE(keveman) : We can compute N by doing (yi/xi) ^ (1/beta).
+ // However, this is numerically unstable for small values of xi. We
+ // compute N explicitly here to avoid that.
+
+ int64 depth_begin = std::max<int64>(0, j - depth_radius_);
+ int64 depth_end = std::min<int64>(depth, j + depth_radius_ + 1);
+
+ float norm = 0.0f;
+ for (int64 k = depth_begin; k < depth_end; ++k) {
+ norm += in_shaped(i, k) * in_shaped(i, k);
+ }
+ norm = alpha_ * norm + bias_;
+ DCHECK_GT(norm, 1e-6);
+ for (int64 k = depth_begin; k < depth_end; ++k) {
+ float dyi = -2.0f * alpha_ * beta_ * in_shaped(i, k) *
+ activations(i, j) / norm;
+ if (k == j) {
+ dyi += std::pow(norm, -beta_);
+ }
+ dyi *= grads_shaped(i, j);
+ const_cast<TTypes<float, 2>::Tensor&>(out_shaped)(i, k) += dyi;
+ }
+ }
+ }
+ };
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+ Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
+ depth * depth, shard);
+ }
+
+ private:
+ typedef Eigen::Tensor<float, 1, Eigen::RowMajor>::DimensionPair DimPair;
+
+ int64 depth_radius_;
+ float bias_;
+ float alpha_;
+ float beta_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("LRNGrad").Device(DEVICE_CPU), LRNGradOp);
+
+#endif // __ANDROID__
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/lrn_op_test.cc b/tensorflow/core/kernels/lrn_op_test.cc
new file mode 100644
index 0000000000..4c338b6cb3
--- /dev/null
+++ b/tensorflow/core/kernels/lrn_op_test.cc
@@ -0,0 +1,185 @@
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+static const float tol_ = 1e-4;
+
+class LRNFloatTest : public OpsTestBase {
+ protected:
+ LRNFloatTest() : philox_(123, 17), rand_(&philox_) { RequireDefaultOps(); }
+
+ int GetIntAttr(const string& name) {
+ int value;
+ TF_CHECK_OK(GetNodeAttr(*node_def(), name, &value));
+ return value;
+ }
+
+ float GetFloatAttr(const string& name) {
+ float value;
+ TF_CHECK_OK(GetNodeAttr(*node_def(), name, &value));
+ return value;
+ }
+
+ bool Compare() {
+ const auto& input = GetInput(0);
+ const int64 batch_size = input.dim_size(0);
+ const int64 rows = input.dim_size(1);
+ const int64 cols = input.dim_size(2);
+ const int64 depth = input.dim_size(3);
+ const int64 rest = cols * rows * batch_size;
+
+ const int64 depth_radius = GetIntAttr("depth_radius");
+ const float bias = GetFloatAttr("bias");
+ const float alpha = GetFloatAttr("alpha");
+ const float beta = GetFloatAttr("beta");
+
+ Eigen::Tensor<float, 4, Eigen::RowMajor> expected(batch_size, rows, cols,
+ depth);
+ auto out = expected.reshape(Eigen::DSizes<int64, 2>{rest, depth});
+ auto in = input.shaped<float, 2>({rest, depth});
+
+ for (int64 i = 0; i < rest; ++i) {
+ Eigen::Tensor<float, 1, Eigen::RowMajor> out_col(depth);
+ for (int64 d = 0; d < depth; ++d) {
+ float denom = 0.0f;
+ for (int64 r = std::max(0ll, d - depth_radius);
+ r < std::min(depth, d + depth_radius + 1); ++r) {
+ denom += in(i, r) * in(i, r);
+ }
+ denom = std::pow(denom * alpha + bias, beta);
+ out_col(d) = in(i, d) / denom;
+ }
+ out.chip<0>(i) = out_col;
+ }
+ auto actual = GetOutput(0)->tensor<float, 4>();
+ Eigen::Tensor<float, 0, Eigen::RowMajor> sum =
+ ((expected - actual).abs() > actual.constant(tol_))
+ .select(actual.constant(1), actual.constant(0))
+ .sum();
+ return sum() == 0;
+ }
+
+ random::PhiloxRandom philox_;
+ random::SimplePhilox rand_;
+};
+
+TEST_F(LRNFloatTest, Depth96) {
+ ASSERT_OK(NodeDefBuilder("lrn_op", "LRN")
+ .Input(FakeInput())
+ .Attr("depth_radius", 5)
+ .Attr("bias", 1.0f)
+ .Attr("alpha", 0.1f)
+ .Attr("beta", 2.0f)
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ AddInput<float>(TensorShape({1, 1, 1, 96}),
+ [this](int i) -> float { return i + 1; });
+ ASSERT_OK(RunOpKernel());
+ auto actual = GetOutput(0)->tensor<float, 4>();
+
+ // Output for Node 0 with Value 1:
+ // 1 / (1 + 0.1*(1^2 + 2^2 + 3^2 + 4^2 + 5^2 + 6^2))^2
+ EXPECT_NEAR(1. / (10.1 * 10.1), actual(0, 0, 0, 0), tol_);
+
+ // Output for Node 5 with Value 6:
+ // 6 / (1 + 0.1*(1^2 + 2^2 + 3^2 + 4^2 + 5^2 + 6^2 ... + 11^2))^2
+ EXPECT_NEAR(6. / (51.6 * 51.6), actual(0, 0, 0, 5), tol_);
+
+ // Output for Node 63 with value 64:
+ // 64 / (1 + 0.1*(59^2 + 60^2 + 61^2 + 62^2 + 63^2 + 64^2))^2
+ EXPECT_NEAR(64. / (2272.1 * 2272.1), actual(0, 0, 0, 63), tol_);
+
+ // Output for Node 64 with value 65:
+ // 65 / (1 + 0.1*(65^2 + 66^2 + 67^2 + 68^2 + 69^2 + 70^2))^2
+ EXPECT_NEAR(65. / (2736.5 * 2736.5), actual(0, 0, 0, 64), tol_);
+
+ // Output for Node 95 with value 96:
+ // 96 / (1 + 0.1*(91^2 + 92^2 + 93^2 + 94^2 + 95^2 + 96^2))^2
+ EXPECT_NEAR(96. / (5248.1 * 5248.1), actual(0, 0, 0, 95), tol_);
+ EXPECT_TRUE(Compare());
+}
+
+TEST_F(LRNFloatTest, Depth16) {
+ ASSERT_OK(NodeDefBuilder("lrn_op", "LRN")
+ .Input(FakeInput())
+ .Attr("depth_radius", 5)
+ .Attr("bias", 1.0f)
+ .Attr("alpha", 0.1f)
+ .Attr("beta", 2.0f)
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ AddInput<float>(TensorShape({1, 1, 1, 16}),
+ [this](int i) -> float { return i + 1; });
+ ASSERT_OK(RunOpKernel());
+ auto actual = GetOutput(0)->tensor<float, 4>();
+
+ // Output for Node 0 with Value 1:
+ // 1 / (1 + 0.1*(1^2 + 2^2 + 3^2 + 4^2 + 5^2 + 6^2))^2
+ EXPECT_NEAR(1. / (10.1 * 10.1), actual(0, 0, 0, 0), tol_);
+
+ // Output for Node 5 with Value 6:
+ // 6 / (1 + 0.1*(1^2 + 2^2 + 3^2 + 4^2 + 5^2 + 6^2 ... + 11^2))^2
+ EXPECT_NEAR(6. / (51.6 * 51.6), actual(0, 0, 0, 5), tol_);
+
+ // Output for Node 15 with value 16:
+ // 16 / (1 + 0.1*(11^2 + 12^2 + 13^2 + 14^2 + 15^2 + 16^2))^2
+ EXPECT_NEAR(16. / (112.1 * 112.1), actual(0, 0, 0, 15), tol_);
+ EXPECT_TRUE(Compare());
+}
+
+static double RndGaussian(random::SimplePhilox* rnd) {
+ // Box-Muller transformation.
+ // See, for example, http://www.taygeta.com/random/gaussian.html
+ double x1, x2;
+ double r;
+ do {
+ x1 = 2 * rnd->RandDouble() - 1;
+ x2 = 2 * rnd->RandDouble() - 1;
+ r = x1 * x1 + x2 * x2;
+ } while (r == 0 || r >= 1.0);
+ double w = sqrt(-2.0 * log(r) / r);
+ return x1 * w;
+}
+
+#define TCASE(NAME, DEPTH, BATCH, DEPTH_RADIUS, BIAS, ALPHA, BETA) \
+ TEST_F(LRNFloatTest, NAME) { \
+ ASSERT_OK(NodeDefBuilder("lrn_op", "LRN") \
+ .Input(FakeInput()) \
+ .Attr("depth_radius", (DEPTH_RADIUS)) \
+ .Attr("bias", (BIAS)) \
+ .Attr("alpha", ((ALPHA) / 10)) \
+ .Attr("beta", (BETA)) \
+ .Finalize(node_def())); \
+ ASSERT_OK(InitOp()); \
+ AddInput<float>(TensorShape({BATCH, 1, 1, DEPTH}), \
+ [this](int i) -> float { return RndGaussian(&rand_); }); \
+ ASSERT_OK(RunOpKernel()); \
+ EXPECT_TRUE(Compare()); \
+ }
+
+// clang-format off
+// DEPTH BATCH DEPTH_RADIUS BIAS ALPHA BETA
+TCASE(T0, 4, 2, 2, 1.0f, 1.0f, 2.0f)
+TCASE(T1, 16, 1, 5, 1.0f, 1.0f, 2.0f)
+TCASE(T2, 16, 32, 2, 1.0f, 2.0f, 1.0f)
+TCASE(T3, 128, 4, 3, 2.0f, 1.0f, 1.0f)
+// clang-format on
+
+#undef TCASE
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/matching_files_op.cc b/tensorflow/core/kernels/matching_files_op.cc
new file mode 100644
index 0000000000..08a4da5b41
--- /dev/null
+++ b/tensorflow/core/kernels/matching_files_op.cc
@@ -0,0 +1,42 @@
+// See docs in ../ops/io_ops.cc.
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/io/match.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+class MatchingFilesOp : public OpKernel {
+ public:
+ using OpKernel::OpKernel;
+ void Compute(OpKernelContext* context) override {
+ const Tensor* pattern;
+ OP_REQUIRES_OK(context, context->input("pattern", &pattern));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(pattern->shape()),
+ errors::InvalidArgument(
+ "Input pattern tensor must be scalar, but had shape: ",
+ pattern->shape().DebugString()));
+ std::vector<string> fnames;
+ OP_REQUIRES_OK(context,
+ io::GetMatchingFiles(context->env(),
+ pattern->scalar<string>()(), &fnames));
+ const int num_out = fnames.size();
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ "filenames", TensorShape({num_out}), &output));
+ auto output_vec = output->vec<string>();
+ for (int i = 0; i < num_out; ++i) {
+ output_vec(i) = fnames[i];
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MatchingFiles").Device(DEVICE_CPU),
+ MatchingFilesOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
new file mode 100644
index 0000000000..48bdba78b2
--- /dev/null
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -0,0 +1,214 @@
+// See docs in ../ops/math_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/matmul_op.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/fill_functor.h"
+
+#if GOOGLE_CUDA
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
+#include "tensorflow/stream_executor/stream.h"
+#endif // GOOGLE_CUDA
+
+namespace tensorflow {
+
+#if GOOGLE_CUDA
+
+namespace {
+template <typename T>
+perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
+ perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
+ perftools::gputools::DeviceMemory<T> typed(wrapped);
+ return typed;
+}
+} // namespace
+
+#endif // GOOGLE_CUDA
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T, bool USE_CUBLAS>
+struct LaunchMatMul;
+
+// On CPUs, we ignore USE_CUBLAS
+template <typename T>
+struct LaunchMatMulCPU {
+ static void launch(
+ OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b,
+ const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
+ Tensor* out) {
+ functor::MatMulFunctor<CPUDevice, T>()(ctx->eigen_device<CPUDevice>(),
+ out->matrix<T>(), a.matrix<T>(),
+ b.matrix<T>(), dim_pair);
+ }
+};
+
+template <typename T, bool USE_CUBLAS>
+struct LaunchMatMul<CPUDevice, T, USE_CUBLAS> : public LaunchMatMulCPU<T> {};
+
+#if GOOGLE_CUDA
+
+template <typename T>
+struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
+ static void launch(
+ OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b,
+ const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
+ Tensor* out) {
+ perftools::gputools::blas::Transpose trans[] = {
+ perftools::gputools::blas::Transpose::kNoTranspose,
+ perftools::gputools::blas::Transpose::kTranspose};
+ const uint64 m = a.dim_size(1 - dim_pair[0].first);
+ const uint64 k = a.dim_size(dim_pair[0].first);
+ const uint64 n = b.dim_size(1 - dim_pair[0].second);
+ bool transpose_a = dim_pair[0].first == 0;
+ bool transpose_b = dim_pair[0].second == 1;
+ auto blas_transpose_a = trans[transpose_a];
+ auto blas_transpose_b = trans[transpose_b];
+
+ auto* stream = ctx->op_device_context<GPUDeviceContext>()->stream();
+ OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
+
+ auto a_ptr = AsDeviceMemory(a.template flat<T>().data());
+ auto b_ptr = AsDeviceMemory(b.template flat<T>().data());
+ auto c_ptr = AsDeviceMemory(out->template flat<T>().data());
+
+ // Cublas does
+ // C = A x B
+ // where A, B and C are assumed to be in column major.
+ // We want the output to be in row-major, so we can compute
+ // C' = B' x A' (' stands for transpose)
+ bool blas_launch_status =
+ stream->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, 1.0f,
+ b_ptr, transpose_b ? k : n, a_ptr,
+ transpose_a ? m : k, 0.0f, &c_ptr, n)
+ .ok();
+ if (!blas_launch_status) {
+ ctx->SetStatus(errors::Internal(
+ "Blas SGEMM launch failed : a.shape=(", a.dim_size(0), ", ",
+ a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
+ "), m=", m, ", n=", n, ", k=", k));
+ }
+ }
+};
+
+template <typename T>
+struct LaunchMatMul<GPUDevice, T, false /* USE_CUBLAS */> {
+ static void launch(
+ OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b,
+ const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
+ Tensor* out) {
+ functor::MatMulFunctor<GPUDevice, T>()(ctx->eigen_device<GPUDevice>(),
+ out->matrix<T>(), a.matrix<T>(),
+ b.matrix<T>(), dim_pair);
+ }
+};
+
+#endif // GOOGLE_CUDA
+
+template <typename Device, typename T, bool USE_CUBLAS>
+class MatMulOp : public OpKernel {
+ public:
+ explicit MatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& a = ctx->input(0);
+ const Tensor& b = ctx->input(1);
+
+ // Check that the dimensions of the two matrices are valid.
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
+ errors::InvalidArgument("In[0] is not a matrix"));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
+ errors::InvalidArgument("In[1] is not a matrix"));
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
+ dim_pair[0].first = transpose_a_ ? 0 : 1;
+ dim_pair[0].second = transpose_b_ ? 1 : 0;
+
+ OP_REQUIRES(ctx,
+ a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
+ errors::InvalidArgument("Matrix size-compatible: In[0]: ",
+ a.shape().DebugString(), ", In[1]: ",
+ b.shape().DebugString()));
+ int a_dim_remaining = 1 - dim_pair[0].first;
+ int b_dim_remaining = 1 - dim_pair[0].second;
+ TensorShape out_shape(
+ {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
+
+ if (out->NumElements() == 0) {
+ // If a has shape [0, x] or b has shape [x, 0], the output shape
+ // is a 0-element matrix, so there is nothing to do.
+ return;
+ }
+
+ if (a.NumElements() == 0 || b.NumElements() == 0) {
+ // If a has shape [x, 0] and b has shape [0, y], the
+ // output shape is [x, y] where x and y are non-zero, so we fill
+ // the output with zeros.
+ functor::SetZeroFunctor<Device, T> f;
+ f(ctx->eigen_device<Device>(), out->flat<T>());
+ return;
+ }
+
+ LaunchMatMul<Device, T, USE_CUBLAS>::launch(ctx, this, a, b, dim_pair, out);
+ }
+
+ private:
+ bool transpose_a_;
+ bool transpose_b_;
+};
+
+namespace functor {
+
+// Partial specialization MatMulFunctor<Device=CPUDevice, T>.
+template <typename T>
+struct MatMulFunctor<CPUDevice, T> {
+ void operator()(
+ const CPUDevice& d, typename MatMulTypes<T>::out_type out,
+ typename MatMulTypes<T>::in_type in0,
+ typename MatMulTypes<T>::in_type in1,
+ const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
+ MatMul<CPUDevice>(d, out, in0, in1, dim_pair);
+ }
+};
+
+} // end namespace functor
+
+#define REGISTER_CPU(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("eigen"), \
+ MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>)
+
+#define REGISTER_GPU(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ MatMulOp<GPUDevice, T, true /* cublas, true by default */>); \
+ REGISTER_KERNEL_BUILDER(Name("MatMul") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .Label("cublas"), \
+ MatMulOp<GPUDevice, T, true /* cublas */>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("MatMul").Device(DEVICE_GPU).TypeConstraint<T>("T").Label("eigen"), \
+ MatMulOp<GPUDevice, T, false /* cublas */>)
+
+REGISTER_CPU(float);
+REGISTER_CPU(double);
+REGISTER_CPU(int32);
+REGISTER_CPU(complex64);
+#if GOOGLE_CUDA
+REGISTER_GPU(float);
+// REGISTER_GPU(double);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/matmul_op.h b/tensorflow/core/kernels/matmul_op.h
new file mode 100644
index 0000000000..f75b0ded1b
--- /dev/null
+++ b/tensorflow/core/kernels/matmul_op.h
@@ -0,0 +1,40 @@
+#ifndef TENSORFLOW_KERNELS_MATMUL_OP_H_
+#define TENSORFLOW_KERNELS_MATMUL_OP_H_
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+// Helpers to define tensor<T> needed by MatMul op.
+template <typename T>
+struct MatMulTypes {
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Aligned>
+ out_type;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Aligned> in_type;
+};
+
+template <typename Device, typename In0, typename In1, typename Out,
+ typename DimPair>
+void MatMul(const Device& d, Out out, In0 in0, In1 in1,
+ const DimPair& dim_pair) {
+ out.device(d) = in0.contract(in1, dim_pair);
+}
+
+template <typename Device, typename T>
+struct MatMulFunctor {
+ // Computes on device "d": out = in0 * in1, where * is matrix
+ // multiplication.
+ void operator()(
+ const Device& d, typename MatMulTypes<T>::out_type out,
+ typename MatMulTypes<T>::in_type in0,
+ typename MatMulTypes<T>::in_type in1,
+ const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair);
+};
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_MATMUL_OP_H_
diff --git a/tensorflow/core/kernels/matmul_op_gpu.cu.cc b/tensorflow/core/kernels/matmul_op_gpu.cu.cc
new file mode 100644
index 0000000000..17107ce5df
--- /dev/null
+++ b/tensorflow/core/kernels/matmul_op_gpu.cu.cc
@@ -0,0 +1,32 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/matmul_op.h"
+
+namespace tensorflow {
+namespace functor {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// Partial specialization MatMulTensorFunctor<Device=GPUDevice, T>
+template <typename T>
+struct MatMulFunctor<GPUDevice, T> {
+ void operator()(
+ const GPUDevice& d, typename MatMulTypes<T>::out_type out,
+ typename MatMulTypes<T>::in_type in0,
+ typename MatMulTypes<T>::in_type in1,
+ const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
+ MatMul<GPUDevice>(d, To32Bit(out), To32Bit(in0), To32Bit(in1), dim_pair);
+ }
+};
+
+#define DEFINE(T) template struct MatMulFunctor<GPUDevice, T>;
+DEFINE(float);
+// DEFINE(double); // Does not compile 1/2015.
+#undef DEFINE
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/matmul_op_test.cc b/tensorflow/core/kernels/matmul_op_test.cc
new file mode 100644
index 0000000000..b2b8f3d905
--- /dev/null
+++ b/tensorflow/core/kernels/matmul_op_test.cc
@@ -0,0 +1,56 @@
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+static Graph* Matmul(int m, int k, int n, bool transpose_a, bool transpose_b) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor in0(DT_FLOAT, transpose_a ? TensorShape({k, m}) : TensorShape({m, k}));
+ in0.flat<float>().setRandom();
+ Tensor in1(DT_FLOAT, transpose_b ? TensorShape({n, k}) : TensorShape({k, n}));
+ in1.flat<float>().setRandom();
+ test::graph::Matmul(g, test::graph::Constant(g, in0),
+ test::graph::Constant(g, in1), transpose_a, transpose_b);
+ return g;
+}
+
+#define BM_MatmulDev(M, K, N, TA, TB, DEVICE) \
+ static void BM_Matmul##_##M##_##K##_##N##_##TA##_##TB##_##DEVICE( \
+ int iters) { \
+ testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \
+ test::Benchmark(#DEVICE, Matmul(M, K, N, TA, TB)).Run(iters); \
+ } \
+ BENCHMARK(BM_Matmul##_##M##_##K##_##N##_##TA##_##TB##_##DEVICE);
+
+#define BM_Matmul(M, K, N, TA, TB) \
+ BM_MatmulDev(M, K, N, TA, TB, cpu); \
+ BM_MatmulDev(M, K, N, TA, TB, gpu);
+
+// Typical fully connected layers
+BM_Matmul(8, 512, 512, false, false);
+BM_Matmul(16, 512, 512, false, false);
+BM_Matmul(128, 512, 512, false, false);
+
+BM_Matmul(8, 1024, 1024, false, false);
+BM_Matmul(16, 1024, 1024, false, false);
+BM_Matmul(128, 1024, 1024, false, false);
+BM_Matmul(4096, 4096, 4096, false, false);
+
+// Backward for fully connected layers
+BM_Matmul(8, 1024, 1024, false, true);
+BM_Matmul(16, 1024, 1024, false, true);
+BM_Matmul(128, 1024, 1024, false, true);
+
+// Forward softmax with large output size
+BM_Matmul(8, 200, 10000, false, false);
+BM_Matmul(20, 200, 10000, false, false);
+BM_Matmul(20, 200, 20000, false, false);
+
+// Backward softmax with large output size
+BM_Matmul(8, 10000, 200, false, true);
+BM_Matmul(20, 10000, 200, false, true);
+BM_Matmul(20, 20000, 200, false, true);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/matrix_inverse_op.cc b/tensorflow/core/kernels/matrix_inverse_op.cc
new file mode 100644
index 0000000000..ad0948d6ef
--- /dev/null
+++ b/tensorflow/core/kernels/matrix_inverse_op.cc
@@ -0,0 +1,64 @@
+// See docs in ../ops/linalg_ops.cc.
+#include <cmath>
+
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/linalg_ops_common.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "third_party/eigen3/Eigen/LU"
+
+namespace tensorflow {
+
+template <class Scalar, bool SupportsBatchOperationT>
+class MatrixInverseOp
+ : public LinearAlgebraOp<Scalar, SupportsBatchOperationT> {
+ public:
+ explicit MatrixInverseOp(OpKernelConstruction* context)
+ : LinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {}
+ ~MatrixInverseOp() override {}
+
+ TensorShape GetOutputMatrixShape(
+ const TensorShape& input_matrix_shape) override {
+ return input_matrix_shape;
+ }
+
+ int64 GetCostPerUnit(const TensorShape& input_matrix_shape) override {
+ const int64 rows = input_matrix_shape.dim_size(0);
+ if (rows > (1LL << 20)) {
+ // A big number to cap the cost in case overflow.
+ return kint32max;
+ } else {
+ return rows * rows * rows;
+ }
+ }
+
+ using typename LinearAlgebraOp<Scalar, SupportsBatchOperationT>::MatrixMap;
+ using
+ typename LinearAlgebraOp<Scalar, SupportsBatchOperationT>::ConstMatrixMap;
+
+ void ComputeMatrix(OpKernelContext* context, const ConstMatrixMap& input,
+ MatrixMap* output) override {
+ OP_REQUIRES(context, input.rows() == input.cols(),
+ errors::InvalidArgument("Input matrix must be square."));
+ if (input.rows() == 0) {
+ // By definition, an empty matrix's inverse is an emptry matrix.
+ return;
+ }
+ Eigen::FullPivLU<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic,
+ Eigen::RowMajor>> lu_decomposition(input);
+ OP_REQUIRES(context, lu_decomposition.isInvertible(),
+ errors::InvalidArgument("Input is not invertible."));
+ *output = lu_decomposition.inverse();
+ }
+};
+
+REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp<float, false>), float);
+REGISTER_LINALG_OP("MatrixInverse", (MatrixInverseOp<double, false>), double);
+REGISTER_LINALG_OP("BatchMatrixInverse", (MatrixInverseOp<float, true>), float);
+REGISTER_LINALG_OP("BatchMatrixInverse", (MatrixInverseOp<double, true>),
+ double);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc
new file mode 100644
index 0000000000..31046018c5
--- /dev/null
+++ b/tensorflow/core/kernels/maxpooling_op.cc
@@ -0,0 +1,554 @@
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/maxpooling_op.h"
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/pooling_ops_common.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/util/use_cudnn.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+#if GOOGLE_CUDA
+#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
+#include "tensorflow/core/kernels/pooling_ops_common_gpu.h"
+#endif // GOOGLE_CUDA
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+const int kInvalidMaxPoolingIndex = -1;
+
+template <typename Device, typename T>
+struct SpatialMaxPoolWithArgMaxHelper {
+ static void Compute(Tensor* output, Tensor* output_arg_max,
+ const Tensor& tensor_in, const PoolParameters& params,
+ const Padding& padding) {
+ typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+ ConstEigenMatrixMap;
+ typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+ EigenMatrixMap;
+ typedef Eigen::Map<Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic>>
+ EigenIndexMatrixMap;
+
+ ConstEigenMatrixMap in_mat(
+ tensor_in.flat<T>().data(), params.depth,
+ params.tensor_in_cols * params.tensor_in_rows * params.tensor_in_batch);
+ EigenMatrixMap out_mat(
+ output->flat<T>().data(), params.depth,
+ params.out_width * params.out_height * params.tensor_in_batch);
+ EigenIndexMatrixMap out_arg_max_mat(
+ output_arg_max->flat<int64>().data(), params.depth,
+ params.out_width * params.out_height * params.tensor_in_batch);
+
+ // Initializes the output tensor with MIN<T>.
+ output_arg_max->flat<int64>().setConstant(kInvalidMaxPoolingIndex);
+ output->flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
+
+ // The following code basically does the following:
+ // 1. Flattens the input and output tensors into two dimensional arrays.
+ // tensor_in_as_matrix:
+ // depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch)
+ // output_as_matrix:
+ // depth by (out_width * out_height * tensor_in_batch)
+ //
+ // 2. Walks through the set of columns in the flattened tensor_in_as_matrix,
+ // and updates the corresponding column(s) in output_as_matrix with the
+ // max value.
+ for (int b = 0; b < params.tensor_in_batch; ++b) {
+ for (int h = 0; h < params.tensor_in_rows; ++h) {
+ for (int w = 0; w < params.tensor_in_cols; ++w) {
+ // (h_start, h_end) * (w_start, w_end) is the range that the input
+ // vector projects to.
+ const int hpad = h + params.pad_rows;
+ const int wpad = w + params.pad_cols;
+ const int h_start =
+ (hpad < params.window_rows)
+ ? 0
+ : (hpad - params.window_rows) / params.row_stride + 1;
+ const int h_end =
+ std::min(hpad / params.row_stride + 1, params.out_height);
+ const int w_start =
+ (wpad < params.window_cols)
+ ? 0
+ : (wpad - params.window_cols) / params.col_stride + 1;
+ const int w_end =
+ std::min(wpad / params.col_stride + 1, params.out_width);
+ // compute elementwise max
+ const int in_index =
+ (b * params.tensor_in_rows + h) * params.tensor_in_cols + w;
+ for (int ph = h_start; ph < h_end; ++ph) {
+ for (int pw = w_start; pw < w_end; ++pw) {
+ const int out_index =
+ (b * params.out_height + ph) * params.out_width + pw;
+ /// NOTES(zhengxq): not using the eigen matrix operation for now.
+ /// May consider parallelizing the operations if needed.
+ for (int d = 0; d < params.depth; ++d) {
+ const T& input_ref = in_mat.coeffRef(d, in_index);
+ T& output_ref = out_mat.coeffRef(d, out_index);
+ int64& out_arg_max_ref = out_arg_max_mat.coeffRef(d, out_index);
+ if (output_ref < input_ref ||
+ out_arg_max_ref == kInvalidMaxPoolingIndex) {
+ output_ref = input_ref;
+ int input_offset = in_index * params.depth + d;
+ out_arg_max_ref = input_offset;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MaxPool").Device(DEVICE_CPU),
+ MaxPoolingOp<CPUDevice, float>);
+
+#if GOOGLE_CUDA
+// Forward declarations for the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void SpatialMaxPooling<Eigen::GpuDevice, T>::operator()( \
+ const Eigen::GpuDevice& d, typename TTypes<T, 4>::Tensor output, \
+ typename TTypes<T, 4>::ConstTensor input, int window_rows, \
+ int window_cols, int row_stride, int col_stride, \
+ const Eigen::PaddingType& padding); \
+ extern template struct SpatialMaxPooling<Eigen::GpuDevice, T>;
+
+DECLARE_GPU_SPEC(float);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+// Note(jiayq): Currently, the Caffe custom implementation is faster than the
+// default Eigen implementation so we are using the custom kernel as the
+// default. However, you can explicitly invoke the eigen version using
+// kernel_label_map.
+REGISTER_KERNEL_BUILDER(Name("MaxPool")
+ .Device(DEVICE_GPU)
+ .Label("eigen_tensor"),
+ MaxPoolingOp<Eigen::GpuDevice, float>);
+#endif // GOOGLE_CUDA
+
+// The operation to compute MaxPool gradients.
+// It takes three inputs:
+// - The original input tensor
+// - The original output tensor
+// - Backprop tensor for output
+// It produces one output: backprop tensor for input.
+template <class Device, class T>
+class MaxPoolingGradOp : public OpKernel {
+ public:
+ explicit MaxPoolingGradOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ OP_REQUIRES(
+ context, ksize_[3] == 1 && stride_[3] == 1,
+ errors::Unimplemented(
+ "MaxPoolingGrad is not yet supported on the depth dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = context->input(0);
+ const Tensor& tensor_out = context->input(1);
+ const Tensor& out_backprop = context->input(2);
+
+ // For maxpooling, tensor_in should have 4 dimensions.
+ OP_REQUIRES(context, tensor_in.dims() == 4,
+ errors::InvalidArgument("tensor_in must be 4-dimensional"));
+ OP_REQUIRES(context, tensor_out.dims() == 4,
+ errors::InvalidArgument("tensor_out must be 4-dimensional"));
+ // For maxpooling, out_backprop should have 4 dimensions.
+ OP_REQUIRES(context, out_backprop.dims() == 4,
+ errors::InvalidArgument("out_backprop must be 4-dimensional"));
+
+ TensorShape output_shape = tensor_in.shape();
+
+ // Tensor index_tensor(context->allocator(), DT_INT32, output_shape);
+
+ Tensor tensor_out_dup;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::v(),
+ tensor_out.shape(), &tensor_out_dup));
+ Tensor tensor_out_arg_max;
+ OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<int64>::v(),
+ tensor_out.shape(),
+ &tensor_out_arg_max));
+
+ PoolParameters params{context, ksize_, stride_, padding_,
+ tensor_in.shape()};
+ if (!context->status().ok()) {
+ return;
+ }
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+ output->flat<T>().setZero();
+
+ SpatialMaxPoolWithArgMaxHelper<CPUDevice, T>::Compute(
+ &tensor_out_dup, &tensor_out_arg_max, tensor_in, params, padding_);
+ auto out_backprop_flat = out_backprop.flat<T>();
+ auto input_backprop_flat = output->flat<T>();
+ auto out_arg_max_flat = tensor_out_arg_max.flat<int64>();
+ int num_total_outputs = out_backprop.flat<T>().size();
+ int num_total_inputs = input_backprop_flat.size();
+
+ for (int index = 0; index < num_total_outputs; ++index) {
+ int input_backprop_index = out_arg_max_flat(index);
+ // Although this check is in the inner loop, it is worth its value
+ // so we don't end up with memory corruptions. Our benchmark shows that
+ // the performance impact is quite small
+ CHECK(input_backprop_index >= 0 &&
+ input_backprop_index < num_total_inputs)
+ << "Invalid input backprop index: " << input_backprop_index << ", "
+ << num_total_inputs;
+ input_backprop_flat(input_backprop_index) += out_backprop_flat(index);
+ }
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("MaxPoolGrad").Device(DEVICE_CPU),
+ MaxPoolingGradOp<CPUDevice, float>);
+
+#ifdef GOOGLE_CUDA
+
+static void MaxPoolingBackwardCustomKernel(
+ OpKernelContext* context, const std::vector<int32>& size,
+ const std::vector<int32>& stride, Padding padding, const Tensor* tensor_in,
+ const Tensor& out_backprop, const TensorShape& tensor_in_shape) {
+ Tensor* output = nullptr;
+
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, tensor_in_shape, &output));
+
+ PoolParameters params{context, size, stride, padding, tensor_in_shape};
+ if (!context->status().ok()) {
+ return;
+ }
+
+ MaxPoolBackwardNoMask(
+ tensor_in->flat<float>().data(), params.tensor_in_batch,
+ params.tensor_in_rows, params.tensor_in_cols, params.depth,
+ params.out_height, params.out_width, params.window_rows,
+ params.window_cols, params.row_stride, params.col_stride, params.pad_rows,
+ params.pad_cols, out_backprop.flat<float>().data(),
+ output->flat<float>().data(), context->eigen_device<Eigen::GpuDevice>());
+}
+
+template <class T>
+class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
+ public:
+ typedef Eigen::GpuDevice Device;
+
+ explicit MaxPoolingGradOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+
+ use_dnn_ = CanUseCudnn();
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = context->input(0);
+ const Tensor& tensor_out = context->input(1);
+ const Tensor& out_backprop = context->input(2);
+
+ // For maxpooling, tensor_in should have 4 dimensions.
+ OP_REQUIRES(context, tensor_in.dims() == 4,
+ errors::InvalidArgument("tensor_in must be 4-dimensional 4"));
+ OP_REQUIRES(context, tensor_out.dims() == 4,
+ errors::InvalidArgument("tensor_out must be 4-dimensional"));
+ // For maxpooling, out_backprop should have 4 dimensions.
+ OP_REQUIRES(context, out_backprop.dims() == 4,
+ errors::InvalidArgument("out_backprop must be 4-dimensional"));
+
+ TensorShape output_shape = tensor_in.shape();
+
+ if (use_dnn_) {
+ DnnPoolingGradOp<T>::Compute(
+ context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize_,
+ stride_, padding_, &tensor_in, &tensor_out, out_backprop,
+ output_shape);
+ } else {
+ MaxPoolingBackwardCustomKernel(context, ksize_, stride_, padding_,
+ &tensor_in, out_backprop, output_shape);
+ }
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ bool use_dnn_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("MaxPoolGrad").Device(DEVICE_GPU),
+ MaxPoolingGradOp<Eigen::GpuDevice, float>);
+
+#endif // GOOGLE_CUDA
+
+template <typename Device, typename T>
+struct LaunchMaxPoolingNoMask;
+
+template <typename Device, typename T>
+class MaxPoolingNoMaskOp : public OpKernel {
+ public:
+ explicit MaxPoolingNoMaskOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window stride field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = context->input(0);
+
+ PoolParameters params{context, ksize_, stride_, padding_,
+ tensor_in.shape()};
+ if (!context->status().ok()) {
+ return;
+ }
+
+ TensorShape out_shape({params.tensor_in_batch, params.out_height,
+ params.out_width, params.depth});
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+
+ LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
+ output);
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
+template <typename Device, typename T>
+struct LaunchMaxPoolingWithArgmax;
+
+template <typename Device, typename T>
+class MaxPoolingWithArgmaxOp : public OpKernel {
+ public:
+ explicit MaxPoolingWithArgmaxOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window stride field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = context->input(0);
+
+ PoolParameters params{context, ksize_, stride_, padding_,
+ tensor_in.shape()};
+ if (!context->status().ok()) {
+ return;
+ }
+
+ TensorShape out_shape({params.tensor_in_batch, params.out_height,
+ params.out_width, params.depth});
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+ Tensor* argmax = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(1, out_shape, &argmax));
+
+ LaunchMaxPoolingWithArgmax<Device, T>::launch(context, params, tensor_in,
+ output, argmax);
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
+template <typename Device, typename T>
+struct LaunchMaxPoolingGradWithArgmax;
+
+template <typename Device, typename T>
+class MaxPoolingGradWithArgmaxOp : public OpKernel {
+ public:
+ explicit MaxPoolingGradWithArgmaxOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument(
+ "Sliding window stride field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = context->input(0);
+ const Tensor& grad_in = context->input(1);
+ const Tensor& argmax = context->input(2);
+
+ PoolParameters params{context, ksize_, stride_, padding_,
+ tensor_in.shape()};
+ if (!context->status().ok()) {
+ return;
+ }
+
+ TensorShape out_shape({params.tensor_in_batch, params.tensor_in_rows,
+ params.tensor_in_cols, params.depth});
+ Tensor* grad_out = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &grad_out));
+
+ LaunchMaxPoolingGradWithArgmax<Device, T>::launch(context, params, grad_in,
+ argmax, grad_out);
+ }
+
+ private:
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
+#if GOOGLE_CUDA
+
+template <typename T>
+struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> {
+ static void launch(OpKernelContext* context, const PoolParameters& params,
+ const Tensor& input, Tensor* output) {
+ bool status = MaxPoolForwardWithOptionalArgmax(
+ input.flat<T>().data(), params.tensor_in_batch, params.tensor_in_rows,
+ params.tensor_in_cols, params.depth, params.out_height,
+ params.out_width, params.window_rows, params.window_cols,
+ params.row_stride, params.col_stride, params.pad_rows, params.pad_cols,
+ output->flat<T>().data(), nullptr, context->eigen_gpu_device());
+ if (!status) {
+ context->SetStatus(
+ errors::Internal("Failed launching MaxPoolForwardNoMask"));
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MaxPool").Device(DEVICE_GPU),
+ MaxPoolingNoMaskOp<Eigen::GpuDevice, float>);
+
+template <typename T>
+struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> {
+ static void launch(OpKernelContext* context, const PoolParameters& params,
+ const Tensor& input, Tensor* output, Tensor* argmax) {
+ bool status = MaxPoolForwardWithOptionalArgmax(
+ input.flat<T>().data(), params.tensor_in_batch, params.tensor_in_rows,
+ params.tensor_in_cols, params.depth, params.out_height,
+ params.out_width, params.window_rows, params.window_cols,
+ params.row_stride, params.col_stride, params.pad_rows, params.pad_cols,
+ output->flat<T>().data(),
+ reinterpret_cast<int64*>(argmax->flat<int64>().data()),
+ context->eigen_gpu_device());
+ if (!status) {
+ context->SetStatus(
+ errors::Internal("Failed launching MaxPoolForwardWithArgmax"));
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int64>("Targmax"),
+ MaxPoolingWithArgmaxOp<Eigen::GpuDevice, float>);
+
+template <typename T>
+struct LaunchMaxPoolingGradWithArgmax<Eigen::GpuDevice, T> {
+ static void launch(OpKernelContext* context, const PoolParameters& params,
+ const Tensor& grad_in, const Tensor& argmax,
+ Tensor* grad_out) {
+ const int input_size = params.tensor_in_batch * params.tensor_in_rows *
+ params.tensor_in_cols * params.depth;
+ const int output_size = params.tensor_in_batch * params.out_height *
+ params.out_width * params.depth;
+ const int top_offset = params.out_height * params.out_width * params.depth;
+ const int bottom_offset =
+ params.tensor_in_rows * params.tensor_in_cols * params.depth;
+ bool status = MaxPoolBackwardWithArgmax(
+ output_size, input_size, grad_in.flat<T>().data(),
+ reinterpret_cast<const int64*>(argmax.flat<int64>().data()), top_offset,
+ bottom_offset, grad_out->flat<T>().data(), context->eigen_gpu_device());
+ if (!status) {
+ context->SetStatus(
+ errors::Internal("Failed launching MaxPoolForwardWithArgmax"));
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int64>("Targmax"),
+ MaxPoolingGradWithArgmaxOp<Eigen::GpuDevice, float>);
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/maxpooling_op.h b/tensorflow/core/kernels/maxpooling_op.h
new file mode 100644
index 0000000000..a074174118
--- /dev/null
+++ b/tensorflow/core/kernels/maxpooling_op.h
@@ -0,0 +1,29 @@
+#ifndef TENSORFLOW_KERNELS_MAXPOOLING_OP_H_
+#define TENSORFLOW_KERNELS_MAXPOOLING_OP_H_
+// Functor definition for MaxPoolingOp, must be compilable by nvcc.
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T>
+struct SpatialMaxPooling {
+ void operator()(const Device& d, typename TTypes<T, 4>::Tensor output,
+ typename TTypes<T, 4>::ConstTensor input, int window_rows,
+ int window_cols, int row_stride, int col_stride,
+ const Eigen::PaddingType& padding) {
+ // Because we swap the layout, we swap the row/cols as well
+ output.swap_layout().device(d) =
+ Eigen::SpatialMaxPooling(input.swap_layout(), window_cols, window_rows,
+ col_stride, row_stride, padding);
+ }
+};
+
+} // namespace functor
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_MAXPOOLING_OP_H_
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
new file mode 100644
index 0000000000..65262eb54e
--- /dev/null
+++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
@@ -0,0 +1,261 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <stdio.h>
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/maxpooling_op.h"
+#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
+
+namespace tensorflow {
+namespace {
+// This is Yangqing's custom kernel for the maxpooling operation. There are
+// three functions: MaxPoolForwardNCHW and MaxPoolForwardNHWC are the two
+// forward functions, dealing with the forward case. MaxPoolBackward is the
+// backward function that deals with the backward case for both storage orders.
+// The parameters to the kernels in the forward function is as follows:
+// nthreads: the number of threads, which is equal to the output size.
+// bottom_data: the bottom data of N*H*W*C (or N*C*H*W) items.
+// height, width, pooled_height, pooled_width: the input and output sizes.
+// kernel_h, kernel_w: the kernel sizes.
+// stride_h, stride_w: the strides.
+// pad_t, pad_l: the padding values on the top and left side.
+// top_data: the maxpool output.
+// mask: the output mask of the same size as top_data. It is stored in
+// int form, keeping track of the flattened index of the input item that
+// produces the max output. If a nullptr is passed in for mask, no mask
+// will be produced.
+#define CUDA_1D_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); i += blockDim.x * gridDim.x)
+
+// To call the forward and backward functions, use e.g.:
+// const int kThreadsPerBlock = 1024
+// const int output_size = batch * channels * pooled_height * pooled_width;
+// MaxPoolForwardNCHW<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
+// kThreadsPerBlock, 0, cuda_stream>>>(...);
+template <typename dtype>
+__global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data,
+ const int channels, const int height,
+ const int width, const int pooled_height,
+ const int pooled_width, const int kernel_h,
+ const int kernel_w, const int stride_h,
+ const int stride_w, const int pad_t,
+ const int pad_l, dtype* top_data,
+ int64* mask) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+ int hstart = ph * stride_h - pad_t;
+ int wstart = pw * stride_w - pad_l;
+ int hend = min(hstart + kernel_h, height);
+ int wend = min(wstart + kernel_w, width);
+ hstart = max(hstart, 0);
+ wstart = max(wstart, 0);
+ dtype maxval = -FLT_MAX;
+ int maxidx = -1;
+ const dtype* bottom_data_n = bottom_data + n * channels * height * width;
+ for (int h = hstart; h < hend; ++h) {
+ for (int w = wstart; w < wend; ++w) {
+ int idx = c * height * width + h * width + w;
+ if (bottom_data_n[idx] > maxval) {
+ maxidx = idx;
+ maxval = bottom_data_n[idx];
+ }
+ }
+ }
+ top_data[index] = maxval;
+ if (mask != nullptr) {
+ mask[index] = maxidx;
+ }
+ }
+}
+
+template <typename dtype>
+__global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data,
+ const int height, const int width,
+ const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h,
+ const int kernel_w, const int stride_h,
+ const int stride_w, const int pad_t,
+ const int pad_l, dtype* top_data,
+ int64* mask) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ int n = index;
+ int c = n % channels;
+ n /= channels;
+ int wstart = (n % pooled_width) * stride_w - pad_l;
+ n /= pooled_width;
+ int hstart = (n % pooled_height) * stride_h - pad_t;
+ n /= pooled_height;
+ int hend = min(hstart + kernel_h, height);
+ int wend = min(wstart + kernel_w, width);
+ hstart = max(hstart, 0);
+ wstart = max(wstart, 0);
+ dtype maxval = -FLT_MAX;
+ int maxidx = -1;
+ const dtype* bottom_data_n = bottom_data + n * height * width * channels;
+ for (int h = hstart; h < hend; ++h) {
+ for (int w = wstart; w < wend; ++w) {
+ int idx = (h * width + w) * channels + c;
+ if (bottom_data_n[idx] > maxval) {
+ maxidx = idx;
+ maxval = bottom_data_n[idx];
+ }
+ }
+ }
+ top_data[index] = maxval;
+ if (mask != nullptr) {
+ mask[index] = maxidx;
+ }
+ }
+}
+
+template <typename dtype>
+__global__ void MaxPoolBackwardNoMaskNHWC(
+ const int nthreads, const dtype* bottom_data, const int height,
+ const int width, const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h, const int kernel_w,
+ const int stride_h, const int stride_w, const int pad_t, const int pad_l,
+ const dtype* top_diff, dtype* bottom_diff) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // First find out the index to the maximum, since we have no mask.
+ int n = index;
+ int c = n % channels;
+ n /= channels;
+ int wstart = (n % pooled_width) * stride_w - pad_l;
+ n /= pooled_width;
+ int hstart = (n % pooled_height) * stride_h - pad_t;
+ n /= pooled_height;
+ int hend = min(hstart + kernel_h, height);
+ int wend = min(wstart + kernel_w, width);
+ hstart = max(hstart, 0);
+ wstart = max(wstart, 0);
+ dtype maxval = -FLT_MAX;
+ int maxidx = -1;
+ const dtype* bottom_data_n = bottom_data + n * height * width * channels;
+ for (int h = hstart; h < hend; ++h) {
+ for (int w = wstart; w < wend; ++w) {
+ int idx = (h * width + w) * channels + c;
+ if (bottom_data_n[idx] > maxval) {
+ maxidx = idx;
+ maxval = bottom_data_n[idx];
+ }
+ }
+ }
+
+ // Atomically accumulate the bottom diff. The index could still be
+ // uninitialized, if all the bottom_data are NaN.
+ if (maxidx != -1) {
+ atomicAdd(bottom_diff + n * height * width * channels + maxidx,
+ top_diff[index]);
+ }
+ }
+}
+
+// The parameters to the kernels in the backward function is as follows:
+// nthreads: the number of threads, which is equal to the output size.
+// top_diff: the gradient of the output data, of size N*Hout*Wout*C (or
+// N*C*Hout*Wout). As we have stored the flattened index of the input
+// entries, the backward function is agnostic of the input storage order.
+// mask: the output mask of the same size as top_data. It is stored in
+// int form, keeping track of the flattened index of the input item that
+// produces the max output.
+// top_offset: the pre-computed per-image offset of the maxpool output. This
+// is equal to Hout*Wout*C. We choose to pre-compute this so we do not
+// need to compute it every time inside the kernel.
+// bottom_offset: the pre-computed per-image offset of the maxpool input.
+// This is equal to H*W*C.
+// bottom_diff: the gradient with respect to the input.
+// This function relies on atomicAdd to avoid race conditions. Also, before the
+// kernel is run, you will need to make sure that bottom_diff is filled with
+// zero first.
+template <typename dtype>
+__global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff,
+ const int64* mask, const int top_offset,
+ const int bottom_offset, dtype* bottom_diff) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ int image_id = (index / top_offset);
+ atomicAdd(bottom_diff + image_id * bottom_offset + mask[index],
+ top_diff[index]);
+ }
+}
+
+template <typename dtype>
+__global__ void SetZero(const int nthreads, dtype* bottom_diff) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) { *(bottom_diff + index) = dtype(0); }
+}
+
+#undef CUDA_1D_KERNEL_LOOP
+} // namespace
+
+bool MaxPoolForwardWithOptionalArgmax(
+ const float* bottom_data, const int batch, const int height,
+ const int width, const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h, const int kernel_w,
+ const int stride_h, const int stride_w, const int pad_t, const int pad_l,
+ float* top_data, int64* mask, const Eigen::GpuDevice& d) {
+ const int kThreadsPerBlock = 1024;
+ const int output_size = batch * channels * pooled_height * pooled_width;
+
+ MaxPoolForwardNHWC<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>(
+ output_size, bottom_data, height, width, channels, pooled_height,
+ pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
+ top_data, mask);
+ return d.ok();
+}
+
+bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch,
+ const int height, const int width,
+ const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h,
+ const int kernel_w, const int stride_h,
+ const int stride_w, const int pad_t, const int pad_l,
+ const float* top_diff, float* bottom_diff,
+ const Eigen::GpuDevice& d) {
+ const int kThreadsPerBlock = 1024;
+ const int bottom_size = batch * channels * height * width;
+ const int top_size = batch * channels * pooled_height * pooled_width;
+
+ SetZero<<<(bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>(bottom_size, bottom_diff);
+
+ MaxPoolBackwardNoMaskNHWC<<<(top_size + kThreadsPerBlock - 1) /
+ kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>(
+ top_size, bottom_data, height, width, channels, pooled_height,
+ pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
+ top_diff, bottom_diff);
+ return d.ok();
+}
+
+bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
+ const float* top_diff, const int64* mask,
+ const int top_offset, const int bottom_offset,
+ float* bottom_diff, const Eigen::GpuDevice& d) {
+ const int kThreadsPerBlock = 1024;
+ SetZero<<<(input_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>(input_size, bottom_diff);
+ MaxPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>(
+ output_size, top_diff, mask, top_offset, bottom_offset, bottom_diff);
+ return d.ok();
+}
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#define DEFINE_GPU_KERNELS(T) \
+ template struct functor::SpatialMaxPooling<GPUDevice, T>;
+
+DEFINE_GPU_KERNELS(float)
+
+#undef DEFINE_GPU_KERNELS
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.h b/tensorflow/core/kernels/maxpooling_op_gpu.h
new file mode 100644
index 0000000000..bfdac904cc
--- /dev/null
+++ b/tensorflow/core/kernels/maxpooling_op_gpu.h
@@ -0,0 +1,42 @@
+#if !GOOGLE_CUDA
+#error This file must only be included when building with Cuda support
+#endif
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_GPU_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_GPU_H_
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
+
+namespace tensorflow {
+
+// Run the forward pass of max pooling, optionally writing the argmax indices to
+// the mask array, if it is not nullptr. If mask is passed in as nullptr, the
+// argmax indices are not written.
+bool MaxPoolForwardWithOptionalArgmax(
+ const float* bottom_data, const int batch, const int height,
+ const int width, const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h, const int kernel_w,
+ const int stride_h, const int stride_w, const int pad_t, const int pad_l,
+ float* top_data, int64* mask, const Eigen::GpuDevice& d);
+
+bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
+ const float* top_diff, const int64* mask,
+ const int top_offset, const int bottom_offset,
+ float* bottom_diff, const Eigen::GpuDevice& d);
+
+bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch,
+ const int height, const int width,
+ const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h,
+ const int kernel_w, const int stride_h,
+ const int stride_w, const int pad_t, const int pad_l,
+ const float* top_diff, float* bottom_diff,
+ const Eigen::GpuDevice& d);
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_GPU_H_
diff --git a/tensorflow/core/kernels/no_op.cc b/tensorflow/core/kernels/no_op.cc
new file mode 100644
index 0000000000..b4f9df81a6
--- /dev/null
+++ b/tensorflow/core/kernels/no_op.cc
@@ -0,0 +1,8 @@
+#include "tensorflow/core/kernels/no_op.h"
+
+namespace tensorflow {
+
+REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_CPU), NoOp);
+REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE_GPU), NoOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/no_op.h b/tensorflow/core/kernels/no_op.h
new file mode 100644
index 0000000000..a3bcbd7680
--- /dev/null
+++ b/tensorflow/core/kernels/no_op.h
@@ -0,0 +1,17 @@
+#ifndef TENSORFLOW_KERNELS_NO_OP_H_
+#define TENSORFLOW_KERNELS_NO_OP_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+class NoOp : public OpKernel {
+ public:
+ explicit NoOp(OpKernelConstruction* context) : OpKernel(context) {}
+ void Compute(OpKernelContext* context) override {}
+ bool IsExpensive() override { return false; }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_NO_OP_H_
diff --git a/tensorflow/core/kernels/ops_testutil.cc b/tensorflow/core/kernels/ops_testutil.cc
new file mode 100644
index 0000000000..7bea17b9e2
--- /dev/null
+++ b/tensorflow/core/kernels/ops_testutil.cc
@@ -0,0 +1,18 @@
+#include "tensorflow/core/kernels/ops_testutil.h"
+
+namespace tensorflow {
+namespace test {
+
+NodeDef Node(const string& name, const string& op,
+ const std::vector<string>& inputs) {
+ NodeDef def;
+ def.set_name(name);
+ def.set_op(op);
+ for (const string& s : inputs) {
+ def.add_input(s);
+ }
+ return def;
+}
+
+} // namespace test
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h
new file mode 100644
index 0000000000..7a3405bf04
--- /dev/null
+++ b/tensorflow/core/kernels/ops_testutil.h
@@ -0,0 +1,191 @@
+#ifndef TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
+#define TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/tensor_slice_reader_cache.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+namespace test {
+
+// Return a NodeDef with the specified name/op/inputs.
+NodeDef Node(const string& name, const string& op,
+ const std::vector<string>& inputs);
+
+} // namespace test
+
+// Helpful functions to test operators.
+//
+// This class will eventually be replaced / heavily modified
+// to use the BrainClient interface.
+class OpsTestBase : public ::testing::Test {
+ public:
+ OpsTestBase() : device_type_(DEVICE_CPU) {
+ device_.reset(
+ DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
+ CHECK(device_.get()) << "Could not create CPU device";
+ }
+
+ ~OpsTestBase() override {
+ gtl::STLDeleteElements(&tensors_);
+ context_.reset(nullptr);
+ }
+
+ void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); }
+
+ // Clients can manipulate the underlying NodeDef via this accessor.
+ NodeDef* node_def() { return &node_def_; }
+
+ // Initializes an operator that takes in 'input_types' as input
+ // and output types as output.
+ //
+ // Returns the status of initialization.
+ Status InitOp() {
+ Status status;
+ kernel_ = CreateOpKernel(device_type_, device_.get(), allocator(),
+ node_def_, &status);
+ if (kernel_ != nullptr) input_types_ = kernel_->input_types();
+ return status;
+ }
+
+ // Adds an input for every element described by the shape.
+ // 'input_mapping' maps an index (0...NumElements(shape)) to a
+ // value.
+ //
+ // TODO(vrv): Replace with something like a BrainClient Feed.
+ template <typename T>
+ void AddInput(const TensorShape& shape, std::function<T(int)> input_mapping) {
+ CHECK_GT(input_types_.size(), inputs_.size())
+ << "Adding more inputs than types; perhaps you need to call MakeOp";
+ bool is_ref = IsRefType(input_types_[inputs_.size()]);
+ Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()),
+ DataTypeToEnum<T>::v(), shape);
+ test::FillFn(input, input_mapping);
+ tensors_.push_back(input);
+ if (is_ref) {
+ CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]),
+ DataTypeToEnum<T>::v());
+ inputs_.push_back({&lock_for_refs_, input});
+ } else {
+ CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum<T>::v());
+ inputs_.push_back({nullptr, input});
+ }
+ }
+
+ // Like AddInput but takes in an explicit arrayslice of data.
+ template <typename T>
+ void AddInputFromArray(const TensorShape& shape,
+ const gtl::ArraySlice<T>& data) {
+ CHECK_GT(input_types_.size(), inputs_.size())
+ << "Adding more inputs than types; perhaps you need to call MakeOp";
+ bool is_ref = IsRefType(input_types_[inputs_.size()]);
+ Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()),
+ DataTypeToEnum<T>::v(), shape);
+ test::FillValues<T>(input, data);
+ tensors_.push_back(input);
+ if (is_ref) {
+ CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]),
+ DataTypeToEnum<T>::v());
+ inputs_.push_back({&lock_for_refs_, input});
+ } else {
+ CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum<T>::v());
+ inputs_.push_back({nullptr, input});
+ }
+ }
+
+ // Runs an operation producing 'num_outputs' outputs.
+ //
+ // Returns the context's status after running the operation.
+ Status RunOpKernel() {
+ OpKernelContext::Params params;
+ params.device = device_.get();
+ params.frame_iter = FrameAndIter(0, 0);
+ params.inputs = &inputs_;
+ params.op_kernel = kernel_.get();
+ params.output_alloc_attr = [this, &params](int index) {
+ AllocatorAttributes attr;
+ const bool on_host =
+ (kernel_->output_memory_types()[index] == HOST_MEMORY);
+ attr.set_on_host(on_host);
+ return attr;
+ };
+ checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
+ params.slice_reader_cache = &slice_reader_cache_wrapper;
+
+ context_.reset(new OpKernelContext(params));
+ device_->Compute(kernel_.get(), context_.get());
+ return context_->status();
+ }
+
+ // Returns the tensor input for 'input_index'.
+ //
+ // REQUIRES: 0 <= input_index < context_->num_inputs()
+ const Tensor& GetInput(int input_index) const {
+ CHECK_LT(input_index, context_->num_inputs());
+ CHECK(!IsRefType(context_->input_dtype(input_index)));
+ return context_->input(input_index);
+ }
+
+ TensorValue mutable_input(int input_index) {
+ CHECK_LT(input_index, inputs_.size());
+ return inputs_[input_index];
+ }
+ // Returns the tensor output for 'output_index'.
+ //
+ // REQUIRES: 0 <= output_index < context_->num_outputs()
+ Tensor* GetOutput(int output_index) {
+ CHECK_LT(output_index, context_->num_outputs());
+ return context_->mutable_output(output_index);
+ }
+
+ Allocator* allocator() {
+ return device_->GetAllocator(AllocatorAttributes());
+ }
+
+ const DataTypeVector& output_types() const { return kernel_->output_types(); }
+
+ protected:
+ std::unique_ptr<Device> device_;
+
+ std::unique_ptr<OpKernel> kernel_;
+ NodeDef node_def_;
+ DataTypeVector input_types_;
+ DeviceType device_type_;
+
+ mutex lock_for_refs_; // Used as the Mutex for inputs added as refs
+
+ gtl::InlinedVector<TensorValue, 4> inputs_;
+ // Owns Tensors.
+ std::vector<Tensor*> tensors_;
+
+ std::unique_ptr<OpKernelContext> context_;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
diff --git a/tensorflow/core/kernels/ops_util.cc b/tensorflow/core/kernels/ops_util.cc
new file mode 100644
index 0000000000..ca2925128e
--- /dev/null
+++ b/tensorflow/core/kernels/ops_util.cc
@@ -0,0 +1,113 @@
+#include <cmath>
+
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/padding.h"
+
+namespace tensorflow {
+
+void RequireDefaultOps() {
+// TODO(opensource): Use a more generic sounding preprocessor name than
+// GOOGLE_CUDA (maybe SUPPORT_CUDA?)
+#if GOOGLE_CUDA
+ void RequireGPUDevice();
+ RequireGPUDevice();
+#endif
+}
+
+Status Get2dOutputSize(const int in_height, const int in_width,
+ int filter_height, int filter_width, int row_stride,
+ int col_stride, Padding padding, int* new_height,
+ int* new_width, int* pad_rows, int* pad_cols) {
+ int pad_bottom_unused, pad_right_unused;
+ return Get2dOutputSizeVerbose(
+ in_height, in_width, filter_height, filter_width, row_stride, col_stride,
+ padding, new_height, new_width, pad_rows, &pad_bottom_unused, pad_cols,
+ &pad_right_unused);
+}
+
+Status Get2dOutputSizeVerbose(const int in_height, const int in_width,
+ int filter_height, int filter_width,
+ int row_stride, int col_stride, Padding padding,
+ int* new_height, int* new_width, int* pad_top,
+ int* pad_bottom, int* pad_left, int* pad_right) {
+ // Cannot have strides larger than the patch size.
+ if (row_stride > filter_height || col_stride > filter_width) {
+ return errors::InvalidArgument(
+ "stride must be less than or equal to kernel size");
+ }
+ switch (padding) {
+ case Padding::VALID:
+ *new_height = ceil((in_height - filter_height + 1.f) /
+ static_cast<float>(row_stride));
+ *new_width = ceil((in_width - filter_width + 1.f) /
+ static_cast<float>(col_stride));
+ *pad_top = 0;
+ *pad_bottom = 0;
+ *pad_left = 0;
+ *pad_right = 0;
+ break;
+ case Padding::SAME:
+ *new_height = ceil(in_height / static_cast<float>(row_stride));
+ *new_width = ceil(in_width / static_cast<float>(col_stride));
+ // Calculate padding for top/bottom/left/right, spilling any excess
+ // padding to bottom and right.
+ const int pad_needed_height =
+ (*new_height - 1) * row_stride + filter_height - in_height;
+ *pad_top = pad_needed_height / 2;
+ CHECK_GE(pad_needed_height, 0);
+ *pad_bottom = pad_needed_height - *pad_top;
+
+ const int pad_needed_width =
+ (*new_width - 1) * col_stride + filter_width - in_width;
+ *pad_left = pad_needed_width / 2;
+ CHECK_GE(pad_needed_width, 0);
+ *pad_right = pad_needed_width - *pad_left;
+ break;
+ }
+ if (*new_height < 0 || *new_width < 0) {
+ return errors::InvalidArgument("computed output size would be negative");
+ }
+ return Status::OK();
+}
+
+Eigen::PaddingType BrainPadding2EigenPadding(Padding padding) {
+ switch (padding) {
+ case Padding::VALID:
+ return Eigen::PADDING_VALID;
+ case Padding::SAME:
+ return Eigen::PADDING_SAME;
+ }
+ return Eigen::PADDING_SAME; // Prevent compiler warning about missing return
+}
+
+Status GetBroadcastSize(const int index, const int in_size,
+ const int ksize, const int stride,
+ const int pad_size, int* bindex, int* bsize) {
+ // Cannot have strides larger than the patch size.
+ if (stride > ksize) {
+ return errors::InvalidArgument(
+ "stride must be less than or equal to kernel size");
+ }
+ // Cannot have index beyond the input size.
+ if (index * stride > in_size) {
+ return errors::InvalidArgument(
+ "index * stride must be less than or equal to input size");
+ }
+ *bindex = index * stride;
+ *bsize = ksize;
+ if (*bindex < pad_size) {
+ // If the current index is in the padding area, start broadcast from index
+ // 0 with broadcast size reduced by padding size.
+ *bsize = ksize + *bindex - pad_size;
+ *bindex = 0;
+ } else {
+ // Otherwise, start broadcast from current index reduced by padding size.
+ *bindex -= pad_size;
+ }
+ if (*bindex + ksize > in_size) {
+ *bsize = std::min((in_size - *bindex), ksize);
+ }
+ return Status::OK();
+}
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/ops_util.h b/tensorflow/core/kernels/ops_util.h
new file mode 100644
index 0000000000..283338f8df
--- /dev/null
+++ b/tensorflow/core/kernels/ops_util.h
@@ -0,0 +1,180 @@
+#ifndef TENSORFLOW_KERNELS_OPS_UTIL_H_
+#define TENSORFLOW_KERNELS_OPS_UTIL_H_
+
+// This file contains utilities for various operations.
+
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+// Call this function from a test if op kernels are not being
+// registered. This can happen if the test is linked in a shared
+// mode and has no direct references to any code from this directory.
+void RequireDefaultOps();
+
+// Get2dOutputSize(): Given an input tensor, kernel, stride and padding
+// type, the function computes the output and padding dimensions.
+//
+// Convolution layers take in an input tensor of shape (D, C, R, B), and
+// convolve it with a set of filters, which can also be presented as a
+// tensor (D, K, K, M), where M is the number of filters, K is the filter size,
+// and each 3-dimensional tensor of size (D, K, K) is a filter. For
+// simplicity we assume that we always use square filters (which is usually the
+// case in images). It also takes in a few additional parameters:
+//
+// Stride (S): the stride with which we apply the filters. This is the offset
+// between locations where we apply the filters. A larger stride
+// means that the output will be spatially smaller.
+//
+// Padding (P): the padding we apply to the input tensor along the R and C
+// dimensions. This is usually used to make sure that the spatial dimension
+// do not shrink when we progress with convolutions. Two types of padding are
+// often used:
+// SAME: the pad value is computed so that the output will have size R/S
+// and C/S.
+// VALID: no padding is carried out.
+// The padded area is zero-filled.
+//
+// The output dimensions for convolution and many other operations, when given
+// all the parameters above, are as follows:
+// - When Padding = SAME: the output size is (B, R', C', M), where
+// R' = ceil(float(R) / float(S))
+// C' = ceil(float(C) / float(S))
+// where ceil is the ceiling function. The number of padded rows and columns
+// are computed as:
+// Pr = ((R' - 1) * S + K - R) / 2
+// Pc = ((C' - 1) * S + K - C) / 2
+// When the stride is 1, we have the simplified case
+// R'=R, C'=C, Pr=Pc=(K-1)/2.
+// This is where SAME comes from - the output has the same size as the input
+// has.
+//
+// - When Padding = VALID: the output size is computed as
+// R' = ceil(float(R - K + 1) / float(S))
+// C' = ceil(float(C - K + 1) / float(S))
+// and the number of padded rows and columns are computed in the same way.
+// When the stride is 1, we have the simplified case
+// R'=R-K+1, C'=C-K+1, Pr=0, Pc=0.
+//
+// For convolution, mathematically, the output value at location (b, r', c', m)
+// is the inner product of two vectors: the chunk of input at
+// (b, (r'*S-Pr) : (r'*S-Pr+K), (c'*S-Pc) : (c'*S-Pc+K), :),
+// and the filter at (m, :, :, :).
+//
+Status Get2dOutputSize(const int in_height, const int in_width,
+ int filter_height, int filter_width, int row_stride,
+ int col_stride, Padding padding, int* new_height,
+ int* new_width, int* pad_rows, int* pad_cols);
+
+// Returns the same output dimensions as in Get2dOutputSize, but returns verbose
+// padding dimensions (top/bottom/left/right). Any excess padding (caused by
+// an odd padding size value) is added to the 'pad_bottom' and 'pad_right'
+// dimensions.
+Status Get2dOutputSizeVerbose(const int in_height, const int in_width,
+ int filter_height, int filter_width,
+ int row_stride, int col_stride, Padding padding,
+ int* new_height, int* new_width, int* pad_top,
+ int* pad_bottom, int* pad_left, int* pad_right);
+
+// Calculates broadcast starting index and size. For SAME padding, addition
+// padding could be applied to right, left, top and bottom. Depending on the
+// current index, input size, kernel size, stride, padding size, the starting
+// index and size for broadcast for that dimension are different from the
+// current index and kernel size.
+// This is mainly used by gradient algorithms for pooling operations.
+Status GetBroadcastSize(const int index, const int in_size,
+ const int ksize, const int stride,
+ const int pad_size, int* bindex, int* bsize);
+
+// Converts Brain's Padding to Eigen's PaddingType.
+Eigen::PaddingType BrainPadding2EigenPadding(Padding padding);
+
+// Given a shape 's' of a tensor of type T. Returns true iff the
+// number of bytes occupied by each dim 0 (i.e., &tensor(i + 1, ...) -
+// &tensor(i, ...)) is multiple of EIGEN_ALIGN_BYTES.
+template <typename T>
+bool IsInnerDimsSizeAligned(const TensorShape& s) {
+ if (s.dims() == 0) return false;
+ const int64 dim0_size = s.dim_size(0);
+ if (dim0_size == 0) return false;
+ const int64 bytes_per_dim0 = (s.num_elements() / dim0_size) * sizeof(T);
+ return bytes_per_dim0 % EIGEN_MAX_ALIGN_BYTES == 0;
+}
+
+// Returns in 'col_data', image patches in storage order (height, width, depth)
+// extracted from image at 'input_data', which is requred to be in storage
+// order (batch, height, width, depth).
+// Implementation written by Yangqing Jia (jiayq).
+template <typename T>
+void Im2col(const T* input_data, const int depth, const int height,
+ const int width, const int filter_h, const int filter_w,
+ const int pad_t, const int pad_l, const int pad_b, const int pad_r,
+ const int stride_h, const int stride_w, T* col_data) {
+ int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
+ int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
+
+ int h_pad = -pad_t;
+ for (int h = 0; h < height_col; ++h) {
+ int w_pad = -pad_l;
+ for (int w = 0; w < width_col; ++w) {
+ for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
+ for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
+ if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
+ memcpy(col_data, input_data + (ih * width + iw) * depth,
+ sizeof(T) * depth);
+ } else {
+ // This should be simply padded with zero.
+ memset(col_data, 0, sizeof(T) * depth);
+ }
+ col_data += depth;
+ }
+ }
+ w_pad += stride_w;
+ }
+ h_pad += stride_h;
+ }
+}
+
+// Returns in 'im_data' image patch in storage order (height, width, depth),
+// constructed from patches in 'col_data', which is required to be in storage
+// order (out_height * out_width, filter_height, filter_width, in_depth).
+// Implementation by Yangqing Jia (jiayq).
+template <typename T>
+void Col2im(const T* col_data, const int depth, const int height,
+ const int width, const int filter_h, const int filter_w,
+ const int pad_t, const int pad_l, const int pad_b, const int pad_r,
+ const int stride_h, const int stride_w, T* im_data) {
+ memset(im_data, 0, sizeof(T) * height * width * depth);
+ int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
+ int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
+ int h_pad = -pad_t;
+ for (int h = 0; h < height_col; ++h) {
+ int w_pad = -pad_l;
+ for (int w = 0; w < width_col; ++w) {
+ T* im_patch_data = im_data + (h_pad * width + w_pad) * depth;
+ for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
+ for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
+ if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
+ // TODO(andydavis) Vectorize this loop (if compiler does not).
+ for (int i = 0; i < depth; ++i) {
+ im_patch_data[i] += col_data[i];
+ }
+ }
+ im_patch_data += depth;
+ col_data += depth;
+ }
+ // Jump over remaining number of depth.
+ im_patch_data += depth * (width - filter_w);
+ }
+ w_pad += stride_w;
+ }
+ h_pad += stride_h;
+ }
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_OPS_UTIL_H_
diff --git a/tensorflow/core/kernels/ops_util_test.cc b/tensorflow/core/kernels/ops_util_test.cc
new file mode 100644
index 0000000000..bc4f57e220
--- /dev/null
+++ b/tensorflow/core/kernels/ops_util_test.cc
@@ -0,0 +1,265 @@
+#include "tensorflow/core/kernels/ops_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+class OpsUtilTest : public ::testing::Test {
+ protected:
+ OpsUtilTest() {}
+ ~OpsUtilTest() override {}
+
+ // Padding structure.
+ struct padding_struct {
+ // Input parameters.
+ struct {
+ int in_height;
+ int in_width;
+ int filter_height;
+ int filter_width;
+ int row_stride;
+ int col_stride;
+ Padding padding;
+ } input;
+ // Output.
+ struct {
+ int new_height;
+ int new_width;
+ int pad_top;
+ int pad_bottom;
+ int pad_left;
+ int pad_right;
+ } output;
+ };
+
+ // Broadcast structure.
+ struct bcast_struct {
+ // Input parameters.
+ struct {
+ int index; // Current index.
+ int in_size; // Size of the dimension.
+ int ksize; // Kernel size.
+ int stride; // Stride.
+ int pad_size; // Padding size.
+ } input;
+ // Output.
+ struct {
+ int new_index; // New starting index.
+ int new_size; // New broadcast size.
+ } output;
+ };
+
+ static void VerifyGet2dOutputSizeBoundaries(padding_struct pad_struct,
+ error::Code code) {
+ int new_height, new_width, pad_rows, pad_cols;
+ Status status = Get2dOutputSize(
+ pad_struct.input.in_height, pad_struct.input.in_width,
+ pad_struct.input.filter_height, pad_struct.input.filter_width,
+ pad_struct.input.row_stride, pad_struct.input.col_stride,
+ pad_struct.input.padding, &new_height, &new_width, &pad_rows,
+ &pad_cols);
+ EXPECT_EQ(status.code(), code) << status;
+ }
+
+ static void VerifyGet2dOutputSizeValues(padding_struct pad_struct,
+ error::Code code) {
+ int new_height, new_width, pad_rows, pad_cols;
+ Status status = Get2dOutputSize(
+ pad_struct.input.in_height, pad_struct.input.in_width,
+ pad_struct.input.filter_height, pad_struct.input.filter_width,
+ pad_struct.input.row_stride, pad_struct.input.col_stride,
+ pad_struct.input.padding, &new_height, &new_width, &pad_rows,
+ &pad_cols);
+ EXPECT_EQ(status.code(), code) << status;
+ EXPECT_EQ(pad_struct.output.new_height, new_height);
+ EXPECT_EQ(pad_struct.output.new_width, new_width);
+ EXPECT_EQ(pad_struct.output.pad_top, pad_rows);
+ EXPECT_EQ(pad_struct.output.pad_left, pad_cols);
+ }
+
+ static void VerifyGet2dOutputVerboseSizeValues(padding_struct pad_struct,
+ error::Code code) {
+ int new_height, new_width, pad_top, pad_bottom, pad_left, pad_right;
+ Status status = Get2dOutputSizeVerbose(
+ pad_struct.input.in_height, pad_struct.input.in_width,
+ pad_struct.input.filter_height, pad_struct.input.filter_width,
+ pad_struct.input.row_stride, pad_struct.input.col_stride,
+ pad_struct.input.padding, &new_height, &new_width, &pad_top,
+ &pad_bottom, &pad_left, &pad_right);
+ EXPECT_EQ(status.code(), code) << status;
+ EXPECT_EQ(pad_struct.output.new_height, new_height);
+ EXPECT_EQ(pad_struct.output.new_width, new_width);
+ EXPECT_EQ(pad_struct.output.pad_top, pad_top);
+ EXPECT_EQ(pad_struct.output.pad_bottom, pad_bottom);
+ EXPECT_EQ(pad_struct.output.pad_left, pad_left);
+ EXPECT_EQ(pad_struct.output.pad_right, pad_right);
+ }
+
+ static void VerifyBoundaries(bcast_struct bcast, error::Code code) {
+ int new_index, new_size;
+ Status status = GetBroadcastSize(
+ bcast.input.index, bcast.input.in_size, bcast.input.ksize,
+ bcast.input.stride, bcast.input.pad_size, &new_index, &new_size);
+ EXPECT_EQ(status.code(), code) << status;
+ }
+
+ static void VerifyBcastValues(bcast_struct bcast) {
+ int new_index, new_size;
+ EXPECT_EQ(Status::OK(),
+ GetBroadcastSize(bcast.input.index, bcast.input.in_size,
+ bcast.input.ksize, bcast.input.stride,
+ bcast.input.pad_size, &new_index, &new_size));
+ EXPECT_EQ(bcast.output.new_index, new_index);
+ EXPECT_EQ(bcast.output.new_size, new_size);
+ }
+};
+
+// Test stride > ksize fails with INVALID_ARGUMENT.
+TEST_F(OpsUtilTest, Get2dOutputSizeInvalidTest) {
+ padding_struct pad_struct = {{3, 3, 1, 2, 2, 2, SAME}, {3, 3, 1, 1, 1, 1}};
+ VerifyGet2dOutputSizeBoundaries(pad_struct, error::INVALID_ARGUMENT);
+}
+
+TEST_F(OpsUtilTest, Get2dOutputSizeNegativeSizeTest) {
+ padding_struct pad_struct = {{1, 1, 3, 3, 1, 1, VALID}, {-1, -1, 0, 0, 0, 0}};
+ VerifyGet2dOutputSizeBoundaries(pad_struct, error::INVALID_ARGUMENT);
+}
+
+TEST_F(OpsUtilTest, Get2dOutputSizeSquareFilterTest) {
+ padding_struct pad_struct1 = {{3, 3, 2, 2, 2, 2, SAME}, {2, 2, 0, 0, 0, 0}};
+ padding_struct pad_struct2 = {{3, 3, 2, 2, 2, 2, VALID}, {1, 1, 0, 0, 0, 0}};
+ VerifyGet2dOutputSizeValues(pad_struct1, error::OK);
+ VerifyGet2dOutputSizeValues(pad_struct2, error::OK);
+}
+
+TEST_F(OpsUtilTest, Get2dOutputSizeNonSquareFilterTest) {
+ padding_struct pad_struct1 = {{4, 5, 1, 2, 1, 1, SAME}, {4, 5, 0, 0, 0, 0}};
+ padding_struct pad_struct2 = {{4, 5, 1, 2, 1, 1, VALID}, {4, 4, 0, 0, 0, 0}};
+ VerifyGet2dOutputSizeValues(pad_struct1, error::OK);
+ VerifyGet2dOutputSizeValues(pad_struct2, error::OK);
+}
+
+TEST_F(OpsUtilTest, Get2dOutputSizeUnevenStrideTest) {
+ padding_struct pad_struct1 = {{4, 4, 2, 2, 1, 2, VALID}, {3, 2, 0, 0, 0, 0}};
+ padding_struct pad_struct2 = {{4, 4, 2, 2, 2, 1, VALID}, {2, 3, 0, 0, 0, 0}};
+ VerifyGet2dOutputSizeValues(pad_struct1, error::OK);
+ VerifyGet2dOutputSizeValues(pad_struct2, error::OK);
+}
+
+TEST_F(OpsUtilTest, Get2dOutputSizeVerbose) {
+ padding_struct pad_struct1 = {{3, 3, 2, 2, 2, 2, SAME}, {2, 2, 0, 1, 0, 1}};
+ padding_struct pad_struct2 = {{3, 3, 2, 2, 2, 2, VALID}, {1, 1, 0, 0, 0, 0}};
+ VerifyGet2dOutputVerboseSizeValues(pad_struct1, error::OK);
+ VerifyGet2dOutputVerboseSizeValues(pad_struct2, error::OK);
+}
+
+// Test stride > ksize fails with INVALID_ARGUMENT.
+TEST_F(OpsUtilTest, GetBroadcastTest3_1_2_0) {
+ bcast_struct bcast = {{0, 3, 1, 2, 0}, {0, 3}};
+ VerifyBoundaries(bcast, error::INVALID_ARGUMENT);
+}
+
+// Test index * stride > in_size fails with INVALID_ARGUMENT.
+TEST_F(OpsUtilTest, GetBroadcastTestBadIndex) {
+ bcast_struct bcast = {{2, 3, 1, 2, 0}, {0, 3}};
+ VerifyBoundaries(bcast, error::INVALID_ARGUMENT);
+}
+
+// in_size = 3, ksize = 3, stride = 1, pad_size = 0
+TEST_F(OpsUtilTest, GetBroadcastTest3_3_1_0) {
+ bcast_struct bcast[] = {
+ {{0, 3, 3, 1, 0}, {0, 3}},
+ {{1, 3, 3, 1, 0}, {1, 2}},
+ {{2, 3, 3, 1, 0}, {2, 1}},
+ };
+ for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
+ VerifyBcastValues(bcast[i]);
+ }
+}
+
+// in_size = 3, ksize = 3, stride = 1, pad_size = 1
+TEST_F(OpsUtilTest, GetBroadcastTest3_3_1_1) {
+ bcast_struct bcast[] = {
+ {{0, 3, 3, 1, 1}, {0, 2}},
+ {{1, 3, 3, 1, 1}, {0, 3}},
+ {{2, 3, 3, 1, 1}, {1, 2}},
+ };
+ for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
+ VerifyBcastValues(bcast[i]);
+ }
+}
+
+// in_size = 3, ksize = 3, stride = 1, pad_size = 2
+TEST_F(OpsUtilTest, GetBroadcastTest3_3_1_2) {
+ bcast_struct bcast[] = {
+ {{0, 3, 3, 1, 2}, {0, 1}},
+ {{1, 3, 3, 1, 2}, {0, 2}},
+ {{2, 3, 3, 1, 2}, {0, 3}},
+ };
+ for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
+ VerifyBcastValues(bcast[i]);
+ }
+}
+
+// in_size = 3, ksize = 3, stride = 2, pad_size = 0
+TEST_F(OpsUtilTest, GetBroadcastTest3_3_2_0) {
+ bcast_struct bcast[] = {
+ {{0, 3, 3, 2, 0}, {0, 3}}, {{1, 3, 3, 2, 0}, {2, 1}},
+ };
+ for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
+ VerifyBcastValues(bcast[i]);
+ }
+}
+
+// in_size = 3, ksize = 3, stride = 2, pad_size = 1
+TEST_F(OpsUtilTest, GetBroadcastTest3_3_2_1) {
+ bcast_struct bcast[] = {
+ {{0, 3, 3, 2, 1}, {0, 2}}, {{1, 3, 3, 2, 1}, {1, 2}},
+ };
+ for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
+ VerifyBcastValues(bcast[i]);
+ }
+}
+
+// in_size = 3, ksize = 3, stride = 2, pad_size = 2
+TEST_F(OpsUtilTest, GetBroadcastTest3_3_2_2) {
+ bcast_struct bcast[] = {
+ {{0, 3, 3, 2, 2}, {0, 1}},
+ };
+ for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
+ VerifyBcastValues(bcast[i]);
+ }
+}
+
+// in_size = 3, ksize = 3, stride = 3, pad_size = 0
+TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_0) {
+ bcast_struct bcast[] = {
+ {{0, 3, 3, 3, 0}, {0, 3}},
+ };
+ for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
+ VerifyBcastValues(bcast[i]);
+ }
+}
+
+// in_size = 3, ksize = 3, stride = 3, pad_size = 1
+TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_1) {
+ bcast_struct bcast[] = {
+ {{0, 3, 3, 3, 1}, {0, 2}}, {{1, 3, 3, 3, 1}, {2, 1}},
+ };
+ for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
+ VerifyBcastValues(bcast[i]);
+ }
+}
+
+// in_size = 3, ksize = 3, stride = 3, pad_size = 2
+TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_2) {
+ bcast_struct bcast[] = {
+ {{0, 3, 3, 3, 2}, {0, 1}},
+ };
+ for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) {
+ VerifyBcastValues(bcast[i]);
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc
new file mode 100644
index 0000000000..cb125ea2fe
--- /dev/null
+++ b/tensorflow/core/kernels/pack_op.cc
@@ -0,0 +1,114 @@
+// See docs in ../ops/array_ops.cc.
+
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/concat_op.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+// --------------------------------------------------------------------------
+template <typename Device, typename T>
+class PackOp : public OpKernel {
+ public:
+ typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
+ ConstMatrixVector;
+
+ explicit PackOp(OpKernelConstruction* c) : OpKernel(c) {}
+
+ void Compute(OpKernelContext* c) override {
+ OpInputList values;
+ OP_REQUIRES_OK(c, c->input_list("values", &values));
+ const int num = values.size();
+
+ // Verify that all input shapes match
+ for (int i = 1; i < num; i++) {
+ OP_REQUIRES(c, values[0].shape().IsSameSize(values[i].shape()),
+ errors::InvalidArgument(
+ "Shapes of all inputs must match: values[0].shape = ",
+ values[0].shape().ShortDebugString(), " != values[", i,
+ "].shape = ", values[i].shape().ShortDebugString()));
+ }
+
+ TensorShape output_shape(values[0].shape());
+ output_shape.InsertDim(0, num);
+
+ // In the num = 1 case, just reshape the input
+ if (num == 1) {
+ Tensor output;
+ CHECK(output.CopyFrom(values[0], output_shape));
+ c->set_output(0, output);
+ return;
+ }
+
+ // Allocate output
+ Tensor* output;
+ OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
+
+ const int output_size = output->NumElements();
+ if (output_size > 0) {
+ auto output_flat = output->shaped<T, 2>({1, output_size});
+
+ // Except for shapes, pack is a special case of concat, so we reuse the
+ // same computational kernels.
+ ConstMatrixVector inputs_flat;
+ inputs_flat.reserve(num);
+ for (int i = 0; i < num; ++i) {
+ inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
+ values[i].shaped<T, 2>({1, values[i].NumElements()})));
+ }
+ if (std::is_same<Device, GPUDevice>::value) {
+ ConcatGPU<T>(c->eigen_gpu_device(), inputs_flat, &output_flat);
+ } else {
+ ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
+ }
+ }
+ }
+};
+
+#define REGISTER_PACK(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Pack").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ PackOp<CPUDevice, type>)
+
+TF_CALL_ALL_TYPES(REGISTER_PACK);
+REGISTER_PACK(quint8);
+REGISTER_PACK(qint8);
+REGISTER_PACK(qint32);
+REGISTER_PACK(bfloat16);
+
+#undef REGISTER_PACK
+
+#if GOOGLE_CUDA
+
+#define REGISTER_GPU(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Pack").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ PackOp<GPUDevice, type>)
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
+#undef REGISTER_GPU
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Pack")
+ .Device(DEVICE_GPU)
+ .HostMemory("values")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T"),
+ PackOp<CPUDevice, int32>);
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/pad_op.cc b/tensorflow/core/kernels/pad_op.cc
new file mode 100644
index 0000000000..6c66e54e3d
--- /dev/null
+++ b/tensorflow/core/kernels/pad_op.cc
@@ -0,0 +1,159 @@
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/pad_op.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class PadOp : public OpKernel {
+ public:
+ explicit PadOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& in0 = context->input(0);
+ const Tensor& in1 = context->input(1);
+ const int dims = in0.dims();
+ static const int kMinDims = 0;
+ static const int kMaxDims = 5;
+ OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims,
+ errors::Unimplemented("inputs rank not in [", kMinDims, ",",
+ kMaxDims, "]: ", dims));
+ OP_REQUIRES(
+ context,
+ TensorShapeUtils::IsMatrix(in1.shape()) && in1.dim_size(1) == 2,
+ errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
+ in1.shape().DebugString()));
+ const int fixed_dims =
+ (kAllowLegacyScalars && dims == 0 && in1.dim_size(0) == 1) ? 1 : dims;
+ OP_REQUIRES(
+ context, fixed_dims == in1.dim_size(0),
+ errors::InvalidArgument(
+ "The first dimension of paddings must be the rank of inputs",
+ in1.shape().DebugString(), " ", in0.shape().DebugString()));
+
+ // Compute the shape of the output tensor, and allocate it.
+ TensorShape output_shape;
+ TTypes<int32>::ConstMatrix paddings = in1.matrix<int32>();
+ for (int d = 0; d < fixed_dims; ++d) {
+ const int32 before_d = paddings(d, 0); // Pad before existing elements.
+ const int32 after_d = paddings(d, 1); // Pad after exisitng elements.
+ OP_REQUIRES(context, before_d >= 0 && after_d >= 0,
+ errors::InvalidArgument("Paddings must be non-negative: ",
+ before_d, " ", after_d));
+ const int size_d =
+ (kAllowLegacyScalars && d == in0.dims()) ? 1 : in0.dim_size(d);
+ output_shape.AddDim(before_d + size_d + after_d);
+ }
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+
+ // Invoke the dims-specific implementation.
+ switch (fixed_dims) {
+ case 0:
+ Operate<0>(context, in0.tensor<T, 0>(), paddings, output);
+ break;
+ case 1:
+ // TODO(irving): Once Pad doesn't need a scalar special case,
+ // change flat to tensor. That is, once !kAllowLegacyScalars.
+ Operate<1>(context, in0.flat<T>(), paddings, output);
+ break;
+ case 2:
+ Operate<2>(context, in0.tensor<T, 2>(), paddings, output);
+ break;
+ case 3:
+ Operate<3>(context, in0.tensor<T, 3>(), paddings, output);
+ break;
+ case 4:
+ Operate<4>(context, in0.tensor<T, 4>(), paddings, output);
+ break;
+ case 5:
+ Operate<5>(context, in0.tensor<T, 5>(), paddings, output);
+ break;
+ default:
+ OP_REQUIRES(context, false,
+ errors::InvalidArgument("Only ranks up to 5 supported: ",
+ in0.shape().DebugString()));
+ }
+ }
+
+ private:
+ template <int Dims>
+ void Operate(OpKernelContext* context,
+ typename TTypes<T, Dims>::ConstTensor input,
+ TTypes<int32>::ConstMatrix paddings, Tensor* output) {
+ CHECK_EQ(Dims, paddings.dimension(0));
+ CHECK_EQ(2, paddings.dimension(1));
+ Eigen::array<std::pair<int32, int32>, Dims> paddings_array;
+ for (int i = 0; i < Dims; ++i) {
+ paddings_array[i] = std::make_pair(paddings(i, 0), paddings(i, 1));
+ }
+ functor::Pad<Device, T, Dims> functor;
+ functor(context->eigen_device<Device>(), output->tensor<T, Dims>(), input,
+ paddings_array);
+ }
+};
+
+#define REGISTER_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Pad") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("paddings"), \
+ PadOp<CPUDevice, type>)
+
+TF_CALL_ALL_TYPES(REGISTER_KERNEL);
+#undef REGISTER_KERNEL
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T, Dims) \
+ template <> \
+ void Pad<GPUDevice, T, Dims>::operator()( \
+ const GPUDevice& d, typename TTypes<T, Dims>::Tensor output, \
+ typename TTypes<T, Dims>::ConstTensor input, \
+ Eigen::array<std::pair<int32, int32>, Dims> paddings); \
+ extern template struct Pad<GPUDevice, T, Dims>;
+
+#define DECLARE_GPU_SPECS(T) \
+ DECLARE_GPU_SPEC(T, 0); \
+ DECLARE_GPU_SPEC(T, 1); \
+ DECLARE_GPU_SPEC(T, 2); \
+ DECLARE_GPU_SPEC(T, 3); \
+ DECLARE_GPU_SPEC(T, 4); \
+ DECLARE_GPU_SPEC(T, 5);
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER_GPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("Pad") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("paddings"), \
+ PadOp<GPUDevice, T>)
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
+#endif // GOOGLE_CUDA
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/pad_op.h b/tensorflow/core/kernels/pad_op.h
new file mode 100644
index 0000000000..c4f8a4abda
--- /dev/null
+++ b/tensorflow/core/kernels/pad_op.h
@@ -0,0 +1,27 @@
+#ifndef TENSORFLOW_KERNELS_PAD_OP_H_
+#define TENSORFLOW_KERNELS_PAD_OP_H_
+// Functor definition for PadOp, must be compilable by nvcc.
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+// Functor used by PadOp to do the computations.
+template <typename Device, typename T, int Dims>
+struct Pad {
+ // Pad "input" into "output", as specified by "paddings". See pad_op.cc for
+ // details.
+ void operator()(const Device& d, typename TTypes<T, Dims>::Tensor output,
+ typename TTypes<T, Dims>::ConstTensor input,
+ Eigen::array<std::pair<int32, int32>, Dims> paddings) {
+ output.device(d) = input.pad(paddings);
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_PAD_OP_H_
diff --git a/tensorflow/core/kernels/pad_op_gpu.cu.cc b/tensorflow/core/kernels/pad_op_gpu.cu.cc
new file mode 100644
index 0000000000..35a03a2cb2
--- /dev/null
+++ b/tensorflow/core/kernels/pad_op_gpu.cu.cc
@@ -0,0 +1,26 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/pad_op.h"
+
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// Definition of the GPU implementations declared in pad_op.cc.
+#define DEFINE_GPU_SPECS(T) \
+ template struct functor::Pad<GPUDevice, T, 0>; \
+ template struct functor::Pad<GPUDevice, T, 1>; \
+ template struct functor::Pad<GPUDevice, T, 2>; \
+ template struct functor::Pad<GPUDevice, T, 3>; \
+ template struct functor::Pad<GPUDevice, T, 4>; \
+ template struct functor::Pad<GPUDevice, T, 5>;
+
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc
new file mode 100644
index 0000000000..35e9bd75fa
--- /dev/null
+++ b/tensorflow/core/kernels/pooling_ops_common.cc
@@ -0,0 +1,252 @@
+#include "tensorflow/core/kernels/pooling_ops_common.h"
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/public/tensor.h"
+
+#if GOOGLE_CUDA
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
+#include "tensorflow/core/kernels/pooling_ops_common_gpu.h"
+#include "tensorflow/stream_executor/dnn.h"
+#include "tensorflow/stream_executor/stream.h"
+#endif // GOOGLE_CUDA
+
+namespace tensorflow {
+
+PoolParameters::PoolParameters(OpKernelContext* context,
+ const std::vector<int32>& ksize,
+ const std::vector<int32>& stride,
+ Padding padding,
+ const TensorShape& tensor_in_shape) {
+ // For maxpooling, tensor_in should have 4 dimensions.
+ OP_REQUIRES(context, tensor_in_shape.dims() == 4,
+ errors::InvalidArgument("tensor_in must be 4-dimensional"));
+
+ depth = tensor_in_shape.dim_size(3);
+ tensor_in_cols = tensor_in_shape.dim_size(2);
+ tensor_in_rows = tensor_in_shape.dim_size(1);
+ tensor_in_batch = tensor_in_shape.dim_size(0);
+ window_rows = ksize[1];
+ window_cols = ksize[2];
+ depth_window = ksize[3];
+ row_stride = stride[1];
+ col_stride = stride[2];
+ depth_stride = stride[3];
+
+ // We only support 2D pooling across width/height and depthwise
+ // pooling, not a combination.
+ OP_REQUIRES(context,
+ (depth_window == 1 || (window_rows == 1 && window_cols == 1)),
+ errors::Unimplemented(
+ "MaxPooling supports exactly one of pooling across depth "
+ "or pooling across width/height."));
+
+ if (depth_window == 1) {
+ OP_REQUIRES_OK(context, Get2dOutputSize(
+ tensor_in_rows, tensor_in_cols, window_rows,
+ window_cols, row_stride, col_stride, padding,
+ &out_height, &out_width, &pad_rows, &pad_cols));
+ } else {
+ // Our current version of depthwise max pooling does not support
+ // any padding, and expects the depth_window to equal the
+ // depth_stride (no overlapping).
+ OP_REQUIRES(
+ context, depth % depth_window == 0,
+ errors::Unimplemented("Depthwise max pooling requires the depth "
+ "window to evenly divide the input depth"));
+ OP_REQUIRES(
+ context, depth_stride == depth_window,
+ errors::Unimplemented("Depthwise max pooling requires the depth "
+ "window to equal the depth stride"));
+
+ // The current version of depthwise max is only implemented on CPU.
+ OP_REQUIRES(context,
+ (DeviceType(static_cast<Device*>(context->device())
+ ->attributes()
+ .device_type()) == DeviceType(DEVICE_CPU)),
+ errors::Unimplemented("Depthwise max pooling is currently "
+ "only implemented for CPU devices."));
+
+ pad_depth = 0;
+ out_depth = depth / depth_window;
+ }
+}
+
+TensorShape PoolParameters::forward_output_shape() {
+ if (depth_window == 1) {
+ // Spatial pooling
+ return TensorShape({tensor_in_batch, out_height, out_width, depth});
+ } else {
+ // Depthwise pooling
+ return TensorShape(
+ {tensor_in_batch, tensor_in_rows, tensor_in_cols, out_depth});
+ }
+}
+
+#ifdef GOOGLE_CUDA
+
+namespace {
+template <typename T>
+perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
+ uint64 size) {
+ perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
+ size * sizeof(T));
+ perftools::gputools::DeviceMemory<T> typed(wrapped);
+ return typed;
+}
+} // namespace
+
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void TransformDepth<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 4>::ConstTensor in, \
+ const Eigen::DSizes<Eigen::DenseIndex, 4>& shuffle, \
+ typename TTypes<T, 4>::Tensor out); \
+ extern template struct TransformDepth<GPUDevice, T>;
+
+DECLARE_GPU_SPEC(float);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+template <typename T>
+void DnnPoolingGradOp<T>::Compute(
+ OpKernelContext* context,
+ perftools::gputools::dnn::PoolingMode pooling_mode,
+ const std::vector<int32>& size, const std::vector<int32>& stride,
+ Padding padding, const Tensor* tensor_in, const Tensor* tensor_out,
+ const Tensor& out_backprop, const TensorShape& tensor_in_shape) {
+ CHECK((pooling_mode == perftools::gputools::dnn::PoolingMode::kMaximum) ||
+ (tensor_in && tensor_out))
+ << "For MaxPoolGrad, both tensor_in and tensor_out needs to be "
+ "specified";
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, tensor_in_shape, &output));
+
+ PoolParameters params{context, size, stride, padding, tensor_in_shape};
+ if (!context->status().ok()) {
+ return;
+ }
+
+ /// For now, cudnn does not support NHWC format, so we need to convert it
+ /// to NCHW before calling cudnn. We need to get rid of this once it is done
+ Tensor transformed_input;
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({tensor_in_shape.dim_size(0),
+ tensor_in_shape.dim_size(3),
+ tensor_in_shape.dim_size(1),
+ tensor_in_shape.dim_size(2)}),
+ &transformed_input));
+ Tensor transformed_input_backprop;
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({tensor_in_shape.dim_size(0),
+ tensor_in_shape.dim_size(3),
+ tensor_in_shape.dim_size(1),
+ tensor_in_shape.dim_size(2)}),
+ &transformed_input_backprop));
+ Tensor transformed_output;
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({out_backprop.dim_size(0), out_backprop.dim_size(3),
+ out_backprop.dim_size(1), out_backprop.dim_size(2)}),
+ &transformed_output));
+ Tensor transformed_output_backprop;
+ OP_REQUIRES_OK(
+ context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({out_backprop.dim_size(0), out_backprop.dim_size(3),
+ out_backprop.dim_size(1), out_backprop.dim_size(2)}),
+ &transformed_output_backprop));
+
+ auto nhwc_to_nchw = Eigen::DSizes<Eigen::DenseIndex, 4>(0, 3, 1, 2);
+ if (tensor_in) {
+ // For AvgPoolGrad, the original input tensor is not necessary. However,
+ // cudnn still requires them to run, although they do not affect the
+ // results.
+ functor::TransformDepth<GPUDevice, T>()(
+ context->eigen_device<Device>(), tensor_in->tensor<T, 4>(),
+ nhwc_to_nchw, transformed_input.tensor<T, 4>());
+ }
+ if (tensor_out) {
+ // For AvgPoolGrad, the original output tensor is not necessary. However,
+ // cudnn still requires them to run, although they do not affect the
+ // results.
+ functor::TransformDepth<GPUDevice, T>()(
+ context->eigen_device<Device>(), tensor_out->tensor<T, 4>(),
+ nhwc_to_nchw, transformed_output.tensor<T, 4>());
+ }
+ functor::TransformDepth<GPUDevice, T>()(
+ context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
+ nhwc_to_nchw, transformed_output_backprop.tensor<T, 4>());
+
+ /// Get ready to call cudnn
+ perftools::gputools::dnn::PoolingDescriptor pooling_desc;
+ pooling_desc.set_pooling_mode(pooling_mode)
+ .set_window_height(params.window_rows)
+ .set_window_width(params.window_cols)
+ .set_vertical_stride(params.row_stride)
+ .set_horizontal_stride(params.col_stride)
+ .set_vertical_padding(params.pad_rows)
+ .set_horizontal_padding(params.pad_cols);
+
+ perftools::gputools::dnn::BatchDescriptor orig_output_desc;
+ orig_output_desc.set_count(params.tensor_in_batch)
+ .set_height(params.out_height)
+ .set_width(params.out_width)
+ .set_feature_map_count(params.depth)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+
+ perftools::gputools::dnn::BatchDescriptor orig_input_desc;
+ orig_input_desc.set_count(params.tensor_in_batch)
+ .set_height(params.tensor_in_rows)
+ .set_width(params.tensor_in_cols)
+ .set_feature_map_count(params.depth)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+
+ auto orig_output_data =
+ AsDeviceMemory(transformed_output.template flat<T>().data(),
+ transformed_output.template flat<T>().size());
+ auto orig_input_data =
+ AsDeviceMemory(transformed_input.template flat<T>().data(),
+ transformed_input.template flat<T>().size());
+ auto output_backprop =
+ AsDeviceMemory(transformed_output_backprop.template flat<T>().data(),
+ transformed_output_backprop.template flat<T>().size());
+ auto input_backprop =
+ AsDeviceMemory(transformed_input_backprop.template flat<T>().data(),
+ transformed_input_backprop.template flat<T>().size());
+
+ auto* stream = context->op_device_context<GPUDeviceContext>()->stream();
+ OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
+
+ bool status =
+ stream->ThenPoolBackward(pooling_desc, orig_input_desc, orig_input_data,
+ orig_output_desc, orig_output_data,
+ output_backprop, &input_backprop)
+ .ok();
+ OP_REQUIRES(context, status,
+ errors::Internal("cudnn PoolBackward launch failed"));
+
+ /// Transform the output data from NCHW back to NHWC
+ auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
+ auto nchw_to_nhwc = Eigen::DSizes<Eigen::DenseIndex, 4>(0, 2, 3, 1);
+ functor::TransformDepth<GPUDevice, T>()(
+ context->eigen_device<Device>(),
+ toConstTensor(transformed_input_backprop).template tensor<T, 4>(),
+ nchw_to_nhwc, output->tensor<T, 4>());
+}
+
+template class DnnPoolingGradOp<float>;
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h
new file mode 100644
index 0000000000..5bf44b6e40
--- /dev/null
+++ b/tensorflow/core/kernels/pooling_ops_common.h
@@ -0,0 +1,264 @@
+#ifndef TENSORFLOW_KERNELS_POOLING_OPS_COMMON_H_
+#define TENSORFLOW_KERNELS_POOLING_OPS_COMMON_H_
+
+#include <vector>
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/avgpooling_op.h"
+#include "tensorflow/core/kernels/maxpooling_op.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// A helper class to manage sizes and shapes for pooling operations.
+struct PoolParameters {
+ // Updates context->status if there is an invalid input.
+ PoolParameters(OpKernelContext* context, const std::vector<int32>& ksize,
+ const std::vector<int32>& stride, Padding padding,
+ const TensorShape& tensor_in_shape);
+
+ // Returns the shape of the output for "forward" pooling operations.
+ TensorShape forward_output_shape();
+
+ int depth;
+
+ int tensor_in_cols;
+ int tensor_in_rows;
+ int tensor_in_batch;
+
+ int window_rows;
+ int window_cols;
+ int depth_window;
+
+ int row_stride;
+ int col_stride;
+ int depth_stride;
+
+ int out_height;
+ int out_width;
+ int out_depth;
+
+ int pad_rows;
+ int pad_cols;
+ int pad_depth;
+};
+
+// An implementation of MaxPooling (forward).
+template <typename Device, typename T>
+class MaxPoolingOp : public UnaryOp<T> {
+ public:
+ explicit MaxPoolingOp(OpKernelConstruction* context) : UnaryOp<T>(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(context, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window stride field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& tensor_in = context->input(0);
+ PoolParameters params{context, ksize_, stride_, padding_,
+ tensor_in.shape()};
+ if (!context->status().ok()) {
+ return;
+ }
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, params.forward_output_shape(), &output));
+
+ if (params.depth_window > 1) {
+ DepthwiseMaxPool(context, output, tensor_in, params);
+ } else {
+ SpatialMaxPool(context, output, tensor_in, params, padding_);
+ }
+ }
+
+ private:
+ // Single-threaded implementation of DepthwiseMaxPool which
+ // does not handle all of the same options as SpatialMaxPool
+ // (strict assumptions on no padding, stride).
+ //
+ // TODO(vrv): implement a more general depthwise-max pool that works
+ // on GPU as well.
+ void DepthwiseMaxPool(OpKernelContext* context, Tensor* output,
+ const Tensor& tensor_in, const PoolParameters& params) {
+ Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+ in_by_pool(tensor_in.flat<T>().data(), params.depth_window,
+ tensor_in.NumElements() / params.depth_window);
+ Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> out_by_pool(
+ output->flat<T>().data(), 1, output->NumElements());
+ out_by_pool = in_by_pool.colwise().maxCoeff();
+ }
+
+ void SpatialMaxPool(OpKernelContext* context, Tensor* output,
+ const Tensor& tensor_in, const PoolParameters& params,
+ const Padding& padding) {
+ // On GPU, use Eigen's Spatial Max Pooling. On CPU, use an
+ // EigenMatrix version that is currently faster than Eigen's
+ // Spatial MaxPooling implementation.
+ //
+ // TODO(vrv): Remove this once we no longer need it.
+ if (std::is_same<Device, GPUDevice>::value) {
+ Eigen::PaddingType pt = BrainPadding2EigenPadding(padding);
+ functor::SpatialMaxPooling<Device, T>()(
+ context->eigen_device<Device>(), output->tensor<T, 4>(),
+ tensor_in.tensor<T, 4>(), params.window_rows, params.window_cols,
+ params.row_stride, params.col_stride, pt);
+ } else {
+ typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+ ConstEigenMatrixMap;
+ typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+ EigenMatrixMap;
+
+ ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), params.depth,
+ params.tensor_in_cols * params.tensor_in_rows *
+ params.tensor_in_batch);
+ EigenMatrixMap out_mat(
+ output->flat<T>().data(), params.depth,
+ params.out_width * params.out_height * params.tensor_in_batch);
+
+ // Initializes the output tensor with MIN<T>.
+ output->flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
+
+ // The following code basically does the following:
+ // 1. Flattens the input and output tensors into two dimensional arrays.
+ // tensor_in_as_matrix:
+ // depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch)
+ // output_as_matrix:
+ // depth by (out_width * out_height * tensor_in_batch)
+ //
+ // 2. Walks through the set of columns in the flattened
+ // tensor_in_as_matrix,
+ // and updates the corresponding column(s) in output_as_matrix with the
+ // max value.
+ for (int b = 0; b < params.tensor_in_batch; ++b) {
+ for (int h = 0; h < params.tensor_in_rows; ++h) {
+ for (int w = 0; w < params.tensor_in_cols; ++w) {
+ // (h_start, h_end) * (w_start, w_end) is the range that the input
+ // vector projects to.
+ const int hpad = h + params.pad_rows;
+ const int wpad = w + params.pad_cols;
+ const int h_start =
+ (hpad < params.window_rows)
+ ? 0
+ : (hpad - params.window_rows) / params.row_stride + 1;
+ const int h_end =
+ std::min(hpad / params.row_stride + 1, params.out_height);
+ const int w_start =
+ (wpad < params.window_cols)
+ ? 0
+ : (wpad - params.window_cols) / params.col_stride + 1;
+ const int w_end =
+ std::min(wpad / params.col_stride + 1, params.out_width);
+ // compute elementwise max
+ const int in_offset =
+ (b * params.tensor_in_rows + h) * params.tensor_in_cols + w;
+ for (int ph = h_start; ph < h_end; ++ph) {
+ for (int pw = w_start; pw < w_end; ++pw) {
+ const int out_offset =
+ (b * params.out_height + ph) * params.out_width + pw;
+ out_mat.col(out_offset) =
+ out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset));
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ std::vector<int32> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+};
+
+template <typename Device, typename T>
+void SpatialAvgPool(OpKernelContext* context, Tensor* output,
+ const Tensor& input, const PoolParameters& params,
+ const Padding& padding) {
+ typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+ ConstEigenMatrixMap;
+ typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
+ EigenMatrixMap;
+
+ auto in_flat = input.flat<T>();
+ auto out_flat = output->flat<T>();
+
+ ConstEigenMatrixMap in_mat(
+ in_flat.data(), params.depth,
+ params.tensor_in_cols * params.tensor_in_rows * params.tensor_in_batch);
+ EigenMatrixMap out_mat(
+ out_flat.data(), params.depth,
+ params.out_width * params.out_height * params.tensor_in_batch);
+ Eigen::Matrix<T, Eigen::Dynamic, 1> out_count(out_mat.cols());
+ out_count.setZero();
+
+ // Initializes output to zero.
+ out_flat.setZero();
+
+ // The following code basically does the following:
+ // 1. Flattens the input and output tensors into two dimensional arrays.
+ // tensor_in_as_matrix:
+ // depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch)
+ // output_as_matrix:
+ // depth by (out_width * out_height * tensor_in_batch)
+ //
+ // 2. Walks through the set of columns in the flattened
+ // tensor_in_as_matrix,
+ // and updates the corresponding column(s) in output_as_matrix with the
+ // average value.
+ for (int b = 0; b < params.tensor_in_batch; ++b) {
+ for (int h = 0; h < params.tensor_in_rows; ++h) {
+ for (int w = 0; w < params.tensor_in_cols; ++w) {
+ // (h_start, h_end) * (w_start, w_end) is the range that the input
+ // vector projects to.
+ const int hpad = h + params.pad_rows;
+ const int wpad = w + params.pad_cols;
+ const int h_start =
+ (hpad < params.window_rows)
+ ? 0
+ : (hpad - params.window_rows) / params.row_stride + 1;
+ const int h_end =
+ std::min(hpad / params.row_stride + 1, params.out_height);
+ const int w_start =
+ (wpad < params.window_cols)
+ ? 0
+ : (wpad - params.window_cols) / params.col_stride + 1;
+ const int w_end =
+ std::min(wpad / params.col_stride + 1, params.out_width);
+ const int in_offset =
+ (b * params.tensor_in_rows + h) * params.tensor_in_cols + w;
+ Eigen::DSizes<ptrdiff_t, 2> in_indices(0, in_offset);
+ for (int ph = h_start; ph < h_end; ++ph) {
+ for (int pw = w_start; pw < w_end; ++pw) {
+ const int out_offset =
+ (b * params.out_height + ph) * params.out_width + pw;
+ out_mat.col(out_offset) += in_mat.col(in_offset);
+ out_count(out_offset)++;
+ }
+ }
+ }
+ }
+ }
+ DCHECK_GT(out_count.minCoeff(), 0);
+ out_mat.array().rowwise() /= out_count.transpose().array();
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_POOLING_OPS_COMMON_H_
diff --git a/tensorflow/core/kernels/pooling_ops_common_gpu.h b/tensorflow/core/kernels/pooling_ops_common_gpu.h
new file mode 100644
index 0000000000..87a3ef5186
--- /dev/null
+++ b/tensorflow/core/kernels/pooling_ops_common_gpu.h
@@ -0,0 +1,39 @@
+#if !GOOGLE_CUDA
+#error This file must only be included when building with Cuda support
+#endif
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_GPU_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_GPU_H_
+
+#include "tensorflow/stream_executor/dnn.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/avgpooling_op.h"
+#include "tensorflow/core/kernels/maxpooling_op.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/NeuralNetworks"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+// A helper class that launch the cudnn pooling backward operations.
+// The original input and output tensors are optional for AvgPoolGrad, but
+// mandatory for MaxPoolGrad.
+template <typename T>
+class DnnPoolingGradOp {
+ public:
+ typedef GPUDevice Device;
+ static void Compute(OpKernelContext* context,
+ perftools::gputools::dnn::PoolingMode pooling_mode,
+ const std::vector<int32>& size,
+ const std::vector<int32>& stride, Padding padding,
+ const Tensor* tensor_in, const Tensor* tensor_out,
+ const Tensor& out_backprop,
+ const TensorShape& tensor_in_shape);
+};
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_GPU_H_
diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc
new file mode 100644
index 0000000000..1b13f68a3a
--- /dev/null
+++ b/tensorflow/core/kernels/queue_base.cc
@@ -0,0 +1,153 @@
+#include "tensorflow/core/kernels/queue_base.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+namespace {
+
+template <DataType DT>
+void HandleSliceToElement(const Tensor& parent, Tensor* element, int index) {
+ typedef typename EnumToDataType<DT>::Type T;
+ auto parent_as_matrix = parent.flat_outer_dims<T>();
+ element->flat<T>() = parent_as_matrix.chip(index, 0);
+}
+
+template <DataType DT>
+void HandleElementToSlice(const Tensor& element, Tensor* parent, int index) {
+ typedef typename EnumToDataType<DT>::Type T;
+ auto parent_as_matrix = parent->flat_outer_dims<T>();
+ parent_as_matrix.chip(index, 0) = element.flat<T>();
+}
+
+} // namespace
+
+// static
+Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
+ int index) {
+#define HANDLE_TYPE(DT) \
+ if (parent.dtype() == DT) { \
+ HandleSliceToElement<DT>(parent, element, index); \
+ return Status::OK(); \
+ }
+ HANDLE_TYPE(DT_FLOAT);
+ HANDLE_TYPE(DT_DOUBLE);
+ HANDLE_TYPE(DT_INT32);
+ HANDLE_TYPE(DT_UINT8);
+ HANDLE_TYPE(DT_INT16);
+ HANDLE_TYPE(DT_INT8);
+ HANDLE_TYPE(DT_STRING);
+ HANDLE_TYPE(DT_INT64);
+#undef HANDLE_TYPE
+ return errors::Unimplemented("Unhandled data type: ", parent.dtype());
+}
+
+// static
+Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent,
+ int index) {
+#define HANDLE_TYPE(DT) \
+ if (element.dtype() == DT) { \
+ HandleElementToSlice<DT>(element, parent, index); \
+ return Status::OK(); \
+ }
+ HANDLE_TYPE(DT_FLOAT);
+ HANDLE_TYPE(DT_DOUBLE);
+ HANDLE_TYPE(DT_INT32);
+ HANDLE_TYPE(DT_UINT8);
+ HANDLE_TYPE(DT_INT16);
+ HANDLE_TYPE(DT_INT8);
+ HANDLE_TYPE(DT_STRING);
+ HANDLE_TYPE(DT_INT64);
+#undef HANDLE_TYPE
+ return errors::Unimplemented("Unhandled data type: ", element.dtype());
+}
+
+QueueBase::QueueBase(const DataTypeVector& component_dtypes,
+ const std::vector<TensorShape>& component_shapes,
+ const string& name)
+ : component_dtypes_(component_dtypes),
+ component_shapes_(component_shapes),
+ name_(name) {}
+
+Status QueueBase::ValidateTupleCommon(const Tuple& tuple) const {
+ if (tuple.size() != static_cast<size_t>(num_components())) {
+ return errors::InvalidArgument(
+ "Wrong number of components in tuple. Expected ", num_components(),
+ ", got ", tuple.size());
+ }
+ for (size_t i = 0; i < tuple.size(); ++i) {
+ if (tuple[i].dtype() != component_dtypes_[i]) {
+ return errors::InvalidArgument(
+ "Type mismatch in tuple component ", i, ". Expected ",
+ DataTypeString(component_dtypes_[i]), ", got ",
+ DataTypeString(tuple[i].dtype()));
+ }
+ }
+ return Status::OK();
+}
+
+// static
+string QueueBase::ShapeListString(const gtl::ArraySlice<TensorShape>& shapes) {
+ string result = "[";
+ bool first = true;
+ for (const TensorShape& shape : shapes) {
+ strings::StrAppend(&result, (first ? "" : ", "), shape.ShortDebugString());
+ first = false;
+ }
+ strings::StrAppend(&result, "]");
+ return result;
+}
+
+Status QueueBase::MatchesNodeDefOp(const NodeDef& node_def,
+ const string& op) const {
+ if (node_def.op() != op) {
+ return errors::InvalidArgument("Shared queue '", name_, "' has type '", op,
+ "' that does not match type of Node '",
+ node_def.name(), "': ", node_def.op());
+ }
+ return Status::OK();
+}
+
+Status QueueBase::MatchesNodeDefCapacity(const NodeDef& node_def,
+ int32 capacity) const {
+ int32 requested_capacity = -1;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "capacity", &requested_capacity));
+ if (requested_capacity < 0) requested_capacity = kUnbounded;
+ if (requested_capacity != capacity) {
+ return errors::InvalidArgument("Shared queue '", name_, "' has capacity ",
+ capacity, " but requested capacity was ",
+ requested_capacity);
+ }
+ return Status::OK();
+}
+
+Status QueueBase::MatchesNodeDefTypes(const NodeDef& node_def) const {
+ DataTypeVector requested_dtypes;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node_def, "component_types", &requested_dtypes));
+ if (requested_dtypes != component_dtypes_) {
+ return errors::InvalidArgument("Shared queue '", name_,
+ "' has component types ",
+ DataTypeSliceString(component_dtypes_),
+ " but requested component types were ",
+ DataTypeSliceString(requested_dtypes));
+ }
+ return Status::OK();
+}
+
+Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const {
+ std::vector<TensorShape> requested_shapes;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes));
+ if (requested_shapes != component_shapes_) {
+ return errors::InvalidArgument("Shared queue '", name_,
+ "' has component shapes ",
+ ShapeListString(component_shapes_),
+ " but requested component shapes were ",
+ ShapeListString(requested_shapes));
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h
new file mode 100644
index 0000000000..4897102974
--- /dev/null
+++ b/tensorflow/core/kernels/queue_base.h
@@ -0,0 +1,77 @@
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/queue_interface.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+// Functionality common to QueueInterface implementations.
+class QueueBase : public QueueInterface {
+ public:
+ // As a possible value of 'capacity'.
+ static const int32 kUnbounded = INT_MAX;
+
+ // Args:
+ // component_dtypes: The types of each component in a queue-element tuple.
+ // component_shapes: The shapes of each component in a queue-element tuple,
+ // which must either be empty (if the shapes are not specified) or
+ // or have the same size as component_dtypes.
+ // name: A name to use for the queue.
+ QueueBase(const DataTypeVector& component_dtypes,
+ const std::vector<TensorShape>& component_shapes,
+ const string& name);
+
+ // Implementations of QueueInterface methods --------------------------------
+ const DataTypeVector& component_dtypes() const override {
+ return component_dtypes_;
+ }
+
+ // Other public methods -----------------------------------------------------
+ const std::vector<TensorShape>& component_shapes() const {
+ return component_shapes_;
+ }
+
+ protected:
+ // Returns the number of components in a queue-element tuple.
+ int32 num_components() const { return component_dtypes_.size(); }
+
+ // True if shapes were specified. If so, inputs will be validated
+ // against them, etc.
+ bool specified_shapes() const { return component_shapes_.size() > 0; }
+
+ // Code common to Validate*Tuple().
+ Status ValidateTupleCommon(const Tuple& tuple) const;
+
+ // Copies the index^th slice (in the first dimension) of parent into element.
+ static Status CopySliceToElement(const Tensor& parent, Tensor* element,
+ int index);
+
+ // Copies element into the index^th slice (in the first dimension) of parent.
+ static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
+ int index);
+
+ ~QueueBase() override {}
+
+ // Helpers for implementing MatchesNodeDef().
+ static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes);
+ Status MatchesNodeDefOp(const NodeDef& node_def, const string& op) const;
+ Status MatchesNodeDefCapacity(const NodeDef& node_def, int32 capacity) const;
+ Status MatchesNodeDefTypes(const NodeDef& node_def) const;
+ Status MatchesNodeDefShapes(const NodeDef& node_def) const;
+
+ const DataTypeVector component_dtypes_;
+ const std::vector<TensorShape> component_shapes_;
+ const string name_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(QueueBase);
+};
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc
new file mode 100644
index 0000000000..c70dc76777
--- /dev/null
+++ b/tensorflow/core/kernels/queue_ops.cc
@@ -0,0 +1,288 @@
+// See docs in ../ops/data_flow_ops.cc.
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/queue_interface.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+class QueueOpKernel : public AsyncOpKernel {
+ public:
+ explicit QueueOpKernel(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final {
+ QueueInterface* queue;
+ OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue),
+ callback);
+ ComputeAsync(ctx, queue, [callback, queue]() {
+ queue->Unref();
+ callback();
+ });
+ }
+
+ protected:
+ virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) = 0;
+};
+
+class QueueAccessOpKernel : public QueueOpKernel {
+ public:
+ explicit QueueAccessOpKernel(OpKernelConstruction* context)
+ : QueueOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
+ // TODO(keveman): Enable timeout.
+ OP_REQUIRES(context, timeout_ == -1,
+ errors::InvalidArgument("Timeout not supported yet."));
+ }
+
+ protected:
+ int64 timeout_;
+};
+
+// Defines an EnqueueOp, the execution of which enqueues a tuple of
+// tensors in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+class EnqueueOp : public QueueAccessOpKernel {
+ public:
+ explicit EnqueueOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override {
+ DataTypeVector expected_inputs = {DT_STRING_REF};
+ for (DataType dt : queue->component_dtypes()) {
+ expected_inputs.push_back(dt);
+ }
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}),
+ callback);
+
+ QueueInterface::Tuple tuple;
+ OpInputList components;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
+ callback);
+ for (const Tensor& Tcomponent : components) {
+ tuple.push_back(Tcomponent);
+ }
+
+ OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback);
+ queue->TryEnqueue(tuple, ctx, callback);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("QueueEnqueue").Device(DEVICE_CPU), EnqueueOp);
+
+// Defines an EnqueueManyOp, the execution of which slices each
+// component of a tuple of tensors along the 0th dimension, and
+// enqueues tuples of slices in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+//
+// N.B. All tuple components must have the same size in the 0th
+// dimension.
+class EnqueueManyOp : public QueueAccessOpKernel {
+ public:
+ explicit EnqueueManyOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override {
+ DataTypeVector expected_inputs = {DT_STRING_REF};
+ for (DataType dt : queue->component_dtypes()) {
+ expected_inputs.push_back(dt);
+ }
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
+
+ QueueInterface::Tuple tuple;
+ OpInputList components;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
+ callback);
+ for (const Tensor& Tcomponent : components) {
+ tuple.push_back(Tcomponent);
+ }
+
+ OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback);
+ queue->TryEnqueueMany(tuple, ctx, callback);
+ }
+
+ ~EnqueueManyOp() override {}
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("QueueEnqueueMany").Device(DEVICE_CPU),
+ EnqueueManyOp);
+
+// Defines a DequeueOp, the execution of which dequeues a tuple of
+// tensors from the given Queue.
+//
+// The op has one input, which is the handle of the appropriate
+// Queue. The op has k outputs, where k is the number of components in
+// the tuples stored in the given Queue, and output i is the ith
+// component of the dequeued tuple.
+class DequeueOp : public QueueAccessOpKernel {
+ public:
+ explicit DequeueOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override {
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()),
+ callback);
+
+ queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
+ if (!ctx->status().ok()) {
+ callback();
+ return;
+ }
+ OpOutputList output_components;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->output_list("components", &output_components), callback);
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ output_components.set(i, tuple[i]);
+ }
+ callback();
+ });
+ }
+
+ ~DequeueOp() override {}
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("QueueDequeue").Device(DEVICE_CPU), DequeueOp);
+
+// Defines a DequeueManyOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+class DequeueManyOp : public QueueAccessOpKernel {
+ public:
+ explicit DequeueManyOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override {
+ const Tensor& Tnum_elements = ctx->input(1);
+ int32 num_elements = Tnum_elements.flat<int32>()(0);
+
+ OP_REQUIRES_ASYNC(
+ ctx, num_elements >= 0,
+ errors::InvalidArgument("DequeueManyOp must request a positive number "
+ "of elements"),
+ callback);
+
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature({DT_STRING_REF, DT_INT32},
+ queue->component_dtypes()),
+ callback);
+
+ queue->TryDequeueMany(
+ num_elements, ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
+ if (!ctx->status().ok()) {
+ callback();
+ return;
+ }
+ OpOutputList output_components;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->output_list("components", &output_components),
+ callback);
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ output_components.set(i, tuple[i]);
+ }
+ callback();
+ });
+ }
+
+ ~DequeueManyOp() override {}
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU),
+ DequeueManyOp);
+
+// Defines a QueueCloseOp, which closes the given Queue. Closing a
+// Queue signals that no more elements will be enqueued in it.
+//
+// The op has one input, which is the handle of the appropriate Queue.
+class QueueCloseOp : public QueueOpKernel {
+ public:
+ explicit QueueCloseOp(OpKernelConstruction* context)
+ : QueueOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
+ &cancel_pending_enqueues_));
+ }
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override {
+ queue->Close(ctx, cancel_pending_enqueues_, callback);
+ }
+
+ private:
+ bool cancel_pending_enqueues_;
+ TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("QueueClose").Device(DEVICE_CPU), QueueCloseOp);
+
+// Defines a QueueSizeOp, which computes the number of elements in the
+// given Queue, and emits it as an output tensor.
+//
+// The op has one input, which is the handle of the appropriate Queue;
+// and one output, which is a single-element tensor containing the current
+// size of that Queue.
+class QueueSizeOp : public QueueOpKernel {
+ public:
+ explicit QueueSizeOp(OpKernelConstruction* context)
+ : QueueOpKernel(context) {}
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override {
+ Tensor* Tqueue_size = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size));
+ Tqueue_size->flat<int32>().setConstant(queue->size());
+ callback();
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("QueueSize").Device(DEVICE_CPU), QueueSizeOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/random_crop_op.cc b/tensorflow/core/kernels/random_crop_op.cc
new file mode 100644
index 0000000000..4fc12e92cb
--- /dev/null
+++ b/tensorflow/core/kernels/random_crop_op.cc
@@ -0,0 +1,103 @@
+// See docs in ../ops/image_ops.cc.
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/guarded_philox_random.h"
+
+namespace tensorflow {
+
+template <typename T>
+class RandomCropOp : public OpKernel {
+ public:
+ explicit RandomCropOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, generator_.Init(context));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(context, input.dims() == 3,
+ errors::InvalidArgument("input must be 3-dimensional",
+ input.shape().ShortDebugString()));
+ const Tensor& shape_t = context->input(1);
+ OP_REQUIRES(context, shape_t.dims() == 1,
+ errors::InvalidArgument("shape_t must be 1-dimensional",
+ shape_t.shape().ShortDebugString()));
+ OP_REQUIRES(context, shape_t.NumElements() == 2,
+ errors::InvalidArgument("shape_t must have two elements",
+ shape_t.shape().ShortDebugString()));
+
+ auto shape_vec = shape_t.vec<int64>();
+ const int32 target_height = shape_vec(0);
+ const int32 target_width = shape_vec(1);
+
+ const int32 height = input.dim_size(0);
+ const int32 width = input.dim_size(1);
+ const int32 channels = input.dim_size(2);
+
+ // Initialize shape to the batch size of the input, then add
+ // the rest of the dimensions
+ Tensor* output = nullptr;
+ const auto output_shape =
+ TensorShape({target_height, target_width, channels});
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+
+ // If the target size matches the actual size, then do nothing.
+ if ((target_height == height) && (target_width == width)) {
+ *output = context->input(0);
+ }
+
+ // TODO(shlens): Implement edge case to guarantee output size dimensions.
+ // Edge case. The target dimensions are larger then the image, so
+ // zero-pad the image. This guarantees that the image will *always*
+ // be [target_height, target_width] in size.
+ OP_REQUIRES(context, width >= target_width, errors::FailedPrecondition(
+ "width must be >= target_width: width = ", width,
+ ", target_width = ", target_width));
+ OP_REQUIRES(context, height >= target_height, errors::FailedPrecondition(
+ "height must be >= target_height: height = ", height,
+ ", target_height = ", target_height));
+
+ int32 offset_height = 0;
+ int32 offset_width = 0;
+
+ auto local_gen = generator_.ReserveSamples32(2);
+ random::SimplePhilox random(&local_gen);
+
+ if (width > target_width) {
+ offset_width = random.Rand32() % (width - target_width + 1);
+ }
+ if (height > target_height) {
+ offset_height = random.Rand32() % (height - target_height + 1);
+ }
+
+ // TODO(shlens): Do this more efficiently with memcpy once padding is
+ // available for smaller images.
+ typename TTypes<T, 3>::ConstTensor input_data = input.tensor<T, 3>();
+ typename TTypes<T, 3>::Tensor output_data = output->tensor<T, 3>();
+
+ for (int y = 0; y < target_height; ++y) {
+ for (int x = 0; x < target_width; ++x) {
+ for (int c = 0; c < channels; ++c) {
+ output_data(y, x, c) =
+ input_data(y + offset_height, x + offset_width, c);
+ }
+ }
+ }
+ }
+
+ private:
+ GuardedPhiloxRandom generator_;
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RandomCrop").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ RandomCropOp<type>)
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/random_crop_op_test.cc b/tensorflow/core/kernels/random_crop_op_test.cc
new file mode 100644
index 0000000000..1f232f4969
--- /dev/null
+++ b/tensorflow/core/kernels/random_crop_op_test.cc
@@ -0,0 +1,60 @@
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+
+class RandomCropOpTest : public OpsTestBase {
+ protected:
+ RandomCropOpTest() {
+ RequireDefaultOps();
+ EXPECT_OK(NodeDefBuilder("random_crop_op", "RandomCrop")
+ .Input(FakeInput(DT_UINT8))
+ .Input(FakeInput(DT_INT64))
+ .Attr("T", DT_UINT8)
+ .Finalize(node_def()));
+ EXPECT_OK(InitOp());
+ }
+};
+
+TEST_F(RandomCropOpTest, Basic) {
+ AddInputFromArray<uint8>(TensorShape({1, 2, 1}), {2, 2});
+ AddInputFromArray<int64>(TensorShape({2}), {1, 1});
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_UINT8, TensorShape({1, 1, 1}));
+ test::FillValues<uint8>(&expected, {2});
+ test::ExpectTensorEqual<uint8>(expected, *GetOutput(0));
+}
+
+TEST_F(RandomCropOpTest, SameSizeOneChannel) {
+ AddInputFromArray<uint8>(TensorShape({2, 1, 1}), {1, 2});
+ AddInputFromArray<int64>(TensorShape({2}), {2, 1});
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_UINT8, TensorShape({2, 1, 1}));
+ test::FillValues<uint8>(&expected, {1, 2});
+ test::ExpectTensorEqual<uint8>(expected, *GetOutput(0));
+}
+
+TEST_F(RandomCropOpTest, SameSizeMultiChannel) {
+ AddInputFromArray<uint8>(TensorShape({2, 1, 3}), {1, 2, 3, 4, 5, 6});
+ AddInputFromArray<int64>(TensorShape({2}), {2, 1});
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_UINT8, TensorShape({2, 1, 3}));
+ test::FillValues<uint8>(&expected, {1, 2, 3, 4, 5, 6});
+ test::ExpectTensorEqual<uint8>(expected, *GetOutput(0));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
new file mode 100644
index 0000000000..09b66d30e6
--- /dev/null
+++ b/tensorflow/core/kernels/random_op.cc
@@ -0,0 +1,276 @@
+// See docs in ../ops/random_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/random_op.h"
+
+#include <algorithm>
+#include <memory>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/util/guarded_philox_random.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+// The default implementation of the functor, which should never be invoked
+// But we still need to provide implementation for now for the linker to work,
+// since we do not support all the distributions yet.
+template <typename Device, class Distribution>
+struct FillPhiloxRandom {
+ typedef typename Distribution::ResultElementType T;
+ void operator()(OpKernelContext*, const Device&, random::PhiloxRandom gen,
+ T* data, int64 size) {
+ LOG(FATAL) << "Default FillPhiloxRandom should not be executed.";
+ }
+};
+
+#if GOOGLE_CUDA
+// Declaration for the partial specialization with GPU
+template <class Distribution>
+struct FillPhiloxRandom<GPUDevice, Distribution> {
+ typedef typename Distribution::ResultElementType T;
+ void operator()(OpKernelContext* ctx, const GPUDevice&,
+ random::PhiloxRandom gen, T* data, int64 size);
+};
+
+#endif
+
+// A class to fill a specified range of random groups
+template <class Distribution, bool VariableSamplesPerOutput>
+struct FillPhiloxRandomTask;
+
+// Specialization for distribution that takes a fixed number of samples for
+// each output.
+template <class Distribution>
+struct FillPhiloxRandomTask<Distribution, false> {
+ typedef typename Distribution::ResultElementType T;
+ static void Run(random::PhiloxRandom gen, T* data, int64 size,
+ int64 start_group, int64 limit_group) {
+ Distribution dist;
+ const int kGroupSize = Distribution::kResultElementCount;
+
+ gen.Skip(start_group);
+ int64 offset = start_group * kGroupSize;
+
+ // First fill all the full-size groups
+ int64 limit_group_full = std::min(limit_group, size / kGroupSize);
+ for (int64 index = start_group; index < limit_group_full; ++index) {
+ auto samples = dist(&gen);
+ std::copy(&samples[0], &samples[0] + kGroupSize, data + offset);
+ offset += kGroupSize;
+ }
+
+ // If there are any remaining elements that need to be filled, process them
+ if (limit_group_full < limit_group) {
+ int remaining_size = size - limit_group_full * kGroupSize;
+ auto samples = dist(&gen);
+ std::copy(&samples[0], &samples[0] + remaining_size, data + offset);
+ }
+ }
+};
+
+// Specialization for distribution that takes a varaiable number of samples for
+// each output. This will be slower due to the generality.
+template <class Distribution>
+struct FillPhiloxRandomTask<Distribution, true> {
+ typedef typename Distribution::ResultElementType T;
+ static const int64 kReservedSamplesPerOutput = 256;
+
+ static void Run(random::PhiloxRandom base_gen, T* data, int64 size,
+ int64 start_group, int64 limit_group) {
+ using random::PhiloxRandom;
+ using random::SingleSampleAdapter;
+
+ Distribution dist;
+ const int kGroupSize = Distribution::kResultElementCount;
+
+ static const int kGeneratorSkipPerOutputGroup =
+ kGroupSize * kReservedSamplesPerOutput /
+ PhiloxRandom::kResultElementCount;
+
+ int64 offset = start_group * kGroupSize;
+
+ // First fill all the full-size groups
+ int64 limit_group_full = std::min(limit_group, size / kGroupSize);
+ int64 group_index;
+ for (group_index = start_group; group_index < limit_group_full;
+ ++group_index) {
+ // Reset the generator to the beginning of the output group region
+ // This is necessary if we want the results to be independent of order
+ // of work
+ PhiloxRandom gen = base_gen;
+ gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
+ SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
+
+ auto samples = dist(&single_samples);
+ std::copy(&samples[0], &samples[0] + kGroupSize, data + offset);
+ offset += kGroupSize;
+ }
+
+ // If there are any remaining elements that need to be filled, process them
+ if (limit_group_full < limit_group) {
+ PhiloxRandom gen = base_gen;
+ gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
+ SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
+
+ int remaining_size = size - limit_group_full * kGroupSize;
+ auto samples = dist(&single_samples);
+ std::copy(&samples[0], &samples[0] + remaining_size, data + offset);
+ }
+ }
+};
+
+// Partial specialization for CPU to fill the entire region with randoms
+// It splits the work into several tasks and run them in parallel
+template <class Distribution>
+struct FillPhiloxRandom<CPUDevice, Distribution> {
+ typedef typename Distribution::ResultElementType T;
+ void operator()(OpKernelContext* context, const CPUDevice&,
+ random::PhiloxRandom gen, T* data, int64 size) {
+ const int kGroupSize = Distribution::kResultElementCount;
+
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+
+ int64 total_group_count = (size + kGroupSize - 1) / kGroupSize;
+
+ // Limit to maximum six threads for now. The performance scaling is very
+ // sub-linear. Too many threads causes a much worse overall performance.
+ int num_workers = 6;
+ Shard(num_workers, worker_threads.workers, total_group_count, kGroupSize,
+ [&gen, data, size](int64 start_group, int64 limit_group) {
+ FillPhiloxRandomTask<
+ Distribution,
+ Distribution::kVariableSamplesPerOutput>::Run(gen, data, size,
+ start_group,
+ limit_group);
+ });
+ }
+};
+} // namespace functor
+
+// For now, use the same interface as RandomOp, so we can choose either one
+// at the run-time.
+template <typename Device, class Distribution>
+class PhiloxRandomOp : public OpKernel {
+ public:
+ typedef typename Distribution::ResultElementType T;
+ explicit PhiloxRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, generator_.Init(ctx));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& input = ctx->input(0);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsLegacyVector(input.shape()),
+ errors::InvalidArgument("shape must be a vector of {int32,int64}."));
+ Tensor* output = nullptr;
+ if (input.dtype() == DataType::DT_INT32) {
+ auto vec = input.flat<int32>();
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShapeUtils::MakeShape(
+ vec.data(), vec.size()),
+ &output));
+ } else if (input.dtype() == DataType::DT_INT64) {
+ auto vec = input.flat<int64>();
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShapeUtils::MakeShape(
+ vec.data(), vec.size()),
+ &output));
+ } else {
+ OP_REQUIRES(ctx, false, errors::InvalidArgument(
+ "shape must be a vector of {int32,int64}."));
+ }
+ functor::FillPhiloxRandom<Device, Distribution>()(
+ ctx, ctx->eigen_device<Device>(),
+ ReserveRandomOutputs(output->flat<T>().size()),
+ output->flat<T>().data(), output->flat<T>().size());
+ }
+
+ private:
+ GuardedPhiloxRandom generator_;
+
+ // Reserve enough random samples in the generator for the given output count.
+ random::PhiloxRandom ReserveRandomOutputs(int64 output_count) {
+ int64 conservative_sample_count = output_count << 8;
+ return generator_.ReserveSamples128(conservative_sample_count);
+ }
+};
+
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RandomUniform") \
+ .Device(DEVICE_CPU) \
+ .HostMemory("shape") \
+ .TypeConstraint<TYPE>("dtype"), \
+ PhiloxRandomOp<CPUDevice, random::UniformDistribution< \
+ random::PhiloxRandom, TYPE> >); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RandomStandardNormal") \
+ .Device(DEVICE_CPU) \
+ .HostMemory("shape") \
+ .TypeConstraint<TYPE>("dtype"), \
+ PhiloxRandomOp<CPUDevice, random::NormalDistribution< \
+ random::PhiloxRandom, TYPE> >); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("TruncatedNormal") \
+ .Device(DEVICE_CPU) \
+ .HostMemory("shape") \
+ .TypeConstraint<TYPE>("dtype"), \
+ PhiloxRandomOp< \
+ CPUDevice, \
+ random::TruncatedNormalDistribution< \
+ random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
+
+REGISTER(float);
+REGISTER(double);
+
+#undef REGISTER
+
+#if GOOGLE_CUDA
+
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RandomUniform") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("shape") \
+ .TypeConstraint<int32>("T") \
+ .TypeConstraint<TYPE>("dtype"), \
+ PhiloxRandomOp<GPUDevice, random::UniformDistribution< \
+ random::PhiloxRandom, TYPE> >); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RandomStandardNormal") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("shape") \
+ .TypeConstraint<int32>("T") \
+ .TypeConstraint<TYPE>("dtype"), \
+ PhiloxRandomOp<GPUDevice, random::NormalDistribution< \
+ random::PhiloxRandom, TYPE> >); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("TruncatedNormal") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("shape") \
+ .TypeConstraint<int32>("T") \
+ .TypeConstraint<TYPE>("dtype"), \
+ PhiloxRandomOp< \
+ GPUDevice, \
+ random::TruncatedNormalDistribution< \
+ random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >)
+
+REGISTER(float);
+REGISTER(double);
+
+#undef REGISTER
+
+#endif // GOOGLE_CUDA
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/random_op.h b/tensorflow/core/kernels/random_op.h
new file mode 100644
index 0000000000..7c7eed4227
--- /dev/null
+++ b/tensorflow/core/kernels/random_op.h
@@ -0,0 +1,16 @@
+#ifndef TENSORFLOW_KERNELS_RANDOM_OP_H_
+#define TENSORFLOW_KERNELS_RANDOM_OP_H_
+
+namespace tensorflow {
+
+class OpKernelContext;
+
+namespace functor {
+
+template <typename Device, class Distribution>
+struct FillPhiloxRandom;
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_RANDOM_OP_H_
diff --git a/tensorflow/core/kernels/random_op_gpu.cu.cc b/tensorflow/core/kernels/random_op_gpu.cu.cc
new file mode 100644
index 0000000000..15cf85f27e
--- /dev/null
+++ b/tensorflow/core/kernels/random_op_gpu.cu.cc
@@ -0,0 +1,152 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/random_op.h"
+
+#include <stdio.h>
+#include <assert.h>
+
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+class OpKernelContext;
+
+namespace functor {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+template <class Distribution, bool VariableSamplesPerOutput>
+struct FillPhiloxRandomKernel;
+
+// A cuda kernel to fill the data with random numbers from the specified
+// distribution. Each output takes a fixed number of samples.
+template <class Distribution>
+struct FillPhiloxRandomKernel<Distribution, false> {
+ typedef typename Distribution::ResultElementType T;
+ PHILOX_DEVICE_FUNC void Run(random::PhiloxRandom gen, T* data, int64 size) {
+ Distribution dist;
+ const int kGroupSize = Distribution::kResultElementCount;
+
+ const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
+ const int32 total_thread_count = gridDim.x * blockDim.x;
+ int32 offset = thread_id * kGroupSize;
+ gen.Skip(thread_id);
+
+ while (offset < size) {
+ typename Distribution::ResultType samples = dist(&gen);
+
+ for (int i = 0; i < kGroupSize; ++i) {
+ if (offset >= size) {
+ return;
+ }
+ data[offset] = samples[i];
+ ++offset;
+ }
+
+ offset += (total_thread_count - 1) * kGroupSize;
+ gen.Skip(total_thread_count - 1);
+ }
+ }
+};
+
+// A cuda kernel to fill the data with random numbers from the specified
+// distribution. Each output takes a variable number of samples.
+template <class Distribution>
+struct FillPhiloxRandomKernel<Distribution, true> {
+ typedef typename Distribution::ResultElementType T;
+ PHILOX_DEVICE_FUNC void Run(const random::PhiloxRandom& base_gen, T* data,
+ int64 size) {
+ using random::PhiloxRandom;
+ using random::SingleSampleAdapter;
+
+ const int kReservedSamplesPerOutput = 256;
+ const int kGroupSize = Distribution::kResultElementCount;
+ const int kGeneratorSkipPerOutputGroup = kGroupSize *
+ kReservedSamplesPerOutput /
+ PhiloxRandom::kResultElementCount;
+
+ const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
+ const int32 total_thread_count = gridDim.x * blockDim.x;
+ int64 group_index = thread_id;
+ int64 offset = group_index * kGroupSize;
+ Distribution dist;
+
+ while (offset < size) {
+ // Since each output takes a variable number of samples, we need to
+ // realign the generator to the beginning for the current output group
+ PhiloxRandom gen = base_gen;
+ gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
+ SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
+
+ typename Distribution::ResultType samples = dist(&single_samples);
+
+ for (int i = 0; i < kGroupSize; ++i) {
+ if (offset >= size) {
+ return;
+ }
+ data[offset] = samples[i];
+ ++offset;
+ }
+
+ offset += (total_thread_count - 1) * kGroupSize;
+ group_index += total_thread_count;
+ }
+ }
+};
+
+// A simple launch pad to call the correct function templates to fill the data
+template <class Distribution>
+__global__ void __launch_bounds__(1024)
+ FillPhiloxRandomKernelLaunch(random::PhiloxRandom base_gen,
+ typename Distribution::ResultElementType* data,
+ int64 size) {
+ FillPhiloxRandomKernel<Distribution,
+ Distribution::kVariableSamplesPerOutput>()
+ .Run(base_gen, data, size);
+}
+
+// Partial specialization for GPU
+template <class Distribution>
+struct FillPhiloxRandom<GPUDevice, Distribution> {
+ typedef typename Distribution::ResultElementType T;
+ typedef GPUDevice Device;
+ void operator()(OpKernelContext*, const Device& d, random::PhiloxRandom gen,
+ T* data, int64 size) {
+ const int32 block_size = d.maxCudaThreadsPerBlock();
+ const int32 num_blocks =
+ (d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor()) /
+ block_size;
+
+ FillPhiloxRandomKernelLaunch<
+ Distribution><<<num_blocks, block_size, 0, d.stream()>>>(gen, data,
+ size);
+ }
+};
+
+// Explicit instantiation of the GPU distributions functors
+// clang-format off
+// NVCC cannot handle ">>" properly
+template struct FillPhiloxRandom<
+ GPUDevice, random::UniformDistribution<random::PhiloxRandom, float> >;
+template struct FillPhiloxRandom<
+ GPUDevice, random::UniformDistribution<random::PhiloxRandom, double> >;
+template struct FillPhiloxRandom<
+ GPUDevice, random::NormalDistribution<random::PhiloxRandom, float> >;
+template struct FillPhiloxRandom<
+ GPUDevice, random::NormalDistribution<random::PhiloxRandom, double> >;
+template struct FillPhiloxRandom<
+ GPUDevice, random::TruncatedNormalDistribution<
+ random::SingleSampleAdapter<random::PhiloxRandom>, float> >;
+template struct FillPhiloxRandom<
+ GPUDevice, random::TruncatedNormalDistribution<
+ random::SingleSampleAdapter<random::PhiloxRandom>, double> >;
+// clang-format on
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/random_op_test.cc b/tensorflow/core/kernels/random_op_test.cc
new file mode 100644
index 0000000000..751b61cfba
--- /dev/null
+++ b/tensorflow/core/kernels/random_op_test.cc
@@ -0,0 +1,99 @@
+#include <random>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+Tensor Int32(int32 v) {
+ Tensor t(DT_INT32, TensorShape({}));
+ t.scalar<int32>()() = v;
+ return t;
+}
+
+Graph* RandomUniform(int64 n) {
+ Graph* g = new Graph(OpRegistry::Global());
+ test::graph::RandomUniform(g, test::graph::Constant(g, Int32(n)), DT_FLOAT);
+ return g;
+}
+
+Graph* RandomNormal(int64 n) {
+ Graph* g = new Graph(OpRegistry::Global());
+ test::graph::RandomGaussian(g, test::graph::Constant(g, Int32(n)), DT_FLOAT);
+ return g;
+}
+
+Graph* RandomParameters(int64 n) {
+ Graph* g = new Graph(OpRegistry::Global());
+ test::graph::RandomParameters(g, test::graph::Constant(g, Int32(n)),
+ DT_FLOAT);
+ return g;
+}
+
+#define BM_RNG(DEVICE, RNG) \
+ static void BM_##DEVICE##_##RNG(int iters, int arg) { \
+ testing::ItemsProcessed(static_cast<int64>(iters) * arg); \
+ test::Benchmark(#DEVICE, RNG(arg)).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_##RNG)->Range(1 << 20, 8 << 20);
+
+BM_RNG(cpu, RandomUniform);
+BM_RNG(cpu, RandomNormal);
+BM_RNG(cpu, RandomParameters);
+
+BM_RNG(gpu, RandomUniform);
+BM_RNG(gpu, RandomNormal);
+BM_RNG(gpu, RandomParameters);
+
+static void BM_PhiloxRandom(int iters) {
+ // Fill 2M random numbers
+ int count = 2 << 20;
+
+ testing::ItemsProcessed(static_cast<int64>(iters) * count);
+
+ random::PhiloxRandom gen(0x12345);
+
+ int val = 1;
+ for (int i = 0; i < iters; ++i) {
+ for (int j = 0; j < count; j += 4) {
+ /// each invocation of gen() returns 128-bit samples
+ auto samples = gen();
+
+ // use the result trivially so the compiler does not optimize it away
+ val ^= samples[0] ^ samples[1] ^ samples[2] ^ samples[3];
+ }
+ }
+
+ // A anchor point to make sure the compiler does not cut corners
+ CHECK(val) << val;
+}
+BENCHMARK(BM_PhiloxRandom);
+
+static void BM_StdMTRandom(int iters) {
+ // Fill 2M random numbers
+ int count = 2 << 20;
+
+ testing::ItemsProcessed(static_cast<int64>(iters) * count);
+
+ std::mt19937 gen(0x12345);
+
+ int val = 1;
+ for (int i = 0; i < iters; ++i) {
+ for (int j = 0; j < count; ++j) {
+ /// each invocation of gen() returns 32-bit sample
+ uint32 sample = gen();
+
+ // use the result trivially so the compiler does not optimize it away
+ val ^= sample;
+ }
+ }
+
+ // A anchor point to make sure the compiler does not cut corners
+ CHECK(val) << val;
+}
+BENCHMARK(BM_StdMTRandom);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/random_shuffle_op.cc b/tensorflow/core/kernels/random_shuffle_op.cc
new file mode 100644
index 0000000000..b87f4e58a0
--- /dev/null
+++ b/tensorflow/core/kernels/random_shuffle_op.cc
@@ -0,0 +1,89 @@
+// See docs in ../ops/random_ops.cc.
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/util/guarded_philox_random.h"
+
+namespace tensorflow {
+
+// TODO(irving): If performance is critical, generate output directly instead
+// of an in-place shuffle using a pseudorandom permutation like
+//
+// https://github.com/otherlab/geode/blob/master/geode/random/permute.cpp
+//
+// This is probably also the right thing if we want a GPU version of shuffling.
+
+// We use our own version of std::random_shuffle to guarantee that exactly
+// size - 1 samples are used.
+template <class Iter, class Random>
+static inline void RandomShuffle(Iter first, Iter last, Random& uniform) {
+ if (first == last) return;
+ const auto stop = last - 1;
+ for (auto i = first; i != stop; ++i) {
+ using std::iter_swap;
+ iter_swap(i, i + uniform(last - i));
+ }
+}
+
+template <typename T>
+class RandomShuffleOp : public OpKernel {
+ public:
+ explicit RandomShuffleOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, generator_.Init(context));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+
+ if (input.NumElements() <= 1 || input.dim_size(0) <= 1) {
+ // No shuffling is required, so copy input directly to output
+ context->set_output(0, input);
+ } else {
+ // Reserve enough random samples for shuffling
+ const int64 size = input.dim_size(0);
+ const int64 samples = size - 1;
+ auto local_gen = generator_.ReserveSamples32(samples);
+ random::SingleSampleAdapter<random::PhiloxRandom> single(&local_gen);
+ const auto uniform = [&single](uint32 n) { return single() % n; };
+
+ if (input.dims() == 1) {
+ // For 1D data, copy and then shuffle in place
+ context->set_output(0, tensor::DeepCopy(input));
+ auto vec = context->mutable_output(0)->vec<T>();
+ RandomShuffle(vec.data(), vec.data() + size, uniform);
+ } else {
+ // For >= 2D, shuffle indices and then copy across
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+ const auto input_mat = input.flat_outer_dims<T>();
+ auto output_mat = output->flat_outer_dims<T>();
+ std::vector<int> permutation(size);
+ for (int i = 0; i < size; i++) {
+ permutation[i] = i;
+ }
+ RandomShuffle(permutation.begin(), permutation.end(), uniform);
+ for (int i = 0; i < size; i++) {
+ output_mat.template chip<0>(i) =
+ input_mat.template chip<0>(permutation[i]);
+ }
+ }
+ }
+ }
+
+ private:
+ GuardedPhiloxRandom generator_;
+};
+
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("RandomShuffle").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ RandomShuffleOp<T>);
+TF_CALL_ALL_TYPES(REGISTER)
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc
new file mode 100644
index 0000000000..561ec76e53
--- /dev/null
+++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc
@@ -0,0 +1,740 @@
+// See docs in ../ops/data_flow_ops.cc.
+
+#include <deque>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/queue_base.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+class RandomShuffleQueue : public QueueBase {
+ public:
+ RandomShuffleQueue(int32 capacity, int32 min_after_dequeue, int64 seed,
+ int64 seed2, const DataTypeVector& component_dtypes,
+ const std::vector<TensorShape>& component_shapes,
+ const string& name);
+ Status Initialize(); // Must be called before any other method.
+
+ // Implementations of QueueInterface methods --------------------------------
+
+ Status ValidateTuple(const Tuple& tuple) override;
+ Status ValidateManyTuple(const Tuple& tuple) override;
+ void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
+ DoneCallback callback) override;
+ void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
+ DoneCallback callback) override;
+ void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
+ void TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ CallbackWithTuple callback) override;
+ void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
+ DoneCallback callback) override;
+ Status MatchesNodeDef(const NodeDef& node_def) override;
+
+ int32 size() override {
+ mutex_lock lock(mu_);
+ return queues_[0].size();
+ }
+
+ private:
+ enum Action { kEnqueue, kDequeue };
+
+ ~RandomShuffleQueue() override {}
+
+ TensorShape ManyOutShape(int i, int batch_size) {
+ TensorShape shape({batch_size});
+ shape.AppendShape(component_shapes_[i]);
+ return shape;
+ }
+
+ // Helper for dequeuing a single random element from queues_.
+ void DequeueLocked(OpKernelContext* ctx, Tuple* tuple)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ void Cancel(Action action, CancellationToken token);
+
+ // Helper for cancelling all pending Enqueue(Many) operations when
+ // Close is called with cancel_pending_enqueues.
+ void CloseAndCancel();
+
+ // Tries to enqueue/dequeue (or close) based on whatever is at the
+ // front of enqueue_attempts_/dequeue_attempts_. Appends to
+ // *finished the callback for any finished attempt (so it may be
+ // called once mu_ is released). Returns true if any progress was
+ // made.
+ struct CleanUp {
+ CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
+ : finished(f), to_deregister(ct), cm(cm) {}
+ DoneCallback finished;
+ CancellationToken to_deregister;
+ CancellationManager* cm;
+ };
+ bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Tries to make progress on the enqueues or dequeues at the front
+ // of the *_attempts_ queues.
+ void FlushUnlocked();
+
+ const int32 capacity_;
+ const int32 min_after_dequeue_;
+ const int64 original_seed_;
+ const int64 original_seed2_;
+
+ mutex mu_;
+ typedef std::vector<PersistentTensor> SubQueue;
+ std::vector<SubQueue> queues_ GUARDED_BY(mu_);
+ bool closed_ GUARDED_BY(mu_);
+ random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
+ random::SingleSampleAdapter<random::PhiloxRandom> generator_ GUARDED_BY(mu_);
+
+ enum RunResult { kNoProgress, kProgress, kComplete };
+ struct Attempt;
+ typedef std::function<RunResult(Attempt*)> RunCallback;
+ struct Attempt {
+ int32 elements_requested;
+ DoneCallback done_callback; // must be run outside mu_
+ OpKernelContext* context;
+ CancellationToken cancellation_token;
+ RunCallback run_callback; // must be run while holding mu_
+ bool is_cancelled;
+ Tuple tuple;
+
+ Attempt(int32 elements_requested, DoneCallback done_callback,
+ OpKernelContext* context, CancellationToken cancellation_token,
+ RunCallback run_callback)
+ : elements_requested(elements_requested),
+ done_callback(done_callback),
+ context(context),
+ cancellation_token(cancellation_token),
+ run_callback(run_callback),
+ is_cancelled(false) {}
+ };
+ std::deque<Attempt> enqueue_attempts_ GUARDED_BY(mu_);
+ std::deque<Attempt> dequeue_attempts_ GUARDED_BY(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueue);
+};
+
+RandomShuffleQueue::RandomShuffleQueue(
+ int capacity, int min_after_dequeue, int64 seed, int64 seed2,
+ const DataTypeVector& component_dtypes,
+ const std::vector<TensorShape>& component_shapes, const string& name)
+ : QueueBase(component_dtypes, component_shapes, name),
+ capacity_(capacity),
+ min_after_dequeue_(min_after_dequeue),
+ original_seed_(seed),
+ original_seed2_(seed2),
+ closed_(false),
+ generator_(&parent_generator_) {
+ if (seed == 0 && seed2 == 0) {
+ // If both seeds are unspecified, use completely random seeds.
+ seed = random::New64();
+ seed2 = random::New64();
+ }
+ parent_generator_ = random::PhiloxRandom(seed, seed2);
+}
+
+Status RandomShuffleQueue::Initialize() {
+ if (component_dtypes_.empty()) {
+ return errors::InvalidArgument("Empty component types for queue ", name_);
+ }
+ if (!component_shapes_.empty() &&
+ component_dtypes_.size() != component_shapes_.size()) {
+ return errors::InvalidArgument("Different number of component types (",
+ component_dtypes_.size(), ") vs. shapes (",
+ component_shapes_.size(), ").");
+ }
+
+ mutex_lock lock(mu_);
+ queues_.reserve(num_components());
+ for (int i = 0; i < num_components(); ++i) {
+ queues_.push_back(SubQueue());
+ queues_.back().reserve(min_after_dequeue_);
+ }
+ return Status::OK();
+}
+
+// TODO(mrry): If these checks become a bottleneck, find a way to
+// reduce the number of times that they are called.
+Status RandomShuffleQueue::ValidateTuple(const Tuple& tuple) {
+ TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
+ if (specified_shapes()) {
+ for (size_t i = 0; i < tuple.size(); ++i) {
+ if (!tuple[i].shape().IsSameSize(component_shapes_[i])) {
+ return errors::InvalidArgument(
+ "Shape mismatch in tuple component ", i, ". Expected ",
+ component_shapes_[i].ShortDebugString(), ", got ",
+ tuple[i].shape().ShortDebugString());
+ }
+ }
+ }
+ return Status::OK();
+}
+
+// TODO(mrry): If these checks become a bottleneck, find a way to
+// reduce the number of times that they are called.
+Status RandomShuffleQueue::ValidateManyTuple(const Tuple& tuple) {
+ TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
+ const int64 batch_size = tuple[0].dim_size(0);
+ if (specified_shapes()) {
+ for (size_t i = 0; i < tuple.size(); ++i) {
+ // Expected shape is [batch_size] + component_shapes_[i]
+ const TensorShape expected_shape = ManyOutShape(i, batch_size);
+ if (!tuple[i].shape().IsSameSize(expected_shape)) {
+ return errors::InvalidArgument(
+ "Shape mismatch in tuple component ", i, ". Expected ",
+ expected_shape.ShortDebugString(), ", got ",
+ tuple[i].shape().ShortDebugString());
+ }
+ }
+ } else {
+ for (size_t i = 1; i < tuple.size(); ++i) {
+ if (tuple[i].dim_size(0) != batch_size) {
+ return errors::InvalidArgument(
+ "All input tensors must have the same size in the 0th ",
+ "dimension. Component ", i, " has ", tuple[i].dim_size(0),
+ ", and should have ", batch_size);
+ }
+ }
+ }
+ return Status::OK();
+}
+
+void RandomShuffleQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
+ DCHECK_GT(queues_[0].size(), 0);
+ int64 index = generator_() % queues_[0].size();
+ (*tuple).reserve(num_components());
+ for (int i = 0; i < num_components(); ++i) {
+ (*tuple).push_back(*queues_[i][index].AccessTensor(ctx));
+ queues_[i][index] = queues_[i].back();
+ queues_[i].pop_back();
+ }
+}
+
+void RandomShuffleQueue::Cancel(Action action, CancellationToken token) {
+ DoneCallback callback = nullptr;
+ {
+ mutex_lock lock(mu_);
+ std::deque<Attempt>* attempts =
+ action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
+
+ for (Attempt& attempt : *attempts) {
+ if (attempt.cancellation_token == token) {
+ attempt.is_cancelled = true;
+ if (action == kEnqueue) {
+ attempt.context->SetStatus(
+ errors::Cancelled("Enqueue operation was cancelled"));
+ } else {
+ attempt.context->SetStatus(
+ errors::Cancelled("Dequeue operation was cancelled"));
+ }
+ std::swap(callback, attempt.done_callback);
+ break;
+ }
+ }
+ }
+ if (callback) {
+ callback();
+ FlushUnlocked();
+ }
+}
+
+void RandomShuffleQueue::CloseAndCancel() {
+ std::vector<DoneCallback> callbacks;
+ {
+ mutex_lock lock(mu_);
+ closed_ = true;
+ for (Attempt& attempt : enqueue_attempts_) {
+ attempt.is_cancelled = true;
+ attempt.context->SetStatus(
+ errors::Cancelled("Enqueue operation was cancelled"));
+ callbacks.emplace_back(std::move(attempt.done_callback));
+ }
+ }
+ for (const DoneCallback& callback : callbacks) {
+ callback();
+ }
+ FlushUnlocked();
+}
+
+bool RandomShuffleQueue::TryAttemptLocked(
+ Action action, std::vector<CleanUp>* clean_up) {
+ std::deque<Attempt>* attempts =
+ action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
+
+ bool progress = false;
+ bool done = false;
+ while (!done && !attempts->empty()) {
+ if (attempts->front().is_cancelled) {
+ if (action == kEnqueue) {
+ LOG(INFO) << "Skipping cancelled enqueue attempt";
+ } else {
+ LOG(INFO) << "Skipping cancelled dequeue attempt";
+ }
+ attempts->pop_front();
+ } else {
+ Attempt* cur_attempt = &attempts->front();
+ switch (cur_attempt->run_callback(cur_attempt)) {
+ case kNoProgress:
+ done = true;
+ break;
+ case kProgress:
+ done = true;
+ progress = true;
+ break;
+ case kComplete:
+ progress = true;
+ clean_up->emplace_back(std::move(cur_attempt->done_callback),
+ cur_attempt->cancellation_token,
+ cur_attempt->context->cancellation_manager());
+ attempts->pop_front();
+ break;
+ }
+ }
+ }
+ return progress;
+}
+
+void RandomShuffleQueue::FlushUnlocked() {
+ std::vector<CleanUp> clean_up;
+ Ref();
+ {
+ mutex_lock lock(mu_);
+ bool changed;
+ do {
+ changed = TryAttemptLocked(kEnqueue, &clean_up);
+ changed = TryAttemptLocked(kDequeue, &clean_up) || changed;
+ } while (changed);
+ }
+ Unref();
+ for (const auto& to_clean : clean_up) {
+ if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
+ // NOTE(mrry): We can safely ignore the return value of
+ // DeregisterCallback because the mutex mu_ ensures that the
+ // cleanup action only executes once.
+ to_clean.cm->DeregisterCallback(to_clean.to_deregister);
+ }
+ to_clean.finished();
+ }
+}
+
+void RandomShuffleQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
+ DoneCallback callback) {
+ CancellationManager* cm = ctx->cancellation_manager();
+ CancellationToken token = cm->get_cancellation_token();
+ bool already_cancelled;
+ {
+ mutex_lock l(mu_);
+ already_cancelled = !cm->RegisterCallback(
+ token, [this, token]() { Cancel(kEnqueue, token); });
+ if (!already_cancelled) {
+ enqueue_attempts_.emplace_back(
+ 1, callback, ctx, token,
+ [tuple, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (closed_) {
+ attempt->context->SetStatus(errors::Aborted(
+ "RandomShuffleQueue '", name_, "' is closed."));
+ return kComplete;
+ }
+ if (queues_[0].size() < static_cast<size_t>(capacity_)) {
+ for (int i = 0; i < num_components(); ++i) {
+ queues_[i].push_back(PersistentTensor(tuple[i]));
+ }
+ return kComplete;
+ } else {
+ return kNoProgress;
+ }
+ });
+ }
+ }
+ if (!already_cancelled) {
+ FlushUnlocked();
+ } else {
+ ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
+ callback();
+ }
+}
+
+void RandomShuffleQueue::TryEnqueueMany(const Tuple& tuple,
+ OpKernelContext* ctx,
+ DoneCallback callback) {
+ const int64 batch_size = tuple[0].dim_size(0);
+ if (batch_size == 0) {
+ callback();
+ return;
+ }
+
+ CancellationManager* cm = ctx->cancellation_manager();
+ CancellationToken token = cm->get_cancellation_token();
+ bool already_cancelled;
+ {
+ mutex_lock l(mu_);
+ already_cancelled = !cm->RegisterCallback(
+ token, [this, token]() { Cancel(kEnqueue, token); });
+ if (!already_cancelled) {
+ enqueue_attempts_.emplace_back(
+ batch_size, callback, ctx, token,
+ [tuple, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (closed_) {
+ attempt->context->SetStatus(errors::Aborted(
+ "RandomShuffleQueue '", name_, "' is closed."));
+ return kComplete;
+ }
+ RunResult result = kNoProgress;
+ while (queues_[0].size() < static_cast<size_t>(capacity_)) {
+ result = kProgress;
+ const int index =
+ tuple[0].dim_size(0) - attempt->elements_requested;
+ for (int i = 0; i < num_components(); ++i) {
+ TensorShape element_shape(tuple[i].shape());
+ element_shape.RemoveDim(0);
+ PersistentTensor element;
+ Tensor* element_access = nullptr;
+ attempt->context->allocate_persistent(
+ tuple[i].dtype(), element_shape, &element, &element_access);
+ attempt->context->SetStatus(
+ CopySliceToElement(tuple[i], element_access, index));
+ if (!attempt->context->status().ok()) return kComplete;
+ queues_[i].push_back(element);
+ }
+ --attempt->elements_requested;
+ if (attempt->elements_requested == 0) {
+ return kComplete;
+ }
+ }
+ return result;
+ });
+ }
+ }
+ if (!already_cancelled) {
+ FlushUnlocked();
+ } else {
+ ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
+ callback();
+ }
+}
+
+void RandomShuffleQueue::TryDequeue(OpKernelContext* ctx,
+ CallbackWithTuple callback) {
+ CancellationManager* cm = ctx->cancellation_manager();
+ CancellationToken token = cm->get_cancellation_token();
+ bool already_cancelled;
+ {
+ mutex_lock l(mu_);
+ already_cancelled = !cm->RegisterCallback(
+ token, [this, token]() { Cancel(kDequeue, token); });
+ if (!already_cancelled) {
+ // TODO(josh11b): This makes two copies of callback, avoid this if possible.
+ dequeue_attempts_.emplace_back(
+ 1, [callback]() { callback(Tuple()); }, ctx, token,
+ [callback, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int32 s = queues_[0].size();
+ if (closed_ && s == 0) {
+ attempt->context->SetStatus(errors::OutOfRange(
+ "RandomShuffleQueue '", name_, "' is closed and has ",
+ "insufficient elements (requested ", 1, ", current size ", s,
+ ")"));
+ return kComplete;
+ }
+ if (!closed_) s -= min_after_dequeue_;
+ if (s > 0) {
+ Tuple tuple;
+ DequeueLocked(attempt->context, &tuple);
+ attempt->done_callback = [callback, tuple]() { callback(tuple); };
+ return kComplete;
+ } else {
+ return kNoProgress;
+ }
+ });
+ }
+ }
+ if (!already_cancelled) {
+ FlushUnlocked();
+ } else {
+ ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
+ callback(Tuple());
+ }
+}
+
+void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ CallbackWithTuple callback) {
+ if (!specified_shapes()) {
+ ctx->SetStatus(
+ errors::InvalidArgument("RandomShuffleQueue's DequeueMany requires the "
+ "components to have specified shapes."));
+ callback(Tuple());
+ return;
+ }
+ if (num_elements == 0) {
+ Tuple tuple;
+ tuple.reserve(num_components());
+ for (int i = 0; i < num_components(); ++i) {
+ // TODO(josh11b,misard): Switch to allocate_output(). Problem is
+ // this breaks the abstraction boundary since we don't *really*
+ // know if and how the Tensors in the tuple we pass to callback
+ // correspond to the outputs of *ctx. For example, the
+ // ReaderRead Op uses TryDequeue() to get a filename out of a
+ // queue that is used internally by the reader and is not
+ // associated with any output of the ReaderRead.
+ // mrry@ adds:
+ // Maybe we need to pass a std::function<Tensor*(...)> (or
+ // better signature) that calls the appropriate allocator
+ // function in addition to ctx? (Or support a shim Allocator
+ // that has an internal OpKernelContext*, and dispatches to the
+ // appropriate method?)
+ // misard@ adds:
+ // I don't see that a std::function would help. The problem is
+ // that at this point (allocation time) the system doesn't know
+ // what is going to happen to the element read out of the
+ // queue. As long as we keep the generality that TensorFlow Ops
+ // do their own dynamic allocation in arbitrary C++ code, we
+ // need to preserve robustness to allocating output Tensors with
+ // the 'wrong' attributes, and fixing up with a copy. The only
+ // improvement I can see here in the future would be to support
+ // an optimized case where the queue 'knows' what attributes to
+ // use, and plumbs them through here.
+ Tensor element;
+ ctx->allocate_temp(component_dtypes_[i], ManyOutShape(i, 0), &element);
+ tuple.emplace_back(element);
+ }
+ callback(tuple);
+ return;
+ }
+
+ CancellationManager* cm = ctx->cancellation_manager();
+ CancellationToken token = cm->get_cancellation_token();
+ bool already_cancelled;
+ {
+ mutex_lock l(mu_);
+ already_cancelled = !cm->RegisterCallback(
+ token, [this, token]() { Cancel(kDequeue, token); });
+ if (!already_cancelled) {
+ // TODO(josh11b): This makes two copies of callback, avoid this if possible.
+ dequeue_attempts_.emplace_back(
+ num_elements, [callback]() { callback(Tuple()); }, ctx, token,
+ [callback, this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int32 s = queues_[0].size();
+ if (closed_ && s < attempt->elements_requested) {
+ attempt->context->SetStatus(errors::OutOfRange(
+ "RandomSuffleQueue '", name_, "' is closed and has ",
+ "insufficient elements (requested ",
+ attempt->elements_requested, ", current size ", s, ")"));
+ return kComplete;
+ }
+
+ RunResult result = kNoProgress;
+ if (!closed_) s -= min_after_dequeue_;
+ for (; s > 0; --s) {
+ if (attempt->tuple.empty()) {
+ // Only allocate tuple when we have something to dequeue
+ // so we don't use exceessive memory when there are many
+ // blocked dequeue attempts waiting.
+ attempt->tuple.reserve(num_components());
+ for (int i = 0; i < num_components(); ++i) {
+ const TensorShape shape =
+ ManyOutShape(i, attempt->elements_requested);
+ Tensor element;
+ attempt->context->allocate_temp(component_dtypes_[i], shape,
+ &element);
+ attempt->tuple.emplace_back(element);
+ }
+ }
+ result = kProgress;
+ Tuple tuple;
+ DequeueLocked(attempt->context, &tuple);
+ const int index =
+ attempt->tuple[0].dim_size(0) - attempt->elements_requested;
+ for (int i = 0; i < num_components(); ++i) {
+ attempt->context->SetStatus(
+ CopyElementToSlice(tuple[i], &attempt->tuple[i], index));
+ if (!attempt->context->status().ok()) return kComplete;
+ }
+ tuple.clear();
+ --attempt->elements_requested;
+ if (attempt->elements_requested == 0) {
+ tuple = attempt->tuple;
+ attempt->done_callback = [callback, tuple]() {
+ callback(tuple);
+ };
+ return kComplete;
+ }
+ }
+ return result;
+ });
+ }
+ }
+ if (!already_cancelled) {
+ FlushUnlocked();
+ } else {
+ ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
+ callback(Tuple());
+ }
+}
+
+void RandomShuffleQueue::Close(OpKernelContext* ctx,
+ bool cancel_pending_enqueues,
+ DoneCallback callback) {
+ if (cancel_pending_enqueues) {
+ CloseAndCancel();
+ callback();
+ } else {
+ {
+ mutex_lock lock(mu_);
+ enqueue_attempts_.emplace_back(
+ 0, callback, ctx, CancellationManager::kInvalidToken,
+ [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (closed_) {
+ attempt->context->SetStatus(errors::Aborted(
+ "RandomShuffleQueue '", name_, "' is already closed."));
+ } else {
+ closed_ = true;
+ }
+ return kComplete;
+ });
+ }
+ FlushUnlocked();
+ }
+}
+
+Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) {
+ TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "RandomShuffleQueue"));
+ TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
+
+ int32 min_after_dequeue = -1;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node_def, "min_after_dequeue", &min_after_dequeue));
+ if (min_after_dequeue != min_after_dequeue_) {
+ return errors::InvalidArgument(
+ "Shared queue '", name_, "' has min_after_dequeue ",
+ min_after_dequeue_, " but requested min_after_dequeue was ",
+ min_after_dequeue, ".");
+ }
+
+ int64 seed = -1;
+ int64 seed2 = -1;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "seed", &seed));
+ TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "seed2", &seed2));
+ if ((seed != 0 || seed2 != 0) &&
+ (seed != original_seed_ || seed2 != original_seed2_)) {
+ return errors::InvalidArgument(
+ "Shared queue '", name_, "' has random seeds (", original_seed_, ", ",
+ original_seed2_, ") but requested seeds are (", seed, ", ", seed2,
+ ").");
+ }
+
+ TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def));
+ TF_RETURN_IF_ERROR(MatchesNodeDefShapes(node_def));
+
+ return Status::OK();
+}
+
+typedef std::shared_ptr<QueueInterface> QueueInterfacePtr;
+
+// Defines a RandomShuffleQueueOp, which produces a Queue (specifically, one
+// backed by RandomShuffleQueue) that persists across different graph
+// executions, and sessions. Running this op produces a single-element
+// tensor of handles to Queues in the corresponding device.
+class RandomShuffleQueueOp : public OpKernel {
+ public:
+ explicit RandomShuffleQueueOp(OpKernelConstruction* context)
+ : OpKernel(context), queue_handle_set_(false) {
+ OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
+ OP_REQUIRES_OK(context,
+ context->allocate_persistent(DT_STRING, TensorShape({2}),
+ &queue_handle_, nullptr));
+ if (capacity_ < 0) {
+ capacity_ = RandomShuffleQueue::kUnbounded;
+ }
+ OP_REQUIRES_OK(context,
+ context->GetAttr("min_after_dequeue", &min_after_dequeue_));
+ OP_REQUIRES(context, min_after_dequeue_ >= 0,
+ errors::InvalidArgument("min_after_dequeue ",
+ min_after_dequeue_, " must be >= 0"));
+ OP_REQUIRES(
+ context, min_after_dequeue_ < capacity_,
+ errors::InvalidArgument("min_after_dequeue ", min_after_dequeue_,
+ " must be < capacity ", capacity_));
+ OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
+ OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
+
+ OP_REQUIRES_OK(context,
+ context->GetAttr("component_types", &component_types_));
+ OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
+ }
+
+ ~RandomShuffleQueueOp() override {
+ // If the queue object was not shared, delete it.
+ if (queue_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
+ TF_CHECK_OK(cinfo_.resource_manager()->Delete<QueueInterface>(
+ cinfo_.container(), cinfo_.name()));
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ mutex_lock l(mu_);
+ if (!queue_handle_set_) {
+ OP_REQUIRES_OK(ctx, SetQueueHandle(ctx));
+ }
+ ctx->set_output_ref(0, &mu_, queue_handle_.AccessTensor(ctx));
+ }
+
+ private:
+ Status SetQueueHandle(OpKernelContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def()));
+ QueueInterface* queue;
+ auto creator = [this](QueueInterface** ret) {
+ auto* q = new RandomShuffleQueue(capacity_, min_after_dequeue_, seed_,
+ seed2_, component_types_,
+ component_shapes_, cinfo_.name());
+ Status s = q->Initialize();
+ if (s.ok()) {
+ *ret = q;
+ } else {
+ q->Unref();
+ }
+ return s;
+ };
+ TF_RETURN_IF_ERROR(
+ cinfo_.resource_manager()->LookupOrCreate<QueueInterface>(
+ cinfo_.container(), cinfo_.name(), &queue, creator));
+ core::ScopedUnref unref_me(queue);
+ // Verify that the shared queue is compatible with the requested arguments.
+ TF_RETURN_IF_ERROR(queue->MatchesNodeDef(def()));
+ auto h = queue_handle_.AccessTensor(ctx)->flat<string>();
+ h(0) = cinfo_.container();
+ h(1) = cinfo_.name();
+ queue_handle_set_ = true;
+ return Status::OK();
+ }
+
+ int32 capacity_;
+ int32 min_after_dequeue_;
+ int64 seed_;
+ int64 seed2_;
+ DataTypeVector component_types_;
+ std::vector<TensorShape> component_shapes_;
+ ContainerInfo cinfo_;
+
+ mutex mu_;
+ PersistentTensor queue_handle_ GUARDED_BY(mu_);
+ bool queue_handle_set_ GUARDED_BY(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueueOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueue").Device(DEVICE_CPU),
+ RandomShuffleQueueOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/range_sampler.cc b/tensorflow/core/kernels/range_sampler.cc
new file mode 100644
index 0000000000..a3f4e0b0cb
--- /dev/null
+++ b/tensorflow/core/kernels/range_sampler.cc
@@ -0,0 +1,305 @@
+#include "tensorflow/core/kernels/range_sampler.h"
+
+#include <vector>
+#include <unordered_set>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+using gtl::ArraySlice;
+using gtl::MutableArraySlice;
+
+RangeSampler::~RangeSampler() {}
+
+void RangeSampler::SampleBatch(random::SimplePhilox* rnd, bool unique,
+ gtl::MutableArraySlice<int64> batch) const {
+ SampleBatchGetExpectedCount(
+ rnd, unique, batch, gtl::MutableArraySlice<float>(),
+ gtl::ArraySlice<int64>(), gtl::MutableArraySlice<float>());
+}
+
+void RangeSampler::SampleBatchGetExpectedCount(
+ random::SimplePhilox* rnd, bool unique, gtl::MutableArraySlice<int64> batch,
+ gtl::MutableArraySlice<float> batch_expected_count,
+ gtl::ArraySlice<int64> extras,
+ gtl::MutableArraySlice<float> extras_expected_count) const {
+ SampleBatchGetExpectedCountAvoid(rnd, unique, batch, batch_expected_count,
+ extras, extras_expected_count,
+ gtl::ArraySlice<int64>());
+}
+
+namespace {
+
+// Approximates the expected count of a value in the output of SampleBatch.
+//
+// If unique=false, then this is (Probability(value) * batch_size)
+//
+// We use batch_size and num_tries, where num_tries is the observed number of
+// tries it took to get batch_size unique values.
+//
+// Assuming (falsely) that the nubmer of tries to get a batch of batch_size
+// distinct values is _always_ num_tries, the probability that the value
+// is in a batch is (1 - (1-p)^num_tries)
+static float ExpectedCountHelper(float p, int batch_size, int num_tries) {
+ if (num_tries == batch_size) {
+ // This shortcut will always be taken if unique=false
+ return p * batch_size;
+ }
+ // numerically stable version of (1 - (1-p)^num_tries)
+ return -expm1(num_tries * log1p(-p));
+}
+
+} // namespace
+
+void RangeSampler::SampleBatchGetExpectedCountAvoid(
+ random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
+ MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
+ MutableArraySlice<float> extras_expected_count,
+ ArraySlice<int64> avoided_values) const {
+ const int batch_size = batch.size();
+ int num_tries;
+
+ if (unique) {
+ CHECK_LE(batch_size + avoided_values.size(), range_);
+ std::unordered_set<int64> used(batch_size);
+ used.insert(avoided_values.begin(), avoided_values.end());
+ int num_picked = 0;
+ num_tries = 0;
+ while (num_picked < batch_size) {
+ num_tries++;
+ CHECK_LT(num_tries, kint32max);
+ int64 value = Sample(rnd);
+ if (gtl::InsertIfNotPresent(&used, value)) {
+ batch[num_picked++] = value;
+ }
+ }
+ } else {
+ CHECK_EQ(avoided_values.size(), 0)
+ << "avoided_values only supported with unique=true";
+ for (int i = 0; i < batch_size; i++) {
+ batch[i] = Sample(rnd);
+ }
+ num_tries = batch_size;
+ }
+ // Compute the expected counts of the batch and the extra values
+ if (batch_expected_count.size() > 0) {
+ CHECK_EQ(batch_size, batch_expected_count.size());
+ for (int i = 0; i < batch_size; i++) {
+ batch_expected_count[i] =
+ ExpectedCountHelper(Probability(batch[i]), batch_size, num_tries);
+ }
+ }
+ CHECK_EQ(extras.size(), extras_expected_count.size());
+ for (size_t i = 0; i < extras.size(); i++) {
+ extras_expected_count[i] =
+ ExpectedCountHelper(Probability(extras[i]), batch_size, num_tries);
+ }
+}
+
+AllSampler::AllSampler(int64 range)
+ : RangeSampler(range), inv_range_(1.0 / range) {}
+
+void AllSampler::SampleBatchGetExpectedCountAvoid(
+ random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
+ MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
+ MutableArraySlice<float> extras_expected_count,
+ ArraySlice<int64> avoided_values) const {
+ const int batch_size = batch.size();
+ CHECK_EQ(range_, batch_size);
+ for (int i = 0; i < batch_size; i++) {
+ batch[i] = i;
+ }
+ if (batch_expected_count.size() > 0) {
+ CHECK_EQ(batch_size, batch_expected_count.size());
+ for (int i = 0; i < batch_size; i++) {
+ batch_expected_count[i] = 1;
+ }
+ }
+ CHECK_EQ(0, avoided_values.size());
+ CHECK_EQ(extras.size(), extras_expected_count.size());
+ for (size_t i = 0; i < extras.size(); i++) {
+ extras_expected_count[i] = 1;
+ }
+}
+
+UniformSampler::UniformSampler(int64 range)
+ : RangeSampler(range), inv_range_(1.0 / range) {}
+
+int64 UniformSampler::Sample(random::SimplePhilox* rnd) const {
+ return rnd->Uniform64(range_);
+}
+
+float UniformSampler::Probability(int64 value) const { return inv_range_; }
+
+LogUniformSampler::LogUniformSampler(int64 range)
+ : RangeSampler(range), log_range_(log(range + 1)) {}
+
+int64 LogUniformSampler::Sample(random::SimplePhilox* rnd) const {
+ const int64 value =
+ static_cast<int64>(exp(rnd->RandDouble() * log_range_)) - 1;
+ CHECK_GE(value, 0);
+ // Mathematically, value should be <= range_, but might not be due to some
+ // floating point roundoff, so we mod by range_.
+ return value % range_;
+}
+
+float LogUniformSampler::Probability(int64 value) const {
+ // value is returned iff the call to UniformDouble(log_range_) in the
+ // Sample() function returns a value between log(value + 1)
+ // and log(value + 2). The probability of this is:
+ // (log(value + 2) - log(value + 1)) / log_range
+ // To avoid two calls to log(), we compute this as follows:
+ return (log((value + 2.0) / (value + 1.0))) / log_range_;
+}
+
+ThreadUnsafeUnigramSampler::ThreadUnsafeUnigramSampler(int64 range)
+ : RangeSampler(range), picker_(range) {
+ CHECK_LT(range, kint32max);
+}
+
+int64 ThreadUnsafeUnigramSampler::Sample(random::SimplePhilox* rnd) const {
+ return picker_.Pick(rnd);
+}
+
+float ThreadUnsafeUnigramSampler::Probability(int64 value) const {
+ return static_cast<float>(picker_.get_weight(value)) / picker_.total_weight();
+}
+
+void ThreadUnsafeUnigramSampler::Update(ArraySlice<int64> values) {
+ int num_updates = std::min(static_cast<int>(values.size()),
+ kint32max - picker_.total_weight());
+ for (int i = 0; i < num_updates; i++) {
+ const int64 value = values[i];
+ picker_.set_weight(value, picker_.get_weight(value) + 1);
+ }
+}
+
+// Thread-safe unigram sampler
+UnigramSampler::UnigramSampler(int64 range)
+ : RangeSampler(range), unsafe_sampler_(range) {
+ CHECK_LT(range, kint32max);
+}
+
+int64 UnigramSampler::Sample(random::SimplePhilox* rnd) const {
+ mutex_lock lock(mu_); // could use reader lock
+ return unsafe_sampler_.Sample(rnd);
+}
+
+float UnigramSampler::Probability(int64 value) const {
+ mutex_lock lock(mu_); // could use reader lock
+ return unsafe_sampler_.Probability(value);
+}
+
+// Overriding at a high level results in far fewer lock aquisitions.
+void UnigramSampler::SampleBatchGetExpectedCountAvoid(
+ random::SimplePhilox* rnd, bool unique, MutableArraySlice<int64> batch,
+ MutableArraySlice<float> batch_expected_count, ArraySlice<int64> extras,
+ MutableArraySlice<float> extras_expected_count,
+ ArraySlice<int64> avoided_values) const {
+ mutex_lock lock(mu_); // could use reader lock
+ unsafe_sampler_.SampleBatchGetExpectedCountAvoid(
+ rnd, unique, batch, batch_expected_count, extras, extras_expected_count,
+ avoided_values);
+}
+
+void UnigramSampler::Update(ArraySlice<int64> values) {
+ mutex_lock lock(mu_);
+ unsafe_sampler_.Update(values);
+}
+
+FixedUnigramSampler::FixedUnigramSampler(Env* env, int64 range,
+ const string& vocab_file,
+ float distortion,
+ int32 num_reserved_ids,
+ int32 num_shards, int32 shard)
+ : RangeSampler(range),
+ total_weight_(0.0),
+ num_shards_(num_shards),
+ shard_(shard) {
+ FillReservedIds(num_reserved_ids);
+ // TODO(vanhoucke): make this non-crashing.
+ TF_CHECK_OK(LoadFromFile(env, vocab_file, distortion));
+ CHECK_EQ(range, weights_.size());
+ dist_sampler_.reset(new random::DistributionSampler(weights_));
+}
+
+FixedUnigramSampler::FixedUnigramSampler(int64 range,
+ const std::vector<float>& unigrams,
+ float distortion,
+ int32 num_reserved_ids,
+ int32 num_shards, int32 shard)
+ : RangeSampler(range),
+ total_weight_(0.0),
+ num_shards_(num_shards),
+ shard_(shard) {
+ FillReservedIds(num_reserved_ids);
+ LoadFromUnigrams(unigrams, distortion);
+ // TODO(vanhoucke): make this non-crashing.
+ CHECK_EQ(range, weights_.size());
+ dist_sampler_.reset(new random::DistributionSampler(weights_));
+}
+
+float FixedUnigramSampler::Probability(int64 value) const {
+ return weights_.at(value) / total_weight_;
+}
+
+int64 FixedUnigramSampler::Sample(random::SimplePhilox* rnd) const {
+ return dist_sampler_->Sample(rnd);
+}
+
+void FixedUnigramSampler::FillReservedIds(int32 num_reserved_ids) {
+ for (int32 word_id = 0; word_id < num_reserved_ids; ++word_id) {
+ if (word_id % num_shards_ == shard_) weights_.push_back(0.0);
+ }
+}
+
+Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file,
+ float distortion) {
+ RandomAccessFile* file;
+ TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file));
+ io::InputBuffer in(file, 262144 /*bytes*/);
+ string line;
+ int32 word_id = weights_.size();
+ while (in.ReadLine(&line).ok()) {
+ // The vocabulary file should be in csv like format, with the last
+ // field the weight associated with the word.
+ std::vector<string> cols = str_util::Split(line, ',');
+ if (cols.size() == 0) continue;
+ // Skip entries that do not belong to this shard.
+ if (word_id % num_shards_ == shard_) {
+ float w = 0.0;
+ if (!strings::safe_strtof(cols.at(cols.size() - 1).c_str(), &w)) {
+ return errors::InvalidArgument("Wrong vocabulary format at line: ",
+ line);
+ }
+ w = pow(w, distortion);
+ total_weight_ += w;
+ weights_.push_back(w);
+ }
+ ++word_id;
+ }
+ return Status::OK();
+}
+
+void FixedUnigramSampler::LoadFromUnigrams(const std::vector<float>& unigrams,
+ float distortion) {
+ int32 word_id = weights_.size();
+ for (float w : unigrams) {
+ // Skip entries that do not belong to this shard.
+ if (word_id % num_shards_ == shard_) {
+ w = pow(w, distortion);
+ total_weight_ += w;
+ weights_.push_back(w);
+ }
+ ++word_id;
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/range_sampler.h b/tensorflow/core/kernels/range_sampler.h
new file mode 100644
index 0000000000..18364c2c03
--- /dev/null
+++ b/tensorflow/core/kernels/range_sampler.h
@@ -0,0 +1,237 @@
+#ifndef TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
+#define TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
+
+#include <vector>
+
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/distribution_sampler.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+#include "tensorflow/core/lib/random/weighted_picker.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+class Env;
+
+// Abstract subclass for sampling from the set of non-negative integers
+// [0, range)
+class RangeSampler {
+ public:
+ explicit RangeSampler(int range) : range_(range) { CHECK_GT(range_, 0); }
+ virtual ~RangeSampler();
+
+ // Sample a single value
+ virtual int64 Sample(random::SimplePhilox* rnd) const = 0;
+
+ // The probability that a single call to Sample() returns the given value.
+ // Assumes that value is in [0, range). No range checking is done.
+ virtual float Probability(int64 value) const = 0;
+
+ // Fill "batch" with samples from the distribution.
+ // If unique=true, then we re-pick each element until we get a
+ // value distinct from all previously picked values in the batch.
+ void SampleBatch(random::SimplePhilox* rnd, bool unique,
+ gtl::MutableArraySlice<int64> batch) const;
+
+ // Fill "batch" with samples from the distribution, and report
+ // "expected counts".
+ //
+ // The "expected count" of a value is an estimate of the expected
+ // number of occurrences of the value in the batch returned by a
+ // call to this function with the given parameters. If unique=true,
+ // the expected count is an inclusion probability. For details on
+ // this estimation, see the comment to "ExpectedCountHelper" in the
+ // .cc file.
+ //
+ // Expected counts for the elements of the returned "batch" are reported
+ // in the aligned array "batch_expected_count".
+ //
+ // The user can optionally provide "extras", containg values in the range.
+ // The expected counts for the extras are reported in the aligned array
+ // "extras_expected_count".
+ //
+ // "batch_expected_count" must have size equal to 0 or to the size of "batch".
+ // "extras" and "extras_expected_count" must have equal size.
+ void SampleBatchGetExpectedCount(
+ random::SimplePhilox* rnd, bool unique,
+ gtl::MutableArraySlice<int64> batch,
+ gtl::MutableArraySlice<float> batch_expected_count,
+ gtl::ArraySlice<int64> extras,
+ gtl::MutableArraySlice<float> extras_expected_count) const;
+
+ // Same as SampleBatchGetExpectedCount (see above), but with avoided values.
+ // We repick to avoid all of the values in "avoided_values".
+ // "avoided_values" is only supported with unique=true. If
+ // unique=false, then avoided_values must be empty.
+ virtual void SampleBatchGetExpectedCountAvoid(
+ random::SimplePhilox* rnd, bool unique,
+ gtl::MutableArraySlice<int64> batch,
+ gtl::MutableArraySlice<float> batch_expected_count,
+ gtl::ArraySlice<int64> extras,
+ gtl::MutableArraySlice<float> extras_expected_count,
+ gtl::ArraySlice<int64> avoided_values) const;
+
+ // Does this sampler need to be updated with values, e.g. UnigramSampler
+ virtual bool NeedsUpdates() const { return false; }
+
+ // Updates the underlying distribution
+ virtual void Update(gtl::ArraySlice<int64> values) {
+ LOG(FATAL) << "Update not supported for this sampler type.";
+ }
+
+ int64 range() { return range_; }
+
+ protected:
+ const int64 range_;
+};
+
+// An AllSampler only samples batches of size equal to range.
+// It returns the entire range.
+// It cannot sample single values.
+class AllSampler : public RangeSampler {
+ public:
+ explicit AllSampler(int64 range);
+
+ ~AllSampler() override {}
+
+ int64 Sample(random::SimplePhilox* rnd) const override {
+ LOG(FATAL) << "Should not be called";
+ }
+
+ float Probability(int64 value) const override {
+ LOG(FATAL) << "Should not be called";
+ }
+
+ void SampleBatchGetExpectedCountAvoid(
+ random::SimplePhilox* rnd, bool unique,
+ gtl::MutableArraySlice<int64> batch,
+ gtl::MutableArraySlice<float> batch_expected_count,
+ gtl::ArraySlice<int64> extras,
+ gtl::MutableArraySlice<float> extras_expected_count,
+ gtl::ArraySlice<int64> avoided_values) const override;
+
+ private:
+ const float inv_range_;
+};
+
+class UniformSampler : public RangeSampler {
+ public:
+ explicit UniformSampler(int64 range);
+
+ ~UniformSampler() override {}
+
+ int64 Sample(random::SimplePhilox* rnd) const override;
+
+ float Probability(int64 value) const override;
+
+ private:
+ const float inv_range_;
+};
+
+class LogUniformSampler : public RangeSampler {
+ public:
+ explicit LogUniformSampler(int64 range);
+
+ ~LogUniformSampler() override {}
+
+ int64 Sample(random::SimplePhilox* rnd) const override;
+
+ float Probability(int64 value) const override;
+
+ private:
+ const double log_range_;
+};
+
+// Thread-unsafe unigram sampler
+class ThreadUnsafeUnigramSampler : public RangeSampler {
+ public:
+ explicit ThreadUnsafeUnigramSampler(int64 range);
+ ~ThreadUnsafeUnigramSampler() override {}
+
+ int64 Sample(random::SimplePhilox* rnd) const override;
+
+ float Probability(int64 value) const override;
+
+ bool NeedsUpdates() const override { return true; }
+ void Update(gtl::ArraySlice<int64> values) override;
+
+ private:
+ random::WeightedPicker picker_;
+};
+
+// Thread-safe unigram sampler
+class UnigramSampler : public RangeSampler {
+ public:
+ explicit UnigramSampler(int64 range);
+ ~UnigramSampler() override {}
+
+ int64 Sample(random::SimplePhilox* rnd) const override;
+
+ float Probability(int64 value) const override;
+
+ // Overriding at a high level results in far fewer lock aquisitions.
+ void SampleBatchGetExpectedCountAvoid(
+ random::SimplePhilox* rnd, bool unique,
+ gtl::MutableArraySlice<int64> batch,
+ gtl::MutableArraySlice<float> batch_expected_count,
+ gtl::ArraySlice<int64> extras,
+ gtl::MutableArraySlice<float> extras_expected_count,
+ gtl::ArraySlice<int64> avoided_values) const override;
+
+ bool NeedsUpdates() const override { return true; }
+ void Update(gtl::ArraySlice<int64> values) override;
+
+ private:
+ ThreadUnsafeUnigramSampler unsafe_sampler_ GUARDED_BY(mu_);
+ mutable mutex mu_;
+};
+
+// A unigram sampler that uses a fixed unigram distribution read from a
+// file or passed in as an in-memory array instead of building up the
+// distribution from data on the fly. There is also an option to skew the
+// distribution by applying a distortion power to the weights.
+class FixedUnigramSampler : public RangeSampler {
+ public:
+ // The vocab_file is assumed to be a CSV, with the last entry of each row a
+ // value representing the counts or probabilities for the corresponding ID.
+ FixedUnigramSampler(Env* env, int64 range, const string& vocab_file,
+ float distortion, int32 num_reserved_ids,
+ int32 num_shards, int32 shard);
+
+ FixedUnigramSampler(int64 range, const std::vector<float>& unigrams,
+ float distortion, int32 num_reserved_ids,
+ int32 num_shards, int32 shard);
+
+ float Probability(int64 value) const override;
+
+ int64 Sample(random::SimplePhilox* rnd) const override;
+
+ private:
+ // Underlying distribution sampler.
+ std::unique_ptr<random::DistributionSampler> dist_sampler_;
+ // Weights for individual samples. The probability of a sample i is defined
+ // as weights_.at(i) / total_weight_.
+ std::vector<float> weights_;
+ // The total weights of all samples.
+ float total_weight_;
+ // Sharding information of the sampler. The whole vocabulary is sharded
+ // into num_shards_ smaller ranges and each sampler is responsible for one
+ // such smaller range, identified by the shard number.
+ int32 num_shards_;
+ int32 shard_;
+
+ // Fill the sampler with the appropriate number of reserved IDs.
+ void FillReservedIds(int32 num_reserved_ids);
+ // Load IDs to sample from a CSV file. It is assumed that the last item of
+ // each row contains a count or probability for the corresponding ID.
+ Status LoadFromFile(Env* env, const string& vocab_file, float distortion);
+ // Load from an in-memory array.
+ void LoadFromUnigrams(const std::vector<float>& unigrams, float distortion);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
diff --git a/tensorflow/core/kernels/range_sampler_test.cc b/tensorflow/core/kernels/range_sampler_test.cc
new file mode 100644
index 0000000000..72c39009e4
--- /dev/null
+++ b/tensorflow/core/kernels/range_sampler_test.cc
@@ -0,0 +1,320 @@
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/kernels/range_sampler.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace {
+
+using gtl::ArraySlice;
+using gtl::MutableArraySlice;
+
+class RangeSamplerTest : public ::testing::Test {
+ protected:
+ void CheckProbabilitiesSumToOne() {
+ double sum = 0;
+ for (int i = 0; i < sampler_->range(); i++) {
+ sum += sampler_->Probability(i);
+ }
+ EXPECT_NEAR(sum, 1.0, 1e-4);
+ }
+ void CheckHistogram(int num_samples, float tolerance) {
+ const int range = sampler_->range();
+ std::vector<int> h(range);
+ std::vector<int64> a(num_samples);
+ // Using a fixed random seed to make the test deterministic.
+ random::PhiloxRandom philox(123, 17);
+ random::SimplePhilox rnd(&philox);
+ sampler_->SampleBatch(&rnd, false, &a);
+ for (int i = 0; i < num_samples; i++) {
+ int64 val = a[i];
+ ASSERT_GE(val, 0);
+ ASSERT_LT(val, range);
+ h[val]++;
+ }
+ for (int val = 0; val < range; val++) {
+ EXPECT_NEAR((h[val] + 0.0) / num_samples, sampler_->Probability(val),
+ tolerance);
+ }
+ }
+ void Update1() {
+ // Add the value 3 ten times.
+ std::vector<int64> a(10);
+ for (int i = 0; i < 10; i++) {
+ a[i] = 3;
+ }
+ sampler_->Update(a);
+ }
+ void Update2() {
+ // Add the value n n times.
+ int64 a[10];
+ for (int i = 0; i < 10; i++) {
+ a[i] = i;
+ }
+ for (int64 i = 1; i < 10; i++) {
+ sampler_->Update(ArraySlice<int64>(a + i, 10 - i));
+ }
+ }
+ std::unique_ptr<RangeSampler> sampler_;
+};
+
+TEST_F(RangeSamplerTest, UniformProbabilities) {
+ sampler_.reset(new UniformSampler(10));
+ for (int i = 0; i < 10; i++) {
+ CHECK_EQ(sampler_->Probability(i), sampler_->Probability(0));
+ }
+}
+
+TEST_F(RangeSamplerTest, UniformChecksum) {
+ sampler_.reset(new UniformSampler(10));
+ CheckProbabilitiesSumToOne();
+}
+
+TEST_F(RangeSamplerTest, UniformHistogram) {
+ sampler_.reset(new UniformSampler(10));
+ CheckHistogram(1000, 0.05);
+}
+
+TEST_F(RangeSamplerTest, LogUniformProbabilities) {
+ int range = 1000000;
+ sampler_.reset(new LogUniformSampler(range));
+ for (int i = 100; i < range; i *= 2) {
+ float ratio = sampler_->Probability(i) / sampler_->Probability(i / 2);
+ EXPECT_NEAR(ratio, 0.5, 0.1);
+ }
+}
+
+TEST_F(RangeSamplerTest, LogUniformChecksum) {
+ sampler_.reset(new LogUniformSampler(10));
+ CheckProbabilitiesSumToOne();
+}
+
+TEST_F(RangeSamplerTest, LogUniformHistogram) {
+ sampler_.reset(new LogUniformSampler(10));
+ CheckHistogram(1000, 0.05);
+}
+
+TEST_F(RangeSamplerTest, UnigramProbabilities1) {
+ sampler_.reset(new UnigramSampler(10));
+ Update1();
+ EXPECT_NEAR(sampler_->Probability(3), 0.55, 1e-4);
+ for (int i = 0; i < 10; i++) {
+ if (i != 3) {
+ ASSERT_NEAR(sampler_->Probability(i), 0.05, 1e-4);
+ }
+ }
+}
+TEST_F(RangeSamplerTest, UnigramProbabilities2) {
+ sampler_.reset(new UnigramSampler(10));
+ Update2();
+ for (int i = 0; i < 10; i++) {
+ ASSERT_NEAR(sampler_->Probability(i), (i + 1) / 55.0, 1e-4);
+ }
+}
+TEST_F(RangeSamplerTest, UnigramChecksum) {
+ sampler_.reset(new UnigramSampler(10));
+ Update1();
+ CheckProbabilitiesSumToOne();
+}
+
+TEST_F(RangeSamplerTest, UnigramHistogram) {
+ sampler_.reset(new UnigramSampler(10));
+ Update1();
+ CheckHistogram(1000, 0.05);
+}
+
+static const char kVocabContent[] =
+ "w1,1\n"
+ "w2,2\n"
+ "w3,4\n"
+ "w4,8\n"
+ "w5,16\n"
+ "w6,32\n"
+ "w7,64\n"
+ "w8,128\n"
+ "w9,256";
+TEST_F(RangeSamplerTest, FixedUnigramProbabilities) {
+ Env* env = Env::Default();
+ string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
+ TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
+ sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
+ // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
+ for (int i = 0; i < 9; i++) {
+ ASSERT_NEAR(sampler_->Probability(i), pow(2, i * 0.8) / 197.05, 1e-4);
+ }
+}
+TEST_F(RangeSamplerTest, FixedUnigramChecksum) {
+ Env* env = Env::Default();
+ string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
+ TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
+ sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
+ CheckProbabilitiesSumToOne();
+}
+
+TEST_F(RangeSamplerTest, FixedUnigramHistogram) {
+ Env* env = Env::Default();
+ string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
+ TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
+ sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
+ CheckHistogram(1000, 0.05);
+}
+TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve1) {
+ Env* env = Env::Default();
+ string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
+ TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
+ sampler_.reset(new FixedUnigramSampler(env, 10, fname, 0.8, 1, 1, 0));
+ ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
+ // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
+ for (int i = 1; i < 10; i++) {
+ ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 1) * 0.8) / 197.05, 1e-4);
+ }
+}
+TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve2) {
+ Env* env = Env::Default();
+ string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
+ TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
+ sampler_.reset(new FixedUnigramSampler(env, 11, fname, 0.8, 2, 1, 0));
+ ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
+ ASSERT_NEAR(sampler_->Probability(1), 0, 1e-4);
+ // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
+ for (int i = 2; i < 11; i++) {
+ ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 2) * 0.8) / 197.05, 1e-4);
+ }
+}
+TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesFromVector) {
+ std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
+ sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
+ // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
+ for (int i = 0; i < 9; i++) {
+ ASSERT_NEAR(sampler_->Probability(i), pow(2, i * 0.8) / 197.05, 1e-4);
+ }
+}
+TEST_F(RangeSamplerTest, FixedUnigramChecksumFromVector) {
+ std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
+ sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
+ CheckProbabilitiesSumToOne();
+}
+TEST_F(RangeSamplerTest, FixedUnigramHistogramFromVector) {
+ std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
+ sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
+ CheckHistogram(1000, 0.05);
+}
+TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve1FromVector) {
+ std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
+ sampler_.reset(new FixedUnigramSampler(10, weights, 0.8, 1, 1, 0));
+ ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
+ // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
+ for (int i = 1; i < 10; i++) {
+ ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 1) * 0.8) / 197.05, 1e-4);
+ }
+}
+TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve2FromVector) {
+ std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
+ sampler_.reset(new FixedUnigramSampler(11, weights, 0.8, 2, 1, 0));
+ ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
+ ASSERT_NEAR(sampler_->Probability(1), 0, 1e-4);
+ // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
+ for (int i = 2; i < 11; i++) {
+ ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 2) * 0.8) / 197.05, 1e-4);
+ }
+}
+
+// AllSampler cannot call Sample or Probability directly.
+// We will test SampleBatchGetExpectedCount instead.
+TEST_F(RangeSamplerTest, All) {
+ int batch_size = 10;
+ sampler_.reset(new AllSampler(10));
+ std::vector<int64> batch(batch_size);
+ std::vector<float> batch_expected(batch_size);
+ std::vector<int64> extras(2);
+ std::vector<float> extras_expected(2);
+ extras[0] = 0;
+ extras[1] = batch_size - 1;
+ sampler_->SampleBatchGetExpectedCount(nullptr, // no random numbers needed
+ false, &batch, &batch_expected, extras,
+ &extras_expected);
+ for (int i = 0; i < batch_size; i++) {
+ EXPECT_EQ(i, batch[i]);
+ EXPECT_EQ(1, batch_expected[i]);
+ }
+ EXPECT_EQ(1, extras_expected[0]);
+ EXPECT_EQ(1, extras_expected[1]);
+}
+
+TEST_F(RangeSamplerTest, Unique) {
+ // We sample num_batches batches, each without replacement.
+ //
+ // We check that the returned expected counts roughly agree with each other
+ // and with the average observed frequencies over the set of batches.
+ random::PhiloxRandom philox(123, 17);
+ random::SimplePhilox rnd(&philox);
+ const int range = 100;
+ const int batch_size = 50;
+ const int num_batches = 100;
+ sampler_.reset(new LogUniformSampler(range));
+ std::vector<int> histogram(range);
+ std::vector<int64> batch(batch_size);
+ std::vector<int64> all_values(range);
+ for (int i = 0; i < range; i++) {
+ all_values[i] = i;
+ }
+ std::vector<float> expected(range);
+
+ // Sample one batch and get the expected counts of all values
+ sampler_->SampleBatchGetExpectedCount(
+ &rnd, true, &batch, MutableArraySlice<float>(), all_values, &expected);
+ // Check that all elements are unique
+ std::set<int64> s(batch.begin(), batch.end());
+ CHECK_EQ(batch_size, s.size());
+
+ for (int trial = 0; trial < num_batches; trial++) {
+ std::vector<float> trial_expected(range);
+ sampler_->SampleBatchGetExpectedCount(&rnd, true, &batch,
+ MutableArraySlice<float>(),
+ all_values, &trial_expected);
+ for (int i = 0; i < range; i++) {
+ EXPECT_NEAR(expected[i], trial_expected[i], expected[i] * 0.5);
+ }
+ for (int i = 0; i < batch_size; i++) {
+ histogram[batch[i]]++;
+ }
+ }
+ for (int i = 0; i < range; i++) {
+ // Check that the computed expected count agrees with the average observed
+ // count.
+ const float average_count = static_cast<float>(histogram[i]) / num_batches;
+ EXPECT_NEAR(expected[i], average_count, 0.2);
+ }
+}
+
+TEST_F(RangeSamplerTest, Avoid) {
+ random::PhiloxRandom philox(123, 17);
+ random::SimplePhilox rnd(&philox);
+ sampler_.reset(new LogUniformSampler(100));
+ std::vector<int64> avoided(2);
+ avoided[0] = 17;
+ avoided[1] = 23;
+ std::vector<int64> batch(98);
+
+ // We expect to pick all elements of [0, 100) except the avoided two.
+ sampler_->SampleBatchGetExpectedCountAvoid(
+ &rnd, true, &batch, MutableArraySlice<float>(), ArraySlice<int64>(),
+ MutableArraySlice<float>(), avoided);
+
+ int sum = 0;
+ for (auto val : batch) {
+ sum += val;
+ }
+ const int expected_sum = 100 * 99 / 2 - avoided[0] - avoided[1];
+ EXPECT_EQ(expected_sum, sum);
+}
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reader_base.cc b/tensorflow/core/kernels/reader_base.cc
new file mode 100644
index 0000000000..06211efb38
--- /dev/null
+++ b/tensorflow/core/kernels/reader_base.cc
@@ -0,0 +1,156 @@
+#include "tensorflow/core/kernels/reader_base.h"
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+
+// ReaderBase ------------------------------------------------------
+
+ReaderBase::ReaderBase(const string& name) : name_(name) {}
+
+int64 ReaderBase::NumRecordsProduced() {
+ mutex_lock lock(mu_);
+ return num_records_produced_;
+}
+
+int64 ReaderBase::NumWorkUnitsCompleted() {
+ mutex_lock lock(mu_);
+ return work_finished_;
+}
+
+Status ReaderBase::Reset() {
+ mutex_lock lock(mu_);
+ return ResetLocked();
+}
+
+Status ReaderBase::ResetLocked() {
+ work_started_ = 0;
+ work_finished_ = 0;
+ num_records_produced_ = 0;
+ work_.clear();
+ return Status::OK();
+}
+
+Status ReaderBase::SerializeState(string* state) {
+ mutex_lock lock(mu_);
+ return SerializeStateLocked(state);
+}
+
+Status ReaderBase::SerializeStateLocked(string* state) {
+ return errors::Unimplemented("Reader SerializeState");
+}
+
+Status ReaderBase::RestoreState(const string& state) {
+ mutex_lock lock(mu_);
+ Status status = RestoreStateLocked(state);
+ if (!status.ok()) {
+ ResetLocked();
+ }
+ return status;
+}
+
+Status ReaderBase::RestoreStateLocked(const string& state) {
+ return errors::Unimplemented("Reader RestoreState");
+}
+
+void ReaderBase::Read(QueueInterface* queue, string* key, string* value,
+ OpKernelContext* context) {
+ mutex_lock lock(mu_);
+ while (true) {
+ if (!work_in_progress()) {
+ GetNextWorkLocked(queue, context);
+ if (!context->status().ok()) return;
+ }
+
+ bool produced = false;
+ bool at_end = false;
+ Status status = ReadLocked(key, value, &produced, &at_end);
+
+ if (!at_end && status.ok() && !produced) {
+ status = errors::Internal(
+ "ReadLocked() for ", name(),
+ " must set *at_end=true, *produced=true, or return an error.");
+ }
+ if (!status.ok() && produced) {
+ status = errors::Internal("ReadLocked() for ", name(),
+ " set *produced=true *and* returned an error: ",
+ status.ToString());
+ }
+ if (status.ok() && at_end) {
+ status = OnWorkFinishedLocked();
+ work_finished_ = work_started_;
+ }
+ if (!status.ok()) {
+ context->SetStatus(status);
+ return;
+ }
+ if (produced) {
+ ++num_records_produced_;
+ return;
+ }
+ }
+}
+
+void ReaderBase::GetNextWorkLocked(QueueInterface* queue,
+ OpKernelContext* context) {
+ Notification n;
+ queue->TryDequeue(
+ context, [this, context, &n](const QueueInterface::Tuple& tuple) {
+ if (context->status().ok()) {
+ if (tuple.size() != 1) {
+ context->SetStatus(
+ errors::InvalidArgument("Expected single component queue"));
+ } else if (tuple[0].dtype() != DT_STRING) {
+ context->SetStatus(errors::InvalidArgument(
+ "Expected queue with single string component"));
+ } else if (tuple[0].NumElements() != 1) {
+ context->SetStatus(errors::InvalidArgument(
+ "Expected to dequeue a one-element string tensor"));
+ } else {
+ work_ = tuple[0].flat<string>()(0);
+ ++work_started_;
+ Status status = OnWorkStartedLocked();
+ if (!status.ok()) {
+ context->SetStatus(status);
+ --work_started_;
+ }
+ }
+ }
+ n.Notify();
+ });
+ n.WaitForNotification();
+}
+
+void ReaderBase::SaveBaseState(ReaderBaseState* state) const {
+ state->Clear();
+ state->set_work_started(work_started_);
+ state->set_work_finished(work_finished_);
+ state->set_num_records_produced(num_records_produced_);
+ state->set_current_work(work_);
+}
+
+Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) {
+ work_started_ = state.work_started();
+ work_finished_ = state.work_finished();
+ num_records_produced_ = state.num_records_produced();
+ work_ = state.current_work();
+ if (work_started_ < 0 || work_finished_ < 0 || num_records_produced_ < 0) {
+ return errors::InvalidArgument(
+ "Unexpected negative value when restoring in ", name(), ": ",
+ state.ShortDebugString());
+ }
+ if (work_started_ > work_finished_) {
+ return errors::InvalidArgument(
+ "Inconsistent work started vs. finished when restoring in ", name(),
+ ": ", state.ShortDebugString());
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reader_base.h b/tensorflow/core/kernels/reader_base.h
new file mode 100644
index 0000000000..d344300388
--- /dev/null
+++ b/tensorflow/core/kernels/reader_base.h
@@ -0,0 +1,107 @@
+#ifndef TENSORFLOW_KERNELS_READER_BASE_H_
+#define TENSORFLOW_KERNELS_READER_BASE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+#include "tensorflow/core/framework/queue_interface.h"
+#include "tensorflow/core/framework/reader_interface.h"
+#include "tensorflow/core/kernels/reader_base.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+
+// Default implementation of ReaderInterface.
+class ReaderBase : public ReaderInterface {
+ public:
+ // name: For use in error messages, should mention both the name of
+ // the op and the node.
+ explicit ReaderBase(const string& name);
+
+ // Note that methods with names ending in "Locked" are called while
+ // the ReaderBase's mutex is held.
+
+ // Implement this function in descendants -----------------------------------
+
+ // Produce the next key/value pair from the current work item.
+ // This is called "Locked" since it is executed under a mutex
+ // that serializes all Reader calls.
+ // Usage:
+ // a) If a record was successfully produced, set *produced = true,
+ // and fill in *key and *value.
+ // b) If no more records will be produced for this work item, set
+ // *at_end = true.
+ // c) If a record was produced, but no more will be produced, you
+ // may either do both (a) and (b), or do (a) in this call and do (b) in
+ // the next call to ReadLocked().
+ // d) If there was an error producing (e.g. an error reading the file,
+ // data corruption), return a non-OK() status. ReadLocked may be
+ // called again if the user reruns this part of the graph.
+ virtual Status ReadLocked(string* key, string* value, bool* produced,
+ bool* at_end) = 0;
+
+ // Descendants may optionally implement these -------------------------------
+
+ // Called when work starts / finishes.
+ virtual Status OnWorkStartedLocked() { return Status::OK(); }
+ virtual Status OnWorkFinishedLocked() { return Status::OK(); }
+
+ // Called to reset the Reader to a newly constructed state.
+ virtual Status ResetLocked();
+
+ // Default implementation generates an Unimplemented error.
+ // See the protected helper methods below.
+ virtual Status SerializeStateLocked(string* state);
+ virtual Status RestoreStateLocked(const string& state);
+
+ // Accessors ----------------------------------------------------------------
+
+ // Always true during a call to ReadLocked().
+ bool work_in_progress() const { return work_finished_ < work_started_; }
+
+ // Returns the name of the current work item (valid if
+ // work_in_progress() returns true). May change between calls to
+ // ReadLocked().
+ const string& current_work() const { return work_; }
+
+ // What was passed to the constructor.
+ const string& name() const { return name_; }
+
+ protected:
+ // For descendants wishing to implement serialize & restore state.
+
+ // Writes ReaderBase state to *state.
+ void SaveBaseState(ReaderBaseState* state) const;
+
+ // Restores ReaderBase state from state. Assumes state was filled
+ // using SaveBaseState() above.
+ Status RestoreBaseState(const ReaderBaseState& state);
+
+ private:
+ // Implementations of ReaderInterface methods. These ensure thread-safety
+ // and call the methods above to do the work.
+ void Read(QueueInterface* queue, string* key, string* value,
+ OpKernelContext* context) override;
+ Status Reset() override;
+ int64 NumRecordsProduced() override;
+ int64 NumWorkUnitsCompleted() override;
+ Status SerializeState(string* state) override;
+ Status RestoreState(const string& state) override;
+
+ // For implementing Read(). Dequeues the next work item from
+ // *queue, and if successful updates work_, work_started_
+ // (establishing work_in_progress() == true) and calls
+ // OnWorkStartedLocked(). May block.
+ void GetNextWorkLocked(QueueInterface* queue, OpKernelContext* context);
+
+ mutable mutex mu_;
+ const string name_;
+ int64 work_started_ = 0;
+ int64 work_finished_ = 0;
+ int64 num_records_produced_ = 0;
+ string work_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_READER_BASE_H_
diff --git a/tensorflow/core/kernels/reader_base.proto b/tensorflow/core/kernels/reader_base.proto
new file mode 100644
index 0000000000..4335cb2152
--- /dev/null
+++ b/tensorflow/core/kernels/reader_base.proto
@@ -0,0 +1,13 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+// For serializing and restoring the state of ReaderBase, see
+// reader_base.h for details.
+message ReaderBaseState {
+ int64 work_started = 1;
+ int64 work_finished = 2;
+ int64 num_records_produced = 3;
+ bytes current_work = 4;
+};
diff --git a/tensorflow/core/kernels/reader_ops.cc b/tensorflow/core/kernels/reader_ops.cc
new file mode 100644
index 0000000000..38c1013604
--- /dev/null
+++ b/tensorflow/core/kernels/reader_ops.cc
@@ -0,0 +1,132 @@
+// See docs in ../ops/io_ops.cc.
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/queue_interface.h"
+#include "tensorflow/core/framework/reader_interface.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+class ReaderVerbOpKernel : public OpKernel {
+ public:
+ using OpKernel::OpKernel;
+
+ void Compute(OpKernelContext* context) override {
+ ReaderInterface* reader;
+ OP_REQUIRES_OK(context,
+ GetResourceFromContext(context, "reader_handle", &reader));
+ ComputeWithReader(context, reader);
+ reader->Unref();
+ }
+
+ protected:
+ virtual void ComputeWithReader(OpKernelContext* context,
+ ReaderInterface* reader) = 0;
+};
+
+class ReaderReadOp : public ReaderVerbOpKernel {
+ public:
+ using ReaderVerbOpKernel::ReaderVerbOpKernel;
+
+ void ComputeWithReader(OpKernelContext* context,
+ ReaderInterface* reader) override {
+ QueueInterface* queue;
+ OP_REQUIRES_OK(context,
+ GetResourceFromContext(context, "queue_handle", &queue));
+ core::ScopedUnref unref_me(queue);
+ Tensor* key = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("key", TensorShape({}), &key));
+ Tensor* value = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("value", TensorShape({}), &value));
+
+ auto key_scalar = key->scalar<string>();
+ auto value_scalar = value->scalar<string>();
+ reader->Read(queue, &key_scalar(), &value_scalar(), context);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ReaderRead").Device(DEVICE_CPU), ReaderReadOp);
+
+class ReaderNumRecordsProducedOp : public ReaderVerbOpKernel {
+ public:
+ using ReaderVerbOpKernel::ReaderVerbOpKernel;
+
+ void ComputeWithReader(OpKernelContext* context,
+ ReaderInterface* reader) override {
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output("records_produced",
+ TensorShape({}), &output));
+ output->scalar<int64>()() = reader->NumRecordsProduced();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ReaderNumRecordsProduced").Device(DEVICE_CPU),
+ ReaderNumRecordsProducedOp);
+
+class ReaderNumWorkUnitsCompletedOp : public ReaderVerbOpKernel {
+ public:
+ using ReaderVerbOpKernel::ReaderVerbOpKernel;
+
+ void ComputeWithReader(OpKernelContext* context,
+ ReaderInterface* reader) override {
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output("units_completed",
+ TensorShape({}), &output));
+ output->scalar<int64>()() = reader->NumWorkUnitsCompleted();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ReaderNumWorkUnitsCompleted").Device(DEVICE_CPU),
+ ReaderNumWorkUnitsCompletedOp);
+
+class ReaderSerializeStateOp : public ReaderVerbOpKernel {
+ public:
+ using ReaderVerbOpKernel::ReaderVerbOpKernel;
+
+ void ComputeWithReader(OpKernelContext* context,
+ ReaderInterface* reader) override {
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("state", TensorShape({}), &output));
+ OP_REQUIRES_OK(context,
+ reader->SerializeState(&output->scalar<string>()()));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ReaderSerializeState").Device(DEVICE_CPU),
+ ReaderSerializeStateOp);
+
+class ReaderRestoreStateOp : public ReaderVerbOpKernel {
+ public:
+ using ReaderVerbOpKernel::ReaderVerbOpKernel;
+
+ void ComputeWithReader(OpKernelContext* context,
+ ReaderInterface* reader) override {
+ const Tensor* tensor;
+ OP_REQUIRES_OK(context, context->input("state", &tensor));
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(tensor->shape()),
+ errors::InvalidArgument("Reader state must be scalar, but had shape: ",
+ tensor->shape().DebugString()));
+ OP_REQUIRES_OK(context, reader->RestoreState(tensor->scalar<string>()()));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ReaderRestoreState").Device(DEVICE_CPU),
+ ReaderRestoreStateOp);
+
+class ReaderResetOp : public ReaderVerbOpKernel {
+ public:
+ using ReaderVerbOpKernel::ReaderVerbOpKernel;
+
+ void ComputeWithReader(OpKernelContext* context,
+ ReaderInterface* reader) override {
+ OP_REQUIRES_OK(context, reader->Reset());
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ReaderReset").Device(DEVICE_CPU), ReaderResetOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops.h b/tensorflow/core/kernels/reduction_ops.h
new file mode 100644
index 0000000000..b412617a65
--- /dev/null
+++ b/tensorflow/core/kernels/reduction_ops.h
@@ -0,0 +1,66 @@
+#ifndef TENSORFLOW_KERNELS_REDUCTION_OPS_H_
+#define TENSORFLOW_KERNELS_REDUCTION_OPS_H_
+
+// Functor definitions for Reduction ops, must be compilable by nvcc.
+
+#include <iostream>
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+// When eigen3 has better implementation of AllReducer and AnyReducer,
+// replaces reducers here.
+
+// Reduction using logical_and.
+struct AllReducer {
+ // TODO(zhifengc): Implement PacketAccess when performance matters.
+ static const bool PacketAccess = false;
+ static const bool IsStateful = false;
+
+ EIGEN_DEVICE_FUNC void reduce(const bool t, bool* accum) const {
+ *accum &= t;
+ }
+
+ EIGEN_DEVICE_FUNC bool initialize() const { return true; }
+
+ EIGEN_DEVICE_FUNC bool finalize(const bool accum) const { return accum; }
+};
+
+// Reduction using logical_or.
+struct AnyReducer {
+ // TODO(zhifengc): Implement PacketAccess when performance matters.
+ static const bool PacketAccess = false;
+ static const bool IsStateful = false;
+
+ EIGEN_DEVICE_FUNC void reduce(const bool t, bool* accum) const {
+ *accum |= t;
+ }
+
+ EIGEN_DEVICE_FUNC bool initialize() const { return false; }
+
+ EIGEN_DEVICE_FUNC bool finalize(const bool accum) const { return accum; }
+};
+
+template <typename Device, typename OUT_T, typename IN_T,
+ typename ReductionAxes, typename Reducer>
+void ReduceEigenImpl(const Device& d, OUT_T out, IN_T in,
+ const ReductionAxes& reduction_axes,
+ const Reducer& reducer) {
+ out.device(d) = in.reduce(reduction_axes, reducer);
+}
+
+template <typename Device>
+struct ReduceFunctor {
+ template <typename OUT_T, typename IN_T, typename ReductionAxes,
+ typename Reducer>
+ static void Reduce(const Device& d, OUT_T out, IN_T in,
+ const ReductionAxes& reduction_axes,
+ const Reducer& reducer);
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_REDUCTION_OPS_H_
diff --git a/tensorflow/core/kernels/reduction_ops_all.cc b/tensorflow/core/kernels/reduction_ops_all.cc
new file mode 100644
index 0000000000..11d399e70a
--- /dev/null
+++ b/tensorflow/core/kernels/reduction_ops_all.cc
@@ -0,0 +1,17 @@
+#include "tensorflow/core/kernels/reduction_ops_common.h"
+
+namespace tensorflow {
+
+REGISTER_KERNEL_BUILDER(Name("All")
+ .Device(DEVICE_CPU)
+ .HostMemory("reduction_indices"),
+ ReductionOp<CPUDevice, bool, functor::AllReducer>);
+
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("All")
+ .Device(DEVICE_GPU)
+ .HostMemory("reduction_indices"),
+ ReductionOp<GPUDevice, bool, functor::AllReducer>);
+#endif
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_any.cc b/tensorflow/core/kernels/reduction_ops_any.cc
new file mode 100644
index 0000000000..a89ef22b08
--- /dev/null
+++ b/tensorflow/core/kernels/reduction_ops_any.cc
@@ -0,0 +1,17 @@
+#include "tensorflow/core/kernels/reduction_ops_common.h"
+
+namespace tensorflow {
+
+REGISTER_KERNEL_BUILDER(Name("Any")
+ .Device(DEVICE_CPU)
+ .HostMemory("reduction_indices"),
+ ReductionOp<CPUDevice, bool, functor::AnyReducer>);
+
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("Any")
+ .Device(DEVICE_GPU)
+ .HostMemory("reduction_indices"),
+ ReductionOp<GPUDevice, bool, functor::AnyReducer>);
+#endif
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h
new file mode 100644
index 0000000000..2bde3a1a54
--- /dev/null
+++ b/tensorflow/core/kernels/reduction_ops_common.h
@@ -0,0 +1,302 @@
+// This is an internal header file intended to only be included as the
+// front-matter in the implementation files of various reduction ops. It
+// is a header file because we split the various reduction ops into their
+// own compilation units to get more parallelism in compilation.
+
+#ifndef TENSORFLOW_KERNELS_REDUCTION_OPS_COMMON_H_
+#define TENSORFLOW_KERNELS_REDUCTION_OPS_COMMON_H_
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/reduction_ops.h"
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device>
+struct Constants {
+ // Derive Index type. int (32-bit) or long (64-bit) depending on the
+ // compile-time configuration. "float" here is not relevant.
+ // TODO(zhifengc): Moves the definition to TTypes.
+ typedef TTypes<float>::Tensor::Index Index;
+ Eigen::array<Index, 1> kZero;
+ Eigen::array<Index, 1> kOne;
+ Eigen::array<Index, 2> kZeroTwo;
+
+ Constants() {
+ kZero[0] = 0;
+ kOne[0] = 1;
+ kZeroTwo[0] = 0;
+ kZeroTwo[1] = 2;
+ }
+};
+
+#if defined(EIGEN_HAS_INDEX_LIST)
+template <>
+struct Constants<CPUDevice> {
+ const Eigen::IndexList<Eigen::type2index<0>> kZero;
+ const Eigen::IndexList<Eigen::type2index<1>> kOne;
+ const Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<2>> kZeroTwo;
+};
+#endif
+
+namespace {
+
+class ReductionHelper {
+ public:
+ ReductionHelper() : reduce_first_axis_(false) {}
+
+ Status Simplify(const Tensor& data, const Tensor& axis,
+ const bool keep_dims) {
+ // bitmap[i] indicates whether to reduce data along i-th axis.
+ std::vector<bool> bitmap(data.dims(), false);
+ auto axis_vec = axis.flat<int32>();
+ for (int64 i = 0; i < axis.NumElements(); ++i) {
+ const int32 index = axis_vec(i);
+ if (index < 0 || index >= data.dims()) {
+ return errors::OutOfRange("Invalid reduction dimension (", index,
+ " for input with ", data.dims(),
+ " dimension(s)");
+ }
+ bitmap[index] = true;
+ }
+
+ // Output tensor's dim sizes.
+ out_shape_.clear();
+ for (int i = 0; i < data.dims(); ++i) {
+ if (!bitmap[i]) {
+ // If we are not reducing along dimension i.
+ out_shape_.push_back(data.dim_size(i));
+ } else if (keep_dims) {
+ // We are reducing along dimension i, but we want to keep the
+ // same number of dimensions, so we set the dimension of i to
+ // '1'.
+ out_shape_.push_back(1);
+ }
+ }
+
+ // Depending on bitmap[i] and bitmap[i-1], we can collapse axis of
+ // the input data before doing the reduction on the resulting
+ // tensor. The shape of the reduction is a reshape of the final
+ // output.
+
+ // We'll skip the leading 1s.
+ int dim_index = 0;
+ for (; dim_index < data.dims(); ++dim_index) {
+ if (data.dim_size(dim_index) != 1) break;
+ }
+ if (dim_index >= data.dims()) {
+ // Special case. The input is essentially a scalar.
+ reduce_first_axis_ = true;
+ } else {
+ // Starting from the (dim_index)-th dimension, dimensions
+ // alternates between runs that need to be reduced and runs that
+ // don't.
+ //
+ // NOTE: If a dimension has size 1, we group it as the current
+ // run so that we can minimize the number of runs.
+ //
+ // E.g., when we want to reduce a tensor of shape [2, 1, 3, 1,
+ // 5] by axes = [1, 4], we should treat the tensor as a [6, 5]
+ // and reduce by axes = [1] (i.e., the output is shape [6]).
+ reduce_first_axis_ = bitmap[dim_index];
+ data_reshape_.push_back(data.dim_size(dim_index));
+ ++dim_index;
+ for (; dim_index < data.dims(); ++dim_index) {
+ const auto size = data.dim_size(dim_index);
+ if (size == 1) {
+ bitmap[dim_index] = bitmap[dim_index - 1];
+ }
+ if (bitmap[dim_index - 1] != bitmap[dim_index]) {
+ // Starts a new run of reduce or !reduce.
+ data_reshape_.push_back(size);
+ } else {
+ // Continue a run of reduce or !reduce.
+ data_reshape_.back() *= size;
+ }
+ }
+ // If reduce_first_axis_ is true (input's dimension 0, 2, 4, etc
+ // are reduced), data_reshape_[1, 3, 5, ...] is out_reshape_,
+ // otherwise, data_reshape_[0, 2, 4, ...] is.
+ for (size_t i = reduce_first_axis_ ? 1 : 0; i < data_reshape_.size();
+ i += 2) {
+ out_reshape_.push_back(data_reshape_[i]);
+ }
+ }
+
+ VLOG(1) << "data reshape: " << str_util::Join(data_reshape_, ",");
+ VLOG(1) << "out reshape: " << str_util::Join(out_reshape_, ",");
+ VLOG(1) << "out shape: " << str_util::Join(out_shape_, ",");
+ return Status::OK();
+ }
+
+ // We need to do roughly:
+ // tmp_out = allocate(out_reshape())
+ // tmp_out.reshape(out_reshape) = data.reshape(data_reshape).reduce(axes)
+ // out = tmp_out.reshape(out_shape)
+
+ // The reduction result must be allocated with this shape.
+ TensorShape out_reshape() const {
+ TensorShape shape;
+ for (auto size : out_reshape_) shape.AddDim(size);
+ return shape;
+ }
+
+ // The final output shape must be allocated with this shape.
+ TensorShape out_shape() const {
+ TensorShape shape;
+ for (auto size : out_shape_) shape.AddDim(size);
+ return shape;
+ }
+
+ // The reduction is on a reshaped tensor of this rank.
+ int ndims() const { return data_reshape_.size(); }
+
+ // True if need to reduce the 0-th dimension.
+ bool reduce_first_axis() const { return reduce_first_axis_; }
+
+ // The output is reshaped.
+ template <typename T, int N>
+ typename TTypes<T, N>::Tensor out(Tensor* out) {
+ return out->shaped<T, N>(out_reshape_);
+ }
+
+ // The input is reshaped.
+ template <typename T, int N>
+ typename TTypes<T, N>::ConstTensor in(const Tensor& data) {
+ return data.shaped<T, N>(data_reshape_);
+ }
+
+ private:
+ bool reduce_first_axis_; // True if need to reduce the 0-th dimension.
+ std::vector<int64> data_reshape_; // Reshape the data before reduction.
+ std::vector<int64> out_shape_; // The final output shape.
+ std::vector<int64> out_reshape_; // Reshape the output for reduction.
+};
+
+} // end namespace
+
+// For operations where the output is a reduction function along some
+// dimensions of the input.
+template <typename Device, class T, typename Reducer>
+class ReductionOp : public OpKernel {
+ public:
+ explicit ReductionOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ const DataType dt = DataTypeToEnum<T>::v();
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt}));
+
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& data = ctx->input(0);
+ const Tensor& axes = ctx->input(1);
+ VLOG(1) << "data shape: " << data.shape().ShortDebugString();
+ VLOG(1) << "axes : " << axes.SummarizeValue(10);
+
+ ReductionHelper helper;
+ OP_REQUIRES_OK(ctx, helper.Simplify(data, axes, keep_dims_));
+ CHECK_GE(helper.ndims(), 0);
+
+ // The real output shape will be assigned below.
+ TensorShape empty_shape;
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, empty_shape, &out));
+
+ if (helper.ndims() == 0 ||
+ (helper.ndims() == 1 && !helper.reduce_first_axis())) {
+ // Special case. Reduces nothing. It is unclear why this is
+ // necessary, but tests fail without it. Look into why this
+ // case occurs.
+ if (!out->CopyFrom(data, helper.out_shape())) {
+ ctx->SetStatus(errors::Internal("Error during reduction copy."));
+ }
+ return;
+ }
+
+ // A temporary tensor whose size matches the size of the reduced
+ // output.
+ Tensor tmp_out;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(out->dtype(), helper.out_reshape(), &tmp_out));
+
+ typedef functor::ReduceFunctor<Device> Functor;
+ Constants<Device> constants;
+ const Device& d = ctx->eigen_device<Device>();
+ Reducer reducer;
+
+ if ((helper.ndims() == 1) && helper.reduce_first_axis()) {
+ // Reduce to a scalar.
+ Functor::Reduce(d, helper.out<T, 0>(&tmp_out), helper.in<T, 1>(data),
+ constants.kZero, reducer);
+ } else if ((helper.ndims() == 2) && helper.reduce_first_axis()) {
+ // Can be viewed as a reduction of a matrix along 1st dimension.
+ Functor::Reduce(d, helper.out<T, 1>(&tmp_out), helper.in<T, 2>(data),
+ constants.kZero, reducer);
+ } else if ((helper.ndims() == 2) && !helper.reduce_first_axis()) {
+ // Can be viewed as a reduction of a matrix along 2nd dimension.
+ Functor::Reduce(d, helper.out<T, 1>(&tmp_out), helper.in<T, 2>(data),
+ constants.kOne, reducer);
+ } else if ((helper.ndims() == 3) && helper.reduce_first_axis()) {
+ // Can be viewed as a reduction of a 3D tensor along 1st and 3rd
+ // dimensions.
+ Functor::Reduce(d, helper.out<T, 1>(&tmp_out), helper.in<T, 3>(data),
+ constants.kZeroTwo, reducer);
+ } else if ((helper.ndims() == 3) && !helper.reduce_first_axis()) {
+ // Can be viewed as a reduction of a 3D tensor along 2nd dimension.
+ Functor::Reduce(d, helper.out<T, 2>(&tmp_out), helper.in<T, 3>(data),
+ constants.kOne, reducer);
+ } else {
+ // TODO(zhifengc): We can implement reduction for arbitrary rank
+ // tensor and arbitrary reduction axes by iterating the reduction
+ // multiple times. This may also be accomplished in the graph
+ // construction.
+ ctx->SetStatus(
+ errors::Unimplemented("Reducing ", data.shape().ShortDebugString(),
+ " axes [", axes.SummarizeValue(10), "] to ",
+ tmp_out.shape().ShortDebugString()));
+ return;
+ }
+
+ // Set the real output using the contents of the reduction but the
+ // real expected output shape. The number of elements should
+ // match between the two shapes.
+ if (!out->CopyFrom(tmp_out, helper.out_shape())) {
+ ctx->SetStatus(errors::Internal("Error during reduction copy."));
+ }
+ }
+
+ private:
+ // True if the number of dimensions should be maintained.
+ bool keep_dims_;
+};
+
+namespace functor {
+
+template <>
+struct ReduceFunctor<CPUDevice> {
+ template <typename OUT_T, typename IN_T, typename ReductionAxes,
+ typename Reducer>
+ static void Reduce(const CPUDevice& d, OUT_T out, IN_T in,
+ const ReductionAxes& reduction_axes,
+ const Reducer& reducer) {
+ ReduceEigenImpl(d, out, in, reduction_axes, reducer);
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_REDUCTION_OPS_COMMON_H_
diff --git a/tensorflow/core/kernels/reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu.cu.cc
new file mode 100644
index 0000000000..8e29d2d06c
--- /dev/null
+++ b/tensorflow/core/kernels/reduction_ops_gpu.cu.cc
@@ -0,0 +1,65 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/framework/numeric_types.h"
+#include "tensorflow/core/kernels/reduction_ops.h"
+
+namespace tensorflow {
+namespace functor {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// Derive Index type. int (32-bit) or long (64-bit) depending on the
+// compile-time configuration. "float" here is not relevant.
+// TODO(zhifengc): Moves the definition to TTypes.
+typedef TTypes<float>::Tensor::Index Index;
+
+template <>
+struct ReduceFunctor<GPUDevice> {
+ template <typename OUT_T, typename IN_T, typename ReductionAxes,
+ typename Reducer>
+ static void Reduce(const GPUDevice& d, OUT_T out, IN_T in,
+ const ReductionAxes& reduction_axes,
+ const Reducer& reducer) {
+ ReduceEigenImpl(d, To32Bit(out), To32Bit(in), reduction_axes, reducer);
+ }
+};
+
+// T: the data type
+// REDUCER: the reducer functor
+// NUM_AXES: the number of axes to reduce
+// IN_DIMS: the number of dimensions of the input tensor
+#define DEFINE(T, REDUCER, IN_DIMS, NUM_AXES) \
+ template void ReduceFunctor<GPUDevice>::Reduce( \
+ const GPUDevice& d, TTypes<T, IN_DIMS - NUM_AXES>::Tensor out, \
+ TTypes<T, IN_DIMS>::ConstTensor in, \
+ const Eigen::array<Index, NUM_AXES>& reduction_axes, \
+ const REDUCER& reducer);
+
+#define DEFINE_FOR_TYPE_AND_R(T, R) \
+ DEFINE(T, R, 1, 1); \
+ DEFINE(T, R, 2, 1); \
+ DEFINE(T, R, 3, 1); \
+ DEFINE(T, R, 3, 2);
+
+#define DEFINE_FOR_ALL_REDUCERS(T) \
+ DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::SumReducer<T>); \
+ DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MinReducer<T>); \
+ DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::MaxReducer<T>); \
+ DEFINE_FOR_TYPE_AND_R(T, Eigen::internal::ProdReducer<T>)
+
+DEFINE_FOR_ALL_REDUCERS(float);
+#undef DEFINE_FOR_ALL_REDUCERS
+
+DEFINE_FOR_TYPE_AND_R(complex64, Eigen::internal::SumReducer<complex64>);
+DEFINE_FOR_TYPE_AND_R(bool, AllReducer);
+DEFINE_FOR_TYPE_AND_R(bool, AnyReducer);
+#undef DEFINE_FOR_TYPE_AND_R
+
+#undef DEFINE
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc
new file mode 100644
index 0000000000..1749360b6e
--- /dev/null
+++ b/tensorflow/core/kernels/reduction_ops_max.cc
@@ -0,0 +1,26 @@
+#include "tensorflow/core/kernels/reduction_ops_common.h"
+
+namespace tensorflow {
+
+#define REGISTER_CPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Max").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReductionOp<CPUDevice, type, Eigen::internal::MaxReducer<type>>);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
+#undef REGISTER_CPU_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Max") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<GPUDevice, type, Eigen::internal::MaxReducer<type>>);
+REGISTER_GPU_KERNELS(float);
+#undef REGISTER_GPU_KERNELS
+
+#endif
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_mean.cc b/tensorflow/core/kernels/reduction_ops_mean.cc
new file mode 100644
index 0000000000..b00c36fed8
--- /dev/null
+++ b/tensorflow/core/kernels/reduction_ops_mean.cc
@@ -0,0 +1,12 @@
+#include "tensorflow/core/kernels/reduction_ops_common.h"
+
+namespace tensorflow {
+
+#define REGISTER_CPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Mean").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReductionOp<CPUDevice, type, Eigen::internal::MeanReducer<type>>);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
+#undef REGISTER_CPU_KERNELS
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_min.cc b/tensorflow/core/kernels/reduction_ops_min.cc
new file mode 100644
index 0000000000..de1f4b8520
--- /dev/null
+++ b/tensorflow/core/kernels/reduction_ops_min.cc
@@ -0,0 +1,26 @@
+#include "tensorflow/core/kernels/reduction_ops_common.h"
+
+namespace tensorflow {
+
+#define REGISTER_CPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Min").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReductionOp<CPUDevice, type, Eigen::internal::MinReducer<type>>);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
+#undef REGISTER_CPU_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Min") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<GPUDevice, type, Eigen::internal::MinReducer<type>>);
+REGISTER_GPU_KERNELS(float);
+#undef REGISTER_GPU_KERNELS
+
+#endif
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_prod.cc b/tensorflow/core/kernels/reduction_ops_prod.cc
new file mode 100644
index 0000000000..4068c7feda
--- /dev/null
+++ b/tensorflow/core/kernels/reduction_ops_prod.cc
@@ -0,0 +1,26 @@
+#include "tensorflow/core/kernels/reduction_ops_common.h"
+
+namespace tensorflow {
+
+#define REGISTER_CPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Prod").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReductionOp<CPUDevice, type, Eigen::internal::ProdReducer<type>>);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
+#undef REGISTER_CPU_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Prod") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<GPUDevice, type, Eigen::internal::ProdReducer<type>>);
+REGISTER_GPU_KERNELS(float);
+#undef REGISTER_GPU_KERNELS
+
+#endif
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc
new file mode 100644
index 0000000000..82d685e225
--- /dev/null
+++ b/tensorflow/core/kernels/reduction_ops_sum.cc
@@ -0,0 +1,37 @@
+#include "tensorflow/core/kernels/reduction_ops_common.h"
+
+namespace tensorflow {
+
+#define REGISTER_CPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Sum").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReductionOp<CPUDevice, type, Eigen::internal::SumReducer<type>>);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
+#undef REGISTER_CPU_KERNELS
+
+// NOTE: We should have mean(complex64,int32), too. But that needs to
+// change Eigen::internal::MeanReducer to cast int to complex<float>.
+// We don't see immediate need of mean(complex64,int32) anyway.
+REGISTER_KERNEL_BUILDER(
+ Name("Sum").Device(DEVICE_CPU).TypeConstraint<complex64>("T"),
+ ReductionOp<CPUDevice, complex64, Eigen::internal::SumReducer<complex64>>);
+
+#if GOOGLE_CUDA
+
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Sum") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<GPUDevice, type, Eigen::internal::SumReducer<type>>);
+REGISTER_GPU_KERNELS(float);
+#undef REGISTER_GPU_KERNELS
+
+REGISTER_KERNEL_BUILDER(
+ Name("Sum").Device(DEVICE_GPU).TypeConstraint<complex64>("T"),
+ ReductionOp<GPUDevice, complex64, Eigen::internal::SumReducer<complex64>>);
+
+#endif
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_test.cc b/tensorflow/core/kernels/reduction_ops_test.cc
new file mode 100644
index 0000000000..d96da3c7f1
--- /dev/null
+++ b/tensorflow/core/kernels/reduction_ops_test.cc
@@ -0,0 +1,73 @@
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+// Creates a Graph which "reduce"s a 3D float tensor of "num" elements
+// into a scalar.
+static Graph* ToScalar(const string& reduce, int num) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor data(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)}));
+ data.flat<float>().setRandom();
+ Tensor axes(DT_INT32, TensorShape({3}));
+ axes.flat<int32>()(0) = 0;
+ axes.flat<int32>()(1) = 1;
+ axes.flat<int32>()(2) = 2;
+ test::graph::Reduce(g, reduce, test::graph::Constant(g, data),
+ test::graph::Constant(g, axes));
+ return g;
+}
+
+// Creates a bench which reduces a 3D tensor with total "num" floats
+// into a scalar on a "device". Runs the bench for "iters" times.
+static void ReduceToScalar(int iters, const string& device,
+ const string& reduce, int num) {
+ testing::ItemsProcessed(static_cast<int64>(iters) * num);
+ testing::BytesProcessed(static_cast<int64>(iters) * num * sizeof(float));
+ test::Benchmark(device, ToScalar(reduce, num)).Run(iters);
+}
+
+static void BM_Sum3DToScalarCPU(int iters, int num) {
+ ReduceToScalar(iters, "cpu", "Sum", num);
+}
+BENCHMARK(BM_Sum3DToScalarCPU)->Range(1 << 13, 1 << 20);
+
+static void BM_Max3DToScalarCPU(int iters, int num) {
+ ReduceToScalar(iters, "cpu", "Max", num);
+}
+BENCHMARK(BM_Max3DToScalarCPU)->Range(1 << 13, 1 << 20);
+
+static void BM_Prod3DToScalarCPU(int iters, int num) {
+ ReduceToScalar(iters, "cpu", "Prod", num);
+}
+BENCHMARK(BM_Prod3DToScalarCPU)->Range(1 << 13, 1 << 20);
+
+static void BM_Mean3DToScalarCPU(int iters, int num) {
+ ReduceToScalar(iters, "cpu", "Mean", num);
+}
+BENCHMARK(BM_Mean3DToScalarCPU)->Range(1 << 13, 1 << 20);
+
+static void BM_Sum3DToScalarGPU(int iters, int num) {
+ ReduceToScalar(iters, "gpu", "Sum", num);
+}
+BENCHMARK(BM_Sum3DToScalarGPU)->Range(1 << 13, 1 << 20);
+
+static void BM_Max3DToScalarGPU(int iters, int num) {
+ ReduceToScalar(iters, "gpu", "Max", num);
+}
+BENCHMARK(BM_Max3DToScalarGPU)->Range(1 << 13, 1 << 20);
+
+static void BM_Prod3DToScalarGPU(int iters, int num) {
+ ReduceToScalar(iters, "gpu", "Prod", num);
+}
+BENCHMARK(BM_Prod3DToScalarGPU)->Range(1 << 13, 1 << 20);
+
+// Once Mean is available on GPU, enable this.
+// static void BM_Mean3DToScalarGPU(int iters, int num) {
+// ReduceToScalar(iters, "gpu", "Mean", num);
+// }
+// BENCHMARK(BM_Mean3DToScalarGPU)->Range(1 << 13, 1 << 20);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/reference_gemm.h b/tensorflow/core/kernels/reference_gemm.h
new file mode 100644
index 0000000000..77c6ef35e9
--- /dev/null
+++ b/tensorflow/core/kernels/reference_gemm.h
@@ -0,0 +1,75 @@
+#ifndef TENSORFLOW_KERNELS_REFERENCE_GEMM_H_
+#define TENSORFLOW_KERNELS_REFERENCE_GEMM_H_
+
+// This is an unoptimized but debuggable implementation of the GEMM matrix
+// multiply function, used to compare to faster but more opaque versions, or
+// for bit depths or argument combinations that aren't supported by optimized
+// code.
+// It assumes the row-major convention used by TensorFlow, and implements
+// C = A * B, like the standard BLAS GEMM interface. If the tranpose flags are
+// true, then the relevant matrix is treated as stored in column-major order.
+
+namespace tensorflow {
+template <class T1, class T2, class T3>
+void ReferenceGemm(bool transpose_a, bool transpose_b, bool transpose_c,
+ size_t m, size_t n, size_t k, const T1* a, T1 offset_a,
+ size_t lda, const T2* b, T2 offset_b, size_t ldb, T3* c,
+ int32 shift_c, int32 offset_c, int32 mult_c, size_t ldc) {
+ int a_i_stride;
+ int a_l_stride;
+ if (transpose_a) {
+ a_i_stride = 1;
+ a_l_stride = lda;
+ } else {
+ a_i_stride = lda;
+ a_l_stride = 1;
+ }
+ int b_j_stride;
+ int b_l_stride;
+ if (transpose_b) {
+ b_j_stride = ldb;
+ b_l_stride = 1;
+ } else {
+ b_j_stride = 1;
+ b_l_stride = ldb;
+ }
+ int c_i_stride;
+ int c_j_stride;
+ if (transpose_c) {
+ c_i_stride = 1;
+ c_j_stride = ldc;
+ } else {
+ c_i_stride = ldc;
+ c_j_stride = 1;
+ }
+
+ const int32 highest = static_cast<int32>(Eigen::NumTraits<T3>::highest());
+ const int32 lowest = static_cast<int32>(Eigen::NumTraits<T3>::lowest());
+ const int32 rounding = (shift_c < 1) ? 0 : (1 << (shift_c - 1));
+
+ int i, j, l;
+ for (j = 0; j < n; j++) {
+ for (i = 0; i < m; i++) {
+ int32 total = 0;
+ for (l = 0; l < k; l++) {
+ const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
+ const int32 a_value = a[a_index] - offset_a;
+ const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
+ const int32 b_value = b[b_index] - offset_b;
+ total += (a_value * b_value);
+ }
+ const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
+ int32_t output = ((((total + offset_c) * mult_c) + rounding) >> shift_c);
+ if (output > highest) {
+ output = highest;
+ }
+ if (output < lowest) {
+ output = lowest;
+ }
+ c[c_index] = static_cast<T3>(output);
+ }
+ }
+}
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_REFERENCE_GEMM_H_
diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc
new file mode 100644
index 0000000000..d5dd7a8119
--- /dev/null
+++ b/tensorflow/core/kernels/relu_op.cc
@@ -0,0 +1,154 @@
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/relu_op.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class ReluOp : public UnaryElementWiseOp<T, ReluOp<Device, T>> {
+ public:
+ using UnaryElementWiseOp<T, ReluOp<Device, T>>::UnaryElementWiseOp;
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ functor::Relu<Device, T> functor;
+ functor(context->eigen_device<Device>(), input.flat<T>(),
+ output->flat<T>());
+ }
+};
+
+template <typename Device, typename T>
+class Relu6Op : public UnaryElementWiseOp<T, Relu6Op<Device, T>> {
+ public:
+ using UnaryElementWiseOp<T, Relu6Op<Device, T>>::UnaryElementWiseOp;
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ functor::Relu6<Device, T> functor;
+ functor(context->eigen_device<Device>(), input.flat<T>(),
+ output->flat<T>());
+ }
+};
+
+template <typename Device, typename T>
+class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> {
+ public:
+ using BinaryElementWiseOp<T, ReluGradOp<Device, T>>::BinaryElementWiseOp;
+
+ // INPUTS:
+ // g (gradients): backpropagated gradients
+ // a (inputs): inputs that were passed to ReluOp()
+ // OUTPUT:
+ // gradients to backprop
+ template <int NDIMS>
+ void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ OP_REQUIRES(context, a.IsSameSize(g),
+ errors::InvalidArgument("g and a must be the same size"));
+ functor::ReluGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
+ output->flat<T>());
+ }
+};
+
+template <typename Device, typename T>
+class Relu6GradOp : public BinaryElementWiseOp<T, Relu6GradOp<Device, T>> {
+ public:
+ using BinaryElementWiseOp<T, Relu6GradOp<Device, T>>::BinaryElementWiseOp;
+
+ // INPUTS:
+ // g (gradients): backpropagated gradients
+ // a (inputs): inputs that were passed to Relu6Op()
+ // OUTPUT:
+ // gradients to backprop
+ template <int NDIMS>
+ void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ OP_REQUIRES(context, a.IsSameSize(g),
+ errors::InvalidArgument("g and a must be the same size"));
+ functor::Relu6Grad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
+ output->flat<T>());
+ }
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ Relu6Op<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReluGradOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ Relu6GradOp<CPUDevice, type>)
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void Relu<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor features, \
+ typename TTypes<T>::Tensor activations); \
+ extern template struct Relu<GPUDevice, T>; \
+ \
+ template <> \
+ void ReluGrad<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
+ typename TTypes<T>::ConstTensor features, \
+ typename TTypes<T>::Tensor backprops); \
+ \
+ extern template struct ReluGrad<GPUDevice, T>; \
+ template <> \
+ void Relu6<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor features, \
+ typename TTypes<T>::Tensor activations); \
+ extern template struct Relu6<GPUDevice, T>; \
+ \
+ template <> \
+ void Relu6Grad<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
+ typename TTypes<T>::ConstTensor features, \
+ typename TTypes<T>::Tensor backprops); \
+ extern template struct Relu6Grad<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ ReluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ Relu6Op<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ ReluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ Relu6GradOp<GPUDevice, type>)
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
+#undef REGISTER_GPU_KERNELS
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h
new file mode 100644
index 0000000000..8ed071cc4a
--- /dev/null
+++ b/tensorflow/core/kernels/relu_op.h
@@ -0,0 +1,79 @@
+#ifndef TENSORFLOW_KERNELS_RELU_OP_H_
+#define TENSORFLOW_KERNELS_RELU_OP_H_
+// Functor definition for ReluOp and ReluGradOp, must be compilable by nvcc.
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+// Functor used by ReluOp to do the computations.
+template <typename Device, typename T>
+struct Relu {
+ // Computes Relu activation.
+ //
+ // features: any shape.
+ // activations: same shape as "features".
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
+ typename TTypes<T>::Tensor activations) {
+ activations.device(d) = features.cwiseMax(static_cast<T>(0));
+ }
+};
+
+// Functor used by ReluGradOp to do the computations.
+template <typename Device, typename T>
+struct ReluGrad {
+ // Computes ReluGrad backprops.
+ //
+ // gradients: gradients backpropagated to the Relu op.
+ // features: inputs that where passed to the Relu op.
+ // backprops: gradients to backpropagate to the Relu inputs.
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
+ typename TTypes<T>::ConstTensor features,
+ typename TTypes<T>::Tensor backprops) {
+ // NOTE: When the activation is exactly zero, we arbitrarily choose to not
+ // propagate the associated gradient value.
+ backprops.device(d) =
+ gradients * (features > features.constant(static_cast<T>(0)));
+ }
+};
+
+// Functor used by Relu6Op to do the computations.
+template <typename Device, typename T>
+struct Relu6 {
+ // Computes Relu6 activation.
+ //
+ // features: any shape.
+ // activations: same shape as "features".
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
+ typename TTypes<T>::Tensor activations) {
+ activations.device(d) =
+ features.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(6));
+ }
+};
+
+// Functor used by ReluGradOp to do the computations.
+template <typename Device, typename T>
+struct Relu6Grad {
+ // Computes Relu6Grad backprops.
+ //
+ // gradients: gradients backpropagated to the Relu6 op.
+ // features: inputs that where passed to the Relu6 op.
+ // backprops: gradients to backpropagate to the Relu6 inputs.
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
+ typename TTypes<T>::ConstTensor features,
+ typename TTypes<T>::Tensor backprops) {
+ // NOTE: When the activation is exactly zero or six, we
+ // arbitrarily choose to not propagate the associated gradient
+ // value.
+ backprops.device(d) = gradients *
+ (features > features.constant(static_cast<T>(0))) *
+ (features < features.constant(static_cast<T>(6)));
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_RELU_OP_H_
diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc
new file mode 100644
index 0000000000..6bd87ff8e4
--- /dev/null
+++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc
@@ -0,0 +1,27 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <stdio.h>
+
+#include "tensorflow/core/kernels/relu_op.h"
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// Definition of the GPU implementations declared in relu_op.cc.
+#define DEFINE_GPU_KERNELS(T) \
+ template struct functor::Relu<GPUDevice, T>; \
+ template struct functor::ReluGrad<GPUDevice, T>; \
+ template struct functor::Relu6<GPUDevice, T>; \
+ template struct functor::Relu6Grad<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/reshape_op.cc b/tensorflow/core/kernels/reshape_op.cc
new file mode 100644
index 0000000000..7e1cf029de
--- /dev/null
+++ b/tensorflow/core/kernels/reshape_op.cc
@@ -0,0 +1,29 @@
+// See docs in ../ops/array_ops.cc.
+#include "tensorflow/core/kernels/reshape_op.h"
+
+namespace tensorflow {
+
+REGISTER_KERNEL_BUILDER(Name("Reshape").Device(DEVICE_CPU).HostMemory("shape"),
+ ReshapeOp);
+
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Reshape") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("shape") \
+ .TypeConstraint<type>("T"), \
+ ReshapeOp);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+#undef REGISTER_GPU_KERNEL
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Reshape")
+ .Device(DEVICE_GPU)
+ .HostMemory("tensor")
+ .HostMemory("shape")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T"),
+ ReshapeOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reshape_op.h b/tensorflow/core/kernels/reshape_op.h
new file mode 100644
index 0000000000..3fd3f4492e
--- /dev/null
+++ b/tensorflow/core/kernels/reshape_op.h
@@ -0,0 +1,83 @@
+#ifndef TENSORFLOW_KERNELS_RESHAPE_OP_H_
+#define TENSORFLOW_KERNELS_RESHAPE_OP_H_
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+class ReshapeOp : public OpKernel {
+ public:
+ explicit ReshapeOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& sizes = context->input(1);
+ // Preliminary validation of sizes.
+ OP_REQUIRES(context, TensorShapeUtils::IsLegacyVector(sizes.shape()),
+ errors::InvalidArgument("sizes input must be 1-D, not shape ",
+ sizes.shape().ShortDebugString()));
+ const int64 num_dims = sizes.NumElements();
+ OP_REQUIRES(
+ context, num_dims <= 8,
+ errors::InvalidArgument(num_dims, " > max 8 output dims supported"));
+
+ // Compute the output shape. Determine product of specified
+ // dimensions, and find the index of the unspecified one.
+ TensorShape shape;
+ int32 product = 1;
+ int unknown_index = -1;
+ auto Svec = sizes.flat<int32>();
+ for (int d = 0; d < num_dims; ++d) {
+ const int32 size = Svec(d);
+ if (size == -1) {
+ OP_REQUIRES(
+ context, unknown_index == -1,
+ errors::InvalidArgument("only one input size may be -1, not both ",
+ unknown_index, " and ", d));
+ unknown_index = d;
+ shape.AddDim(1);
+ } else {
+ OP_REQUIRES(context, size >= 0,
+ errors::InvalidArgument(
+ "size ", d, " must be non-negative, not ", size));
+ shape.AddDim(size);
+ product *= size;
+ }
+ }
+ if (unknown_index != -1) {
+ OP_REQUIRES(
+ context, product > 0,
+ errors::InvalidArgument("cannot infer the missing input size for "
+ "an empty tensor unless all specified "
+ "input sizes are non-zero"));
+ const int32 missing = input.NumElements() / product;
+ OP_REQUIRES(context, product * missing == input.NumElements(),
+ errors::InvalidArgument("Input has ", input.NumElements(),
+ " values, which isn't divisible by ",
+ product));
+ shape.set_dim(unknown_index, missing);
+ }
+ OP_REQUIRES(context, shape.num_elements() == input.NumElements(),
+ errors::InvalidArgument("Input has ", input.NumElements(),
+ " values, which isn't the same as ",
+ shape.num_elements()));
+
+ // Actually produce the reshaped output.
+ Tensor output(input.dtype());
+ CHECK(output.CopyFrom(input, shape));
+ context->set_output(0, output);
+ }
+
+ bool IsExpensive() override { return false; }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_RESHAPE_OP_H_
diff --git a/tensorflow/core/kernels/resize_area_op.cc b/tensorflow/core/kernels/resize_area_op.cc
new file mode 100644
index 0000000000..2b22d38ad6
--- /dev/null
+++ b/tensorflow/core/kernels/resize_area_op.cc
@@ -0,0 +1,139 @@
+// See docs in ../ops/image_ops.cc
+#define EIGEN_USE_THREADS
+
+#include <algorithm>
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename Device, typename T>
+class ResizeAreaOp : public OpKernel {
+ public:
+ explicit ResizeAreaOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(context, input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().ShortDebugString()));
+ const Tensor& shape_t = context->input(1);
+ OP_REQUIRES(context, shape_t.dims() == 1,
+ errors::InvalidArgument("shape_t must be 1-dimensional",
+ shape_t.shape().ShortDebugString()));
+ OP_REQUIRES(context, shape_t.NumElements() == 2,
+ errors::InvalidArgument("shape_t must have two elements",
+ shape_t.shape().ShortDebugString()));
+
+ auto Svec = shape_t.vec<int32>();
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, TensorShape({input.dim_size(0), Svec(0),
+ Svec(1), input.dim_size(3)}),
+ &output));
+ const int64 batch_size = input.dim_size(0);
+ const int64 in_height = input.dim_size(1);
+ const int64 in_width = input.dim_size(2);
+ const int64 channels = input.dim_size(3);
+ const int64 out_height = output->dim_size(1);
+ const int64 out_width = output->dim_size(2);
+
+ typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
+ typename TTypes<float, 4>::Tensor output_data = output->tensor<float, 4>();
+
+ // A temporary tensor for computing the sum.
+ Tensor sum_tensor;
+ OP_REQUIRES_OK(
+ context, context->allocate_temp(DataTypeToEnum<float>::value,
+ TensorShape({channels}), &sum_tensor));
+ typename TTypes<float, 1>::Tensor sum_data = sum_tensor.vec<float>();
+
+ const float height_scale = in_height / static_cast<float>(out_height);
+ const float width_scale = in_width / static_cast<float>(out_width);
+
+ // When using this algorithm for downsizing, the target pixel value is the
+ // weighted average of all the source pixels. The weight is determined by
+ // the contribution percentage of the source pixel.
+ //
+ // Let "scale" be "target_image_size/source_image_size". If 1/n of the
+ // source pixel contributes to the target pixel, then the weight is (1/n *
+ // scale); if the complete source pixel contributes to the target pixel,
+ // then the weight is scale.
+ //
+ // To visualize the implementation, use one dimension as an example:
+ // Resize in[4] to out[3].
+ // scale = 3/4 = 0.75
+ // out[0]: in[0] and 1/3 of in[1]
+ // out[1]: 2/3 of in[1] and 2/3 of in[2]
+ // out[2]: 1/3 of in[2] and in[1]
+ // Hence, the output pixel values are:
+ // out[0] = (in[0] * 1.0 + in[1] * 1/3) * scale
+ // out[1] = (in[1] * 2/3 + in[2] * 2/3 * scale
+ // out[2] = (in[3] * 1/3 + in[3] * 1.0) * scale
+ float scale = 1.0 / (height_scale * width_scale);
+ for (int64 b = 0; b < batch_size; ++b) {
+ for (int64 y = 0; y < out_height; ++y) {
+ const float in_y = y * height_scale;
+ const float in_y1 = (y + 1) * height_scale;
+ // The start and end height indices of all the cells that could
+ // contribute to the target cell.
+ int64 y_start = floor(in_y);
+ int64 y_end = ceil(in_y1);
+
+ for (int64 x = 0; x < out_width; ++x) {
+ const float in_x = x * width_scale;
+ const float in_x1 = (x + 1) * width_scale;
+ // The start and end width indices of all the cells that could
+ // contribute to the target cell.
+ int64 x_start = floor(in_x);
+ int64 x_end = ceil(in_x1);
+
+ sum_data.setConstant(0.0);
+ for (int64 i = y_start; i < y_end; ++i) {
+ float scale_y =
+ i < in_y ? i + 1 - in_y : (i + 1 > in_y1 ? in_y1 - i : 1.0);
+ for (int64 j = x_start; j < x_end; ++j) {
+ float scale_x =
+ j < in_x ? j + 1 - in_x : (j + 1 > in_x1 ? in_x1 - j : 1.0);
+ for (int64 c = 0; c < channels; ++c) {
+#define BOUND(val, limit) std::min(((limit)-1ll), (std::max(0ll, (val))))
+ sum_data(c) +=
+ input_data(b, BOUND(i, in_height), BOUND(j, in_width), c) *
+ scale_y * scale_x * scale;
+#undef BOUND
+ }
+ }
+ }
+ for (int64 c = 0; c < channels; ++c) {
+ output_data(b, y, x, c) = sum_data(c);
+ }
+ }
+ }
+ }
+ }
+};
+
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("ResizeArea") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("size"), \
+ ResizeAreaOp<CPUDevice, T>);
+
+REGISTER_KERNEL(uint8);
+REGISTER_KERNEL(int8);
+REGISTER_KERNEL(int32);
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/resize_bicubic_op.cc b/tensorflow/core/kernels/resize_bicubic_op.cc
new file mode 100644
index 0000000000..472fc19b82
--- /dev/null
+++ b/tensorflow/core/kernels/resize_bicubic_op.cc
@@ -0,0 +1,121 @@
+// See docs in ../ops/image_ops.cc
+#define EIGEN_USE_THREADS
+
+#include <algorithm>
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename Device, typename T>
+class ResizeBicubicOp : public OpKernel {
+ public:
+ explicit ResizeBicubicOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(context, input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().ShortDebugString()));
+ const Tensor& shape_t = context->input(1);
+ OP_REQUIRES(context, shape_t.dims() == 1,
+ errors::InvalidArgument("shape_t must be 1-dimensional",
+ shape_t.shape().ShortDebugString()));
+ OP_REQUIRES(context, shape_t.NumElements() == 2,
+ errors::InvalidArgument("shape_t must have two elements",
+ shape_t.shape().ShortDebugString()));
+
+ auto Svec = shape_t.vec<int32>();
+ // Initialize shape to the batch size of the input, then add
+ // the rest of the dimensions
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, TensorShape({input.dim_size(0), Svec(0),
+ Svec(1), input.dim_size(3)}),
+ &output));
+ const int64 batch_size = input.dim_size(0);
+ const int64 in_height = input.dim_size(1);
+ const int64 in_width = input.dim_size(2);
+ const int64 channels = input.dim_size(3);
+ const int64 out_height = output->dim_size(1);
+ const int64 out_width = output->dim_size(2);
+
+ typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
+ typename TTypes<float, 4>::Tensor output_data = output->tensor<float, 4>();
+
+ const float height_scale = in_height / static_cast<float>(out_height);
+ const float width_scale = in_width / static_cast<float>(out_width);
+
+ // Initialize coefficients table using Bicubic convolution algorithm.
+ // https://en.wikipedia.org/wiki/Bicubic_interpolation
+ static const int64 tab_size = (1 << 10);
+ static float coeffs_tab[(tab_size + 1) * 2];
+ static const double A = -0.75;
+ for (int i = 0; i <= tab_size; ++i) {
+ float x = i * 1.0 / tab_size;
+ coeffs_tab[i * 2] = ((A + 2) * x - (A + 3)) * x * x + 1;
+ x += 1.0;
+ coeffs_tab[i * 2 + 1] = ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
+ }
+
+ auto cal = [](float v0, float v1, float v2, float v3, float dx) {
+ const int64 offset = round(dx * tab_size);
+ const float a0 = coeffs_tab[offset * 2 + 1];
+ const float a1 = coeffs_tab[offset * 2];
+ const float a2 = coeffs_tab[(tab_size - offset) * 2];
+ const float a3 = coeffs_tab[(tab_size - offset) * 2 + 1];
+ return a0 * v0 + a1 * v1 + a2 * v2 + a3 * v3;
+ };
+
+ float coeff[4] = {0.0};
+ for (int64 b = 0; b < batch_size; ++b) {
+ for (int64 y = 0; y < out_height; ++y) {
+ const int64 in_y = floor(height_scale * y);
+ const float dy = height_scale * y - in_y;
+ for (int64 x = 0; x < out_width; ++x) {
+ const int64 in_x = floor(width_scale * x);
+ const float dx = width_scale * x - in_x;
+ for (int64 c = 0; c < channels; ++c) {
+ for (int64 i = 0; i < 4; ++i) {
+#define BOUND(val, limit) std::min(((limit)-1ll), (std::max(0ll, (val))))
+ int64 bound_y = BOUND(in_y - 1 + i, in_height);
+ coeff[i] =
+ cal(input_data(b, bound_y, BOUND(in_x - 1, in_width), c),
+ input_data(b, bound_y, BOUND(in_x, in_width), c),
+ input_data(b, bound_y, BOUND(in_x + 1, in_width), c),
+ input_data(b, bound_y, BOUND(in_x + 2, in_width), c), dx);
+#undef BOUND
+ }
+ output_data(b, y, x, c) =
+ cal(coeff[0], coeff[1], coeff[2], coeff[3], dy);
+ }
+ }
+ }
+ }
+ }
+};
+
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("ResizeBicubic") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("size"), \
+ ResizeBicubicOp<CPUDevice, T>);
+
+REGISTER_KERNEL(uint8);
+REGISTER_KERNEL(int8);
+REGISTER_KERNEL(int32);
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/resize_bilinear_op.cc b/tensorflow/core/kernels/resize_bilinear_op.cc
new file mode 100644
index 0000000000..5119b93508
--- /dev/null
+++ b/tensorflow/core/kernels/resize_bilinear_op.cc
@@ -0,0 +1,109 @@
+// See docs in ../ops/image_ops.cc
+#define EIGEN_USE_THREADS
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename Device, typename T>
+class ResizeBilinearOp : public OpKernel {
+ public:
+ explicit ResizeBilinearOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(context, input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().ShortDebugString()));
+ const Tensor& shape_t = context->input(1);
+ OP_REQUIRES(context, shape_t.dims() == 1,
+ errors::InvalidArgument("shape_t must be 1-dimensional",
+ shape_t.shape().ShortDebugString()));
+ OP_REQUIRES(context, shape_t.NumElements() == 2,
+ errors::InvalidArgument("shape_t must have two elements",
+ shape_t.shape().ShortDebugString()));
+
+ auto Svec = shape_t.vec<int32>();
+ // Initialize shape to the batch size of the input, then add
+ // the rest of the dimensions
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, TensorShape({input.dim_size(0), Svec(0),
+ Svec(1), input.dim_size(3)}),
+ &output));
+
+ const int64 batch_size = input.dim_size(0);
+ const int64 in_height = input.dim_size(1);
+ const int64 in_width = input.dim_size(2);
+ const int64 channels = input.dim_size(3);
+ const int64 out_height = output->dim_size(1);
+ const int64 out_width = output->dim_size(2);
+
+ typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
+ typename TTypes<float, 4>::Tensor output_data = output->tensor<float, 4>();
+
+ const float height_scale = in_height / static_cast<float>(out_height);
+ const float width_scale = in_width / static_cast<float>(out_width);
+
+ for (int b = 0; b < batch_size; ++b) {
+ for (int y = 0; y < out_height; ++y) {
+ const float in_y = y * height_scale;
+ const int top_y_index = static_cast<int>(floorf(in_y));
+ const int bottom_y_index =
+ std::min(static_cast<int64>(ceilf(in_y)), (in_height - 1));
+ const float y_lerp = in_y - top_y_index;
+ const float inverse_y_lerp = (1.0f - y_lerp);
+ for (int x = 0; x < out_width; ++x) {
+ const float in_x = x * width_scale;
+ const int left_x_index = static_cast<int>(floorf(in_x));
+ const int right_x_index =
+ std::min(static_cast<int64>(ceilf(in_x)), (in_width - 1));
+ const float x_lerp = in_x - left_x_index;
+ const float inverse_x_lerp = (1.0f - x_lerp);
+ for (int c = 0; c < channels; ++c) {
+ const float top_left = input_data(b, top_y_index, left_x_index, c);
+ const float top_right =
+ input_data(b, top_y_index, right_x_index, c);
+ const float bottom_left =
+ input_data(b, bottom_y_index, left_x_index, c);
+ const float bottom_right =
+ input_data(b, bottom_y_index, right_x_index, c);
+ const float top =
+ (top_left * inverse_x_lerp) + (top_right * x_lerp);
+ const float bottom =
+ (bottom_left * inverse_x_lerp) + (bottom_right * x_lerp);
+ output_data(b, y, x, c) =
+ (top * inverse_y_lerp) + (bottom * y_lerp);
+ }
+ }
+ }
+ }
+ }
+};
+
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("ResizeBilinear") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("size"), \
+ ResizeBilinearOp<CPUDevice, T>);
+
+REGISTER_KERNEL(uint8);
+REGISTER_KERNEL(int8);
+REGISTER_KERNEL(int32);
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/resize_bilinear_op_test.cc b/tensorflow/core/kernels/resize_bilinear_op_test.cc
new file mode 100644
index 0000000000..0ebe2e5f8c
--- /dev/null
+++ b/tensorflow/core/kernels/resize_bilinear_op_test.cc
@@ -0,0 +1,171 @@
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+
+class ResizeBilinearOpTest : public OpsTestBase {
+ protected:
+ ResizeBilinearOpTest() {
+ RequireDefaultOps();
+ EXPECT_OK(NodeDefBuilder("resize_bilinear_op", "ResizeBilinear")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_INT32))
+ .Finalize(node_def()));
+ EXPECT_OK(InitOp());
+ }
+};
+
+TEST_F(ResizeBilinearOpTest, TestBilinear2x2To1x1) {
+ // Input:
+ // 1, 2
+ // 3, 4
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {1, 1});
+ ASSERT_OK(RunOpKernel());
+
+ // When scaling down, we have to arbitrarily pick a pixel from the
+ // original input. In this case, we choose the top/left most pixel.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1, 1}));
+ test::FillValues<float>(&expected, {1.0});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(ResizeBilinearOpTest, TestBilinear2x2To3x3) {
+ // Input:
+ // 1, 2
+ // 3, 4
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {3, 3});
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 3, 3, 1}));
+
+ // The corners should match the original corners, and we bilinear
+ // interpolate the values in between.
+
+ // clang-format off
+ test::FillValues<float>(&expected,
+ {1, 5.0/3, 2,
+ 7.0/3, 3, 10.0/3,
+ 3, 11.0/3, 4});
+
+ // clang-format on
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(ResizeBilinearOpTest, TestBilinear3x3To4x4) {
+ // Input:
+ // 1, 2, 3,
+ // 4, 5, 6,
+ // 7, 8, 9
+ AddInputFromArray<float>(TensorShape({1, 3, 3, 1}),
+ {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ AddInputFromArray<int32>(TensorShape({2}), {4, 4});
+ ASSERT_OK(RunOpKernel());
+
+ // The corners should match the original corners, and we bilinear
+ // interpolate the values in between.
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 4, 4, 1}));
+ // clang-format off
+ test::FillValues<float>(&expected,
+ {1, 1.75, 2.5, 3,
+ 3.25, 4, 4.75, 5.25,
+ 5.5, 6.25, 7, 7.5,
+ 7, 7.75, 8.5, 9});
+
+ // clang-format on
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(ResizeBilinearOpTest, TestBilinear2x2To3x3Batch2) {
+ // Input:
+ // 1, 2
+ // 3, 4
+ //
+ // repeated twice
+ AddInputFromArray<float>(TensorShape({2, 2, 2, 1}), {1, 2, 3, 4, 1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {3, 3});
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3, 3, 1}));
+ // clang-format off
+ test::FillValues<float>(&expected,
+ {1, 5.0/3, 2, 7.0/3, 3, 10.0/3, 3, 11.0/3, 4,
+ 1, 5.0/3, 2, 7.0/3, 3, 10.0/3, 3, 11.0/3, 4
+ });
+ // clang-format on
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(ResizeBilinearOpTest, TestBilinear2x2x2To3x3x2) {
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 2}),
+ {1, -1, 2, -2, 3, -3, 4, -4});
+ AddInputFromArray<int32>(TensorShape({2}), {3, 3});
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 3, 3, 2}));
+ // clang-format off
+ test::FillValues<float>(&expected,
+ {
+ 1, -1,
+ 5.0/3, -5.0/3,
+ 2, -2,
+ 7.0/3, -7.0/3,
+ 3, -3,
+ 10.0/3, -10.0/3,
+ 3, -3,
+ 11.0/3, -11.0/3,
+ 4, -4
+ });
+ // clang-format on
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(ResizeBilinearOpTest, TestBilinear2x2To4x4) {
+ // Input:
+ // 1, 2
+ // 3, 4
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {4, 4});
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 4, 4, 1}));
+ // clang-format off
+ test::FillValues<float>(&expected,
+ {1, 1.5, 2, 2,
+ 2, 2.5, 3, 3,
+ 3, 3.5, 4, 4,
+ 3, 3.5, 4, 4});
+ // clang-format on
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(ResizeBilinearOpTest, TestInvalidInputShape) {
+ AddInputFromArray<float>(TensorShape({2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {4, 4});
+ ASSERT_FALSE(RunOpKernel().ok());
+}
+
+TEST_F(ResizeBilinearOpTest, TestInvalidSizeDim) {
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2, 1}), {4, 4});
+ ASSERT_FALSE(RunOpKernel().ok());
+}
+TEST_F(ResizeBilinearOpTest, TestInvalidSizeElements) {
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({3}), {4, 4, 1});
+ ASSERT_FALSE(RunOpKernel().ok());
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
new file mode 100644
index 0000000000..13089308ce
--- /dev/null
+++ b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
@@ -0,0 +1,89 @@
+// See docs in ../ops/image_ops.cc
+#define EIGEN_USE_THREADS
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename Device, typename T>
+class ResizeNearestNeighborOp : public OpKernel {
+ public:
+ explicit ResizeNearestNeighborOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(context, input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().ShortDebugString()));
+ const Tensor& shape_t = context->input(1);
+ OP_REQUIRES(context, shape_t.dims() == 1,
+ errors::InvalidArgument("shape_t must be 1-dimensional",
+ shape_t.shape().ShortDebugString()));
+ OP_REQUIRES(context, shape_t.NumElements() == 2,
+ errors::InvalidArgument("shape_t must have two elements",
+ shape_t.shape().ShortDebugString()));
+
+ auto Svec = shape_t.vec<int32>();
+ // Initialize shape to the batch size of the input, then add
+ // the rest of the dimensions
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, TensorShape({input.dim_size(0), Svec(0),
+ Svec(1), input.dim_size(3)}),
+ &output));
+
+ const int64 batch_size = input.dim_size(0);
+ const int64 in_height = input.dim_size(1);
+ const int64 in_width = input.dim_size(2);
+ const int64 channels = input.dim_size(3);
+ const int64 out_height = output->dim_size(1);
+ const int64 out_width = output->dim_size(2);
+
+ typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
+ typename TTypes<T, 4>::Tensor output_data = output->tensor<T, 4>();
+
+ const float height_scale = in_height / static_cast<float>(out_height);
+ const float width_scale = in_width / static_cast<float>(out_width);
+
+ for (int b = 0; b < batch_size; ++b) {
+ for (int y = 0; y < out_height; ++y) {
+ const int in_y = std::min(static_cast<int64>(floorf(y * height_scale)),
+ (in_height - 1));
+ for (int x = 0; x < out_width; ++x) {
+ const int in_x = std::min(static_cast<int64>(floorf(x * width_scale)),
+ (in_width - 1));
+ for (int c = 0; c < channels; ++c) {
+ output_data(b, y, x, c) = input_data(b, in_y, in_x, c);
+ }
+ }
+ }
+ }
+ }
+};
+
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighbor") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("size"), \
+ ResizeNearestNeighborOp<CPUDevice, T>);
+
+REGISTER_KERNEL(uint8);
+REGISTER_KERNEL(int8);
+REGISTER_KERNEL(int32);
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op_test.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op_test.cc
new file mode 100644
index 0000000000..8fca1f34e3
--- /dev/null
+++ b/tensorflow/core/kernels/resize_nearest_neighbor_op_test.cc
@@ -0,0 +1,163 @@
+// TODO(shlens, sherrym): Consider adding additional tests in image_ops.py in
+// order to compare the reference implementation for image resizing in Python
+// Image Library.
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+
+class ResizeNearestNeighborOpTest : public OpsTestBase {
+ protected:
+ ResizeNearestNeighborOpTest() {
+ RequireDefaultOps();
+ EXPECT_OK(NodeDefBuilder("resize_nn", "ResizeNearestNeighbor")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_INT32))
+ .Finalize(node_def()));
+ EXPECT_OK(InitOp());
+ }
+};
+
+TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To1x1) {
+ // Input:
+ // 1, 2
+ // 3, 4
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {1, 1});
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 1, 1}));
+
+ // clang-format off
+ test::FillValues<float>(&expected, {1});
+
+ // clang-format on
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To3x3) {
+ // Input:
+ // 1, 2
+ // 3, 4
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {3, 3});
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 3, 3, 1}));
+
+ // clang-format off
+ test::FillValues<float>(&expected,
+ {1, 1, 2,
+ 1, 1, 2,
+ 3, 3, 4});
+
+ // clang-format on
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To2x5) {
+ // Input:
+ // 1, 2
+ // 3, 4
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {2, 5});
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 5, 1}));
+
+ // clang-format off
+ test::FillValues<float>(&expected,
+ {1, 1, 1, 2, 2,
+ 3, 3, 3, 4, 4});
+
+ // clang-format on
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To5x2) {
+ // Input:
+ // 1, 2
+ // 3, 4
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {5, 2});
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 5, 2, 1}));
+
+ // clang-format off
+ test::FillValues<float>(&expected,
+ {1, 2,
+ 1, 2,
+ 1, 2,
+ 3, 4,
+ 3, 4});
+
+ // clang-format on
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To4x4) {
+ // Input:
+ // 1, 2
+ // 3, 4
+ AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
+ AddInputFromArray<int32>(TensorShape({2}), {4, 4});
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 4, 4, 1}));
+
+ // clang-format off
+ test::FillValues<float>(&expected,
+ {1, 1, 2, 2,
+ 1, 1, 2, 2,
+ 3, 3, 4, 4,
+ 3, 3, 4, 4});
+
+ // clang-format on
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2x2x2To2x3x3x2) {
+ // Input:
+ // [ [ 1, 1 ], [ 2, 2],
+ // [ 3, 3 ], [ 4, 4] ],
+ // [ [ 5, 5 ], [ 6, 6],
+ // [ 7, 7 ], [ 8, 8] ]
+ AddInputFromArray<float>(TensorShape({2, 2, 2, 2}),
+ {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8});
+ AddInputFromArray<int32>(TensorShape({2}), {3, 3});
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3, 3, 2}));
+
+ // clang-format off
+ test::FillValues<float>(&expected,
+ {1, 1, 1,
+ 1, 2, 2,
+ 1, 1, 1,
+ 1, 2, 2,
+ 3, 3, 3,
+ 3, 4, 4,
+ 5, 5, 5,
+ 5, 6, 6,
+ 5, 5, 5,
+ 5, 6, 6,
+ 7, 7, 7,
+ 7, 8, 8});
+
+ // clang-format on
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/restore_op.cc b/tensorflow/core/kernels/restore_op.cc
new file mode 100644
index 0000000000..b52c69449c
--- /dev/null
+++ b/tensorflow/core/kernels/restore_op.cc
@@ -0,0 +1,65 @@
+// See docs in ../ops/io_ops.cc.
+#include "tensorflow/core/kernels/io.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/tensor_slice_reader.h"
+
+namespace tensorflow {
+
+class RestoreOp : public OpKernel {
+ public:
+ explicit RestoreOp(OpKernelConstruction* context) : OpKernel(context) {
+ int preferred_shard;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("preferred_shard", &preferred_shard));
+ if (preferred_shard == -1) {
+ preferred_shard_ = checkpoint::TensorSliceReader::kLoadAllShards;
+ } else {
+ OP_REQUIRES(context, preferred_shard >= 0,
+ errors::InvalidArgument("Attribute 'preferred_shard' must be "
+ "greater or equal to -1"));
+ preferred_shard_ = preferred_shard;
+ }
+ }
+ void Compute(OpKernelContext* context) override {
+ RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,
+ preferred_shard_, false);
+ }
+
+ private:
+ int preferred_shard_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("Restore").Device(DEVICE_CPU), RestoreOp);
+
+class RestoreSliceOp : public OpKernel {
+ public:
+ explicit RestoreSliceOp(OpKernelConstruction* context) : OpKernel(context) {
+ int preferred_shard;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("preferred_shard", &preferred_shard));
+ if (preferred_shard == -1) {
+ preferred_shard_ = checkpoint::TensorSliceReader::kLoadAllShards;
+ } else {
+ OP_REQUIRES(context, preferred_shard >= 0,
+ errors::InvalidArgument("Attribute 'preferred_shard' must be "
+ "greater or equal to -1"));
+ preferred_shard_ = preferred_shard;
+ }
+ }
+ void Compute(OpKernelContext* context) override {
+ RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,
+ preferred_shard_, true);
+ }
+
+ private:
+ int preferred_shard_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("RestoreSlice").Device(DEVICE_CPU),
+ RestoreSliceOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/restore_op_test.cc b/tensorflow/core/kernels/restore_op_test.cc
new file mode 100644
index 0000000000..59343a8037
--- /dev/null
+++ b/tensorflow/core/kernels/restore_op_test.cc
@@ -0,0 +1,305 @@
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/tensor_slice_reader_cache.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+class RestoreOpTest : public OpsTestBase {
+ protected:
+ // Makes an operation to restore two tensors
+ void MakeRestoreOp(DataType dt) {
+ RequireDefaultOps();
+ ASSERT_OK(NodeDefBuilder("myop", "Restore")
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Attr("dt", dt)
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(RestoreOpTest, RestoreInt) {
+ const string filename = io::JoinPath(testing::TmpDir(), "tensor_int");
+ const string tensor_name = "tensor_int";
+
+ // We first need to write a tensor using the save_op
+ {
+ // Initialize an operation
+ NodeDef save;
+ ASSERT_OK(NodeDefBuilder("save", "Save")
+ .Input(FakeInput(DT_STRING))
+ .Input(FakeInput(DT_STRING))
+ .Input(FakeInput({DT_INT32}))
+ .Finalize(&save));
+
+ std::unique_ptr<Device> device(
+ DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
+
+ gtl::InlinedVector<TensorValue, 4> inputs;
+
+ Status status;
+ std::unique_ptr<OpKernel> op(CreateOpKernel(
+ DEVICE_CPU, device.get(), cpu_allocator(), save, &status));
+ EXPECT_OK(status);
+
+ // Run it
+
+ // Input #0 is the file name
+ Tensor input_0(DT_STRING, TensorShape({}));
+ input_0.scalar<string>()() = filename;
+ inputs.push_back({nullptr, &input_0});
+
+ // Input #1 is the tensor name
+ Tensor input_1(DT_STRING, TensorShape({}));
+ input_1.scalar<string>()() = tensor_name;
+ inputs.push_back({nullptr, &input_1});
+
+ // Input #2 is an integer tensor: it's a 1-d array.
+ Tensor input_2(DT_INT32, TensorShape({10}));
+ for (int i = 0; i < 10; ++i) {
+ input_2.flat<int32>()(i) = i + 1;
+ }
+ inputs.push_back({nullptr, &input_2});
+
+ OpKernelContext::Params params;
+ params.device = device.get();
+ params.frame_iter = FrameAndIter(0, 0);
+ params.inputs = &inputs;
+ params.op_kernel = op.get();
+ params.output_alloc_attr = [&device, &op, &params](int index) {
+ AllocatorAttributes attr;
+ const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY);
+ attr.set_on_host(on_host);
+ return attr;
+ };
+ checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
+ params.slice_reader_cache = &slice_reader_cache_wrapper;
+
+ OpKernelContext ctx(params);
+ op->Compute(&ctx);
+ EXPECT_OK(ctx.status());
+ }
+
+ // Now we restore
+ MakeRestoreOp(DT_INT32);
+ // Add a file name
+ AddInput<string>(TensorShape({}),
+ [&filename](int x) -> string { return filename; });
+ // Add the tensor names
+ AddInput<string>(TensorShape({}),
+ [&tensor_name](int x) -> string { return tensor_name; });
+
+ ASSERT_OK(RunOpKernel());
+
+ // Check that we have an integer tensor
+ Tensor* output = GetOutput(0);
+ TensorShape expected({10});
+ EXPECT_TRUE(output->shape().IsSameSize(expected));
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_EQ(i + 1, output->flat<int32>()(i));
+ }
+}
+
+TEST_F(RestoreOpTest, RestoreFloat) {
+ const string filename = io::JoinPath(testing::TmpDir(), "tensor_float");
+ const string tensor_name = "tensor_float";
+
+ // We first need to write a tensor using the save_op
+ {
+ // Initialize an operation
+ NodeDef save;
+ ASSERT_OK(NodeDefBuilder("save", "Save")
+ .Input(FakeInput(DT_STRING))
+ .Input(FakeInput(DT_STRING))
+ .Input(FakeInput({DT_FLOAT}))
+ .Finalize(&save));
+
+ std::unique_ptr<Device> device(
+ DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
+ gtl::InlinedVector<TensorValue, 4> inputs;
+
+ Status status;
+ std::unique_ptr<OpKernel> op(CreateOpKernel(
+ DEVICE_CPU, device.get(), cpu_allocator(), save, &status));
+ EXPECT_OK(status);
+
+ // Run it
+
+ // Input #0 is the file name
+ Tensor input_0(DT_STRING, TensorShape({}));
+ input_0.scalar<string>()() = filename;
+ inputs.push_back({nullptr, &input_0});
+
+ // Input #1 is the tensor name
+ Tensor input_1(DT_STRING, TensorShape({}));
+ input_1.scalar<string>()() = tensor_name;
+ inputs.push_back({nullptr, &input_1});
+
+ // Input #2 is a float tensor: it's a 2-d array.
+ Tensor input_2(DT_FLOAT, TensorShape({2, 4}));
+ for (int i = 0; i < 8; ++i) {
+ input_2.flat<float>()(i) = static_cast<float>(i) / 10;
+ }
+ inputs.push_back({nullptr, &input_2});
+
+ OpKernelContext::Params params;
+ params.device = device.get();
+ params.frame_iter = FrameAndIter(0, 0);
+ params.inputs = &inputs;
+ params.op_kernel = op.get();
+ params.output_alloc_attr = [&device, &op, &params](int index) {
+ AllocatorAttributes attr;
+ const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY);
+ attr.set_on_host(on_host);
+ return attr;
+ };
+ checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
+ params.slice_reader_cache = &slice_reader_cache_wrapper;
+
+ OpKernelContext ctx(params);
+ op->Compute(&ctx);
+ EXPECT_OK(ctx.status());
+ }
+
+ // Now we restore
+ MakeRestoreOp(DT_FLOAT);
+ // Add a file name
+ AddInput<string>(TensorShape({}),
+ [&filename](int x) -> string { return filename; });
+ // Add the tensor names
+ AddInput<string>(TensorShape({}),
+ [&tensor_name](int x) -> string { return tensor_name; });
+
+ ASSERT_OK(RunOpKernel());
+
+ // Check that we have a float tensor.
+ Tensor* output = GetOutput(0);
+ TensorShape expected({2, 4});
+ EXPECT_TRUE(output->shape().IsSameSize(expected));
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_EQ(static_cast<float>(i) / 10, output->flat<float>()(i));
+ }
+}
+
+class RestoreSliceOpTest : public OpsTestBase {
+ protected:
+ void MakeRestoreSliceOp(DataType dt) {
+ RequireDefaultOps();
+ ASSERT_OK(NodeDefBuilder("myop", "RestoreSlice")
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Attr("dt", dt)
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(RestoreSliceOpTest, RestoreInt) {
+ const string filename = io::JoinPath(testing::TmpDir(), "tensor_int");
+ const string tensor_name = "tensor_int";
+
+ // We first need to write a tensor using the save_op
+ {
+ // Initialize an operation
+ NodeDef save;
+ ASSERT_OK(NodeDefBuilder("save", "Save")
+ .Input(FakeInput(DT_STRING))
+ .Input(FakeInput(DT_STRING))
+ .Input(FakeInput({DT_INT32}))
+ .Finalize(&save));
+
+ std::unique_ptr<Device> device(
+ DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
+
+ gtl::InlinedVector<TensorValue, 4> inputs;
+
+ Status status;
+ std::unique_ptr<OpKernel> op(CreateOpKernel(
+ DEVICE_CPU, device.get(), cpu_allocator(), save, &status));
+ EXPECT_OK(status);
+
+ // Run it
+
+ // Input #0 is the file name
+ Tensor input_0(DT_STRING, TensorShape({}));
+ input_0.scalar<string>()() = filename;
+ inputs.push_back({nullptr, &input_0});
+
+ // Input #1 is the tensor name
+ Tensor input_1(DT_STRING, TensorShape({}));
+ input_1.scalar<string>()() = tensor_name;
+ inputs.push_back({nullptr, &input_1});
+
+ // Input #2 is a 4x16 integer tensor.
+ Tensor input_2(DT_INT32, TensorShape({4, 16}));
+ for (int64 i = 0; i < input_2.NumElements(); ++i) {
+ input_2.flat<int32>()(i) = i + 1;
+ }
+ inputs.push_back({nullptr, &input_2});
+
+ OpKernelContext::Params params;
+ params.device = device.get();
+ params.frame_iter = FrameAndIter(0, 0);
+ params.inputs = &inputs;
+ params.op_kernel = op.get();
+ params.output_alloc_attr = [&device, &op, &params](int index) {
+ AllocatorAttributes attr;
+ const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY);
+ attr.set_on_host(on_host);
+ return attr;
+ };
+ checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
+ params.slice_reader_cache = &slice_reader_cache_wrapper;
+
+ OpKernelContext ctx(params);
+ op->Compute(&ctx);
+ EXPECT_OK(ctx.status());
+ }
+
+ // Now we restore
+ MakeRestoreSliceOp(DT_INT32);
+ string shape_and_slice = "4 16 0,2:-";
+ // Add a file name
+ AddInput<string>(TensorShape({}),
+ [&filename](int x) -> string { return filename; });
+ // Add the tensor names
+ AddInput<string>(TensorShape({}),
+ [&tensor_name](int x) -> string { return tensor_name; });
+ // Add the tensor shape and slice
+ AddInput<string>(TensorShape({}), [&shape_and_slice](int x) -> string {
+ return shape_and_slice;
+ });
+
+ ASSERT_OK(RunOpKernel());
+
+ // Check that we have an integer tensor
+ Tensor* output = GetOutput(0);
+ TensorShape expected({2, 16});
+ EXPECT_TRUE(output->shape().IsSameSize(expected));
+ for (int64 i = 0; i < expected.num_elements(); ++i) {
+ EXPECT_EQ(i + 1, output->flat<int32>()(i));
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reverse_op.cc b/tensorflow/core/kernels/reverse_op.cc
new file mode 100644
index 0000000000..c63dfc1e70
--- /dev/null
+++ b/tensorflow/core/kernels/reverse_op.cc
@@ -0,0 +1,139 @@
+// See docs in ../ops/array_ops.cc
+#define EIGEN_USE_THREADS
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/reverse_op.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class ReverseOp : public OpKernel {
+ public:
+ explicit ReverseOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& dims = context->input(1);
+
+ if (TensorShapeUtils::IsScalar(input.shape())) {
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+ output->scalar<T>() = input.scalar<T>();
+
+ } else {
+ const int input_dims = input.dims();
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(dims.shape()),
+ errors::InvalidArgument("'dims' must be 1-dimension, not ",
+ dims.dims()));
+
+ OP_REQUIRES(context, input_dims == dims.dim_size(0),
+ errors::InvalidArgument(
+ "'dims' must have the same number of values as 'input' has "
+ "dimensions. 'input' has ", input_dims, "'dims' has ",
+ dims.dim_size(0), " values"));
+ OP_REQUIRES(context, input_dims <= 8, errors::Unimplemented(
+ "reverse is not implemented for tensors of rank > 8."));
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+
+#define HANDLE_REVERSE(NDIMS) \
+ case NDIMS: \
+ functor::Reverse<Device, T, NDIMS>()( \
+ context->eigen_device<Device>(), input.tensor<T, NDIMS>(), \
+ dims.vec<bool>(), output->tensor<T, NDIMS>()); \
+ return;
+
+ switch (input_dims) {
+ HANDLE_REVERSE(0);
+ HANDLE_REVERSE(1);
+ HANDLE_REVERSE(2);
+ HANDLE_REVERSE(3);
+ HANDLE_REVERSE(4);
+ HANDLE_REVERSE(5);
+ HANDLE_REVERSE(6);
+ HANDLE_REVERSE(7);
+ HANDLE_REVERSE(8);
+ }
+#undef HANDLE_REVERSE
+ }
+ }
+};
+
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("Reverse") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("dims"), \
+ ReverseOp<CPUDevice, T>)
+
+REGISTER_KERNEL(uint8);
+REGISTER_KERNEL(int8);
+REGISTER_KERNEL(int32);
+REGISTER_KERNEL(bool);
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+
+#if GOOGLE_CUDA
+
+// Forward declarations of the function specializations for GPU (to prevent
+// building the GPU versions here, they will be built compiling _gpu.cu.cc).
+namespace functor {
+#define DECLARE_GPU_SPEC_DIM(T, DIM) \
+ template <> \
+ void Reverse<GPUDevice, T, DIM>::operator()( \
+ const GPUDevice& d, typename TTypes<T, DIM>::ConstTensor input, \
+ typename TTypes<bool, 1>::ConstTensor dims, \
+ typename TTypes<T, DIM>::Tensor output); \
+ extern template struct Reverse<GPUDevice, T, DIM>;
+#define DECLARE_GPU_SPEC(T) \
+ DECLARE_GPU_SPEC_DIM(T, 0) \
+ DECLARE_GPU_SPEC_DIM(T, 1) \
+ DECLARE_GPU_SPEC_DIM(T, 2) \
+ DECLARE_GPU_SPEC_DIM(T, 3) \
+ DECLARE_GPU_SPEC_DIM(T, 4) \
+ DECLARE_GPU_SPEC_DIM(T, 5) \
+ DECLARE_GPU_SPEC_DIM(T, 6) \
+ DECLARE_GPU_SPEC_DIM(T, 7) \
+ DECLARE_GPU_SPEC_DIM(T, 8)
+
+DECLARE_GPU_SPEC(uint8);
+DECLARE_GPU_SPEC(int8);
+DECLARE_GPU_SPEC(int32);
+DECLARE_GPU_SPEC(bool);
+DECLARE_GPU_SPEC(float);
+DECLARE_GPU_SPEC(double);
+#undef DECLARE_GPU_SPEC
+#undef DECLARE_GPU_SPEC_DIM
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER_GPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("Reverse") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("dims"), \
+ ReverseOp<GPUDevice, T>)
+REGISTER_GPU_KERNEL(uint8);
+REGISTER_GPU_KERNEL(int8);
+REGISTER_GPU_KERNEL(float);
+REGISTER_GPU_KERNEL(double);
+#undef REGISTER_GPU_KERNEL
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reverse_op.h b/tensorflow/core/kernels/reverse_op.h
new file mode 100644
index 0000000000..bba25f70e8
--- /dev/null
+++ b/tensorflow/core/kernels/reverse_op.h
@@ -0,0 +1,28 @@
+#ifndef TENSORFLOW_KERNELS_REVERSE_OP_H_
+#define TENSORFLOW_KERNELS_REVERSE_OP_H_
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+// Functor used by MirrorOp to do the computations.
+template <typename Device, typename T, int Dims>
+struct Reverse {
+ void operator()(const Device& d, typename TTypes<T, Dims>::ConstTensor input,
+ typename TTypes<bool, 1>::ConstTensor dims,
+ typename TTypes<T, Dims>::Tensor output) {
+ // mirror is in host memory
+ Eigen::array<bool, Dims> reverse_dims;
+ for (int i = 0; i < Dims; ++i) {
+ reverse_dims[i] = dims(i);
+ }
+ output.device(d) = input.reverse(reverse_dims);
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_MIRROR_OP_H_
diff --git a/tensorflow/core/kernels/reverse_op_gpu.cu.cc b/tensorflow/core/kernels/reverse_op_gpu.cu.cc
new file mode 100644
index 0000000000..b510add3f3
--- /dev/null
+++ b/tensorflow/core/kernels/reverse_op_gpu.cu.cc
@@ -0,0 +1,33 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/reverse_op.h"
+
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#define DEFINE_REVERSE(DIM) \
+ template struct functor::Reverse<GPUDevice, uint8, DIM>; \
+ template struct functor::Reverse<GPUDevice, int8, DIM>; \
+ template struct functor::Reverse<GPUDevice, int32, DIM>; \
+ template struct functor::Reverse<GPUDevice, bool, DIM>; \
+ template struct functor::Reverse<GPUDevice, float, DIM>; \
+ template struct functor::Reverse<GPUDevice, double, DIM>;
+DEFINE_REVERSE(0)
+DEFINE_REVERSE(1)
+DEFINE_REVERSE(2)
+DEFINE_REVERSE(3)
+DEFINE_REVERSE(4)
+DEFINE_REVERSE(5)
+DEFINE_REVERSE(6)
+DEFINE_REVERSE(7)
+DEFINE_REVERSE(8)
+#undef DEFINE_REVERSE
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/reverse_op_test.cc b/tensorflow/core/kernels/reverse_op_test.cc
new file mode 100644
index 0000000000..d41c36e693
--- /dev/null
+++ b/tensorflow/core/kernels/reverse_op_test.cc
@@ -0,0 +1,101 @@
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+class ReverseOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(DataType data_type) {
+ RequireDefaultOps();
+ ASSERT_OK(NodeDefBuilder("myop", "Reverse")
+ .Input(FakeInput(data_type))
+ .Input(FakeInput())
+ .Attr("T", data_type)
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(ReverseOpTest, Reverse_0) {
+ MakeOp(DT_FLOAT);
+ AddInputFromArray<float>(TensorShape({}), {3});
+ AddInputFromArray<bool>(TensorShape({}), {true});
+ ASSERT_OK(RunOpKernel());
+
+ Tensor* output = GetOutput(0);
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({}));
+ expected.scalar<float>() = expected.scalar<float>().constant(3.f);
+ test::ExpectTensorEqual<float>(expected, *output);
+}
+
+TEST_F(ReverseOpTest, Reverse_234) {
+ MakeOp(DT_FLOAT);
+
+ // Feed and run
+ // [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
+ // [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]
+ AddInputFromArray<float>(TensorShape({2, 3, 4}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
+ 15, 16, 17, 18, 19, 20, 21, 22, 23});
+ AddInputFromArray<bool>(TensorShape({3}), {true, false, true});
+
+ ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor* params_tensor = GetOutput(0);
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3, 4}));
+ // Should become
+ // [[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]]
+ // [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]]
+ test::FillValues<float>(
+ &expected, {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, 3, 2, 1, 0, 7,
+ 6, 5, 4, 11, 10, 9, 8});
+ test::ExpectTensorEqual<float>(expected, *params_tensor);
+}
+
+TEST_F(ReverseOpTest, Reverse_1234) {
+ MakeOp(DT_FLOAT);
+
+ // Feed and run
+ // [[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
+ // [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]]
+ AddInputFromArray<float>(TensorShape({1, 2, 3, 4}),
+ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
+ 15, 16, 17, 18, 19, 20, 21, 22, 23});
+ AddInputFromArray<bool>(TensorShape({4}), {true, true, false, true});
+
+ ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor* params_tensor = GetOutput(0);
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 2, 3, 4}));
+ // Should become
+ // [[[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]]
+ // [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]]]
+ test::FillValues<float>(
+ &expected, {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, 3, 2, 1, 0, 7,
+ 6, 5, 4, 11, 10, 9, 8});
+ test::ExpectTensorEqual<float>(expected, *params_tensor);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc
new file mode 100644
index 0000000000..6673a700ef
--- /dev/null
+++ b/tensorflow/core/kernels/reverse_sequence_op.cc
@@ -0,0 +1,170 @@
+// See docs in ../ops/array_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif // GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/reverse_sequence_op.h"
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device>
+void CheckErrors(OpKernelContext* context, int seq_dim) {
+ const Tensor& input = context->input(0);
+ const Tensor& seq_lens = context->input(1);
+
+ auto seq_lens_t = seq_lens.vec<int64>();
+
+ std::vector<int64> seq_lens_vec(seq_lens_t.size());
+
+ // Copy seq_len info down for validity checks
+ context->eigen_device<Device>().memcpyDeviceToHost(
+ seq_lens_vec.data(), seq_lens_t.data(),
+ sizeof(int64) * seq_lens_t.size());
+
+ OP_REQUIRES(context, 0 != seq_dim, errors::InvalidArgument("0 == seq_dim"));
+ OP_REQUIRES(context, seq_dim < input.dims(),
+ errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
+ seq_dim, " vs. ", input.dims(), ")"));
+
+ OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(0),
+ errors::InvalidArgument("len(seq_lens) != input.dims(", 0, "), ",
+ "(", seq_lens.NumElements(), " vs. ",
+ input.dim_size(seq_dim)));
+
+ for (int d = 0; d < seq_lens_vec.size(); ++d) {
+ OP_REQUIRES(context, seq_lens_vec[d] >= 0,
+ errors::InvalidArgument("seq_lens(", d, ") < 0"));
+ OP_REQUIRES(context, seq_lens_vec[d] <= input.dim_size(seq_dim),
+ errors::InvalidArgument("seq_lens(", d, ") > input.dims(",
+ seq_dim, ")"));
+ }
+}
+
+template <>
+void CheckErrors<GPUDevice>(OpKernelContext* context, int seq_dim) {
+ const Tensor& input = context->input(0);
+ const Tensor& seq_lens = context->input(1);
+
+ OP_REQUIRES(context, 0 != seq_dim, errors::InvalidArgument("0 == seq_dim"));
+ OP_REQUIRES(context, seq_dim < input.dims(),
+ errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
+ seq_dim, " vs. ", input.dims(), ")"));
+
+ OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(0),
+ errors::InvalidArgument("len(seq_lens) != input.dims(", 0, "), ",
+ "(", seq_lens.NumElements(), " vs. ",
+ input.dim_size(seq_dim)));
+}
+
+template <typename Device, typename T>
+class ReverseSequenceOp : public OpKernel {
+ public:
+ explicit ReverseSequenceOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& seq_lens = context->input(1);
+
+ // Preliminary validation of sizes.
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens.shape()),
+ errors::InvalidArgument("seq_lens input must be 1-dim, not ",
+ seq_lens.dims()));
+
+ auto seq_lens_t = seq_lens.vec<int64>();
+
+ CheckErrors<Device>(context, seq_dim_);
+
+ const int input_dims = input.dims();
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+
+#define HANDLE_DIM(NDIM) \
+ case NDIM: \
+ functor::ReverseSequence<Device, T, NDIM>::Compute( \
+ context->eigen_device<Device>(), input.tensor<T, NDIM>(), seq_dim_, \
+ seq_lens_t, output->tensor<T, NDIM>()); \
+ break;
+
+ switch (input_dims) {
+ HANDLE_DIM(2);
+ HANDLE_DIM(3);
+ HANDLE_DIM(4);
+ HANDLE_DIM(5);
+
+ default:
+ OP_REQUIRES(context, false,
+ errors::InvalidArgument(
+ "ReverseSequenceOp : Unhandled input dimensions: ",
+ input_dims));
+ }
+ }
+
+ private:
+ int32 seq_dim_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ReverseSequenceOp);
+};
+
+#define REGISTER_REVERSE_SEQUENCE(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReverseSequence").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReverseSequenceOp<CPUDevice, type>);
+
+TF_CALL_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE);
+
+#if GOOGLE_CUDA
+
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T, Dims) \
+ template <> \
+ void ReverseSequence<GPUDevice, T, Dims>::Compute( \
+ const GPUDevice& d, typename TTypes<T, Dims>::ConstTensor input, \
+ int32 seq_dim, TTypes<int64>::ConstVec seq_lens, \
+ typename TTypes<T, Dims>::Tensor output); \
+ extern template struct ReverseSequence<GPUDevice, T, Dims>;
+
+#define DECLARE_GPU_SPECS(T) \
+ DECLARE_GPU_SPEC(T, 2); \
+ DECLARE_GPU_SPEC(T, 3); \
+ DECLARE_GPU_SPEC(T, 4); \
+ DECLARE_GPU_SPEC(T, 5);
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
+
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER_REVERSE_SEQUENCE_GPU(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReverseSequence").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ ReverseSequenceOp<GPUDevice, type>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_REVERSE_SEQUENCE_GPU);
+
+#undef REGISTER_REVERSE_SEQUENCE_GPU
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reverse_sequence_op.h b/tensorflow/core/kernels/reverse_sequence_op.h
new file mode 100644
index 0000000000..d1dd572dcb
--- /dev/null
+++ b/tensorflow/core/kernels/reverse_sequence_op.h
@@ -0,0 +1,56 @@
+#ifndef TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_
+#define TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_
+// Generator definition for ReverseSequenceOp, must be compilable by nvcc.
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+namespace generator {
+
+template <typename T, size_t Dims>
+class ReverseGenerator {
+ public:
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+ ReverseGenerator(typename TTypes<T, Dims>::ConstTensor input, int32 seq_dim,
+ TTypes<int64>::ConstVec seq_lengths)
+ : input_(input), seq_dim_(seq_dim), seq_lengths_(seq_lengths) {}
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
+ operator()(const Eigen::array<Eigen::DenseIndex, Dims>& coords) const {
+ Eigen::array<Eigen::DenseIndex, Dims> new_coords = coords;
+ if (coords[seq_dim_] < seq_lengths_(coords[0])) {
+ new_coords[seq_dim_] = seq_lengths_(coords[0]) - coords[seq_dim_] - 1;
+ }
+
+ return input_(new_coords);
+ }
+
+ private:
+ typename TTypes<T, Dims>::ConstTensor input_;
+ int32 seq_dim_;
+ TTypes<int64>::ConstVec seq_lengths_;
+};
+
+} // namespace generator
+
+namespace functor {
+
+template <typename Device, typename T, size_t Dims>
+struct ReverseSequence {
+ EIGEN_ALWAYS_INLINE static void Compute(
+ const Device& d, typename TTypes<T, Dims>::ConstTensor input,
+ int32 seq_dim, TTypes<int64>::ConstVec seq_lengths,
+ typename TTypes<T, Dims>::Tensor output) {
+ generator::ReverseGenerator<T, Dims> generator(input, seq_dim, seq_lengths);
+ output.device(d) = input.generate(generator);
+ }
+};
+
+} // namespace functor
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_
diff --git a/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc b/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc
new file mode 100644
index 0000000000..7b5d533026
--- /dev/null
+++ b/tensorflow/core/kernels/reverse_sequence_op_gpu.cu.cc
@@ -0,0 +1,26 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/reverse_sequence_op.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#define DEFINE_GPU_SPEC(T, dims) \
+ template class generator::ReverseGenerator<T, dims>; \
+ template struct functor::ReverseSequence<GPUDevice, T, dims>;
+
+#define DEFINE_GPU_SPECS(T) \
+ DEFINE_GPU_SPEC(T, 2); \
+ DEFINE_GPU_SPEC(T, 3); \
+ DEFINE_GPU_SPEC(T, 4); \
+ DEFINE_GPU_SPEC(T, 5);
+
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/save_op.cc b/tensorflow/core/kernels/save_op.cc
new file mode 100644
index 0000000000..71a15c643e
--- /dev/null
+++ b/tensorflow/core/kernels/save_op.cc
@@ -0,0 +1,81 @@
+// See docs in ../ops/io_ops.cc
+#include "tensorflow/core/kernels/io.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/util/tensor_slice_writer.h"
+
+namespace tensorflow {
+
+class SaveOp : public OpKernel {
+ public:
+ explicit SaveOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ SaveTensors(context, &checkpoint::CreateTableTensorSliceBuilder, false);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp);
+
+class SaveSlicesOp : public OpKernel {
+ public:
+ explicit SaveSlicesOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ SaveTensors(context, &checkpoint::CreateTableTensorSliceBuilder, true);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("SaveSlices").Device(DEVICE_CPU), SaveSlicesOp);
+
+class ShardedFilenameOp : public OpKernel {
+ public:
+ explicit ShardedFilenameOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ static const char* input_names[3] = {"basename", "shard", "num_shards"};
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(ctx->input(i).shape()),
+ errors::InvalidArgument(
+ input_names[i], " must be a scalar, got shape ",
+ ctx->input(i).shape().ShortDebugString()));
+ }
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
+ out->scalar<string>()() = strings::Printf(
+ "%s-%05d-of-%05d", ctx->input(0).scalar<string>()().c_str(),
+ ctx->input(1).scalar<int32>()(), ctx->input(2).scalar<int32>()());
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ShardedFilename").Device(DEVICE_CPU),
+ ShardedFilenameOp);
+
+class ShardedFilespecOp : public OpKernel {
+ public:
+ explicit ShardedFilespecOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ static const char* input_names[2] = {"basename", "num_shards"};
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(ctx->input(i).shape()),
+ errors::InvalidArgument(
+ input_names[i], " must be a scalar, got shape ",
+ ctx->input(i).shape().ShortDebugString()));
+ }
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
+ out->scalar<string>()() = strings::Printf(
+ "%s-\?\?\?\?\?-of-%05d", ctx->input(0).scalar<string>()().c_str(),
+ ctx->input(1).scalar<int32>()());
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("ShardedFilespec").Device(DEVICE_CPU),
+ ShardedFilespecOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/save_op_test.cc b/tensorflow/core/kernels/save_op_test.cc
new file mode 100644
index 0000000000..ee1ba492a6
--- /dev/null
+++ b/tensorflow/core/kernels/save_op_test.cc
@@ -0,0 +1,443 @@
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/tensor_slice_reader.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+class SaveOpTest : public OpsTestBase {
+ protected:
+ void MakeOp() {
+ RequireDefaultOps();
+ ASSERT_OK(NodeDefBuilder("myop", "Save")
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Input(FakeInput(
+ {DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8, DT_QINT32}))
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(SaveOpTest, Simple) {
+ const string filename = io::JoinPath(testing::TmpDir(), "tensor_simple");
+ const string tensornames[] = {"tensor_int", "tensor_float", "tensor_double",
+ "tensor_qint8", "tensor_qint32"};
+
+ MakeOp();
+ // Add a file name
+ AddInput<string>(TensorShape({}),
+ [&filename](int x) -> string { return filename; });
+
+ // Add the tensor names
+ AddInput<string>(TensorShape({5}),
+ [&tensornames](int x) -> string { return tensornames[x]; });
+
+ // Add a 1-d integer tensor
+ AddInput<int32>(TensorShape({10}), [](int x) -> int32 { return x + 1; });
+
+ // Add a 2-d float tensor
+ AddInput<float>(TensorShape({2, 4}),
+ [](int x) -> float { return static_cast<float>(x) / 10; });
+
+ // Add a 2-d double tensor
+ AddInput<double>(TensorShape({2, 4}),
+ [](int x) -> double { return static_cast<double>(x) / 20; });
+
+ // Add a 2-d qint8 tensor
+ AddInput<qint8>(TensorShape({3, 2}),
+ [](int x) -> qint8 { return *reinterpret_cast<qint8*>(&x); });
+
+ // Add a 2-d qint32 tensor
+ AddInput<qint32>(TensorShape({2, 3}), [](int x) -> qint32 {
+ return *reinterpret_cast<qint32*>(&x) * qint8(2);
+ });
+
+ ASSERT_OK(RunOpKernel());
+
+ // Check that the checkpoint file is properly written
+ checkpoint::TensorSliceReader reader(filename,
+ checkpoint::OpenTableTensorSliceReader);
+ EXPECT_OK(reader.status());
+
+ // We expect to find all saved tensors
+ {
+ // The 1-d integer tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_int", &shape, &type));
+ TensorShape expected({10});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_INT32, type);
+
+ // We expect the tensor value to be correct.
+ TensorSlice s = TensorSlice::ParseOrDie("-");
+ int data[10];
+ std::fill_n(data, 10, 0);
+ EXPECT_TRUE(reader.CopySliceData("tensor_int", s, data));
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_EQ(i + 1, data[i]);
+ }
+ }
+
+ {
+ // The 2-d float tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_float", &shape, &type));
+ TensorShape expected({2, 4});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_FLOAT, type);
+
+ // We expect the tensor value to be correct.
+ TensorSlice s = TensorSlice::ParseOrDie("-:-");
+ float data[8];
+ std::fill_n(data, 8, 0);
+ EXPECT_TRUE(reader.CopySliceData("tensor_float", s, data));
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_EQ(static_cast<float>(i) / 10, data[i]);
+ }
+ }
+
+ {
+ // The 2-d double tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_double", &shape, &type));
+ TensorShape expected({2, 4});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_DOUBLE, type);
+
+ // We expect the tensor value to be correct.
+ TensorSlice s = TensorSlice::ParseOrDie("-:-");
+ double data[8];
+ std::fill_n(data, 8, 0);
+ EXPECT_TRUE(reader.CopySliceData("tensor_double", s, data));
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_EQ(static_cast<double>(i) / 20, data[i]);
+ }
+ }
+
+ {
+ // The 2-d qint8 tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_qint8", &shape, &type));
+ TensorShape expected({3, 2});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_QINT8, type);
+
+ // We expect the tensor value to be correct.
+ TensorSlice s = TensorSlice::ParseOrDie("-:-");
+ qint8 data[6];
+ EXPECT_TRUE(reader.CopySliceData("tensor_qint8", s, data));
+ for (int i = 0; i < 6; ++i) {
+ EXPECT_EQ(*reinterpret_cast<qint8*>(&i), data[i]);
+ }
+ }
+
+ {
+ // The 2-d qint32 tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_qint32", &shape, &type));
+ TensorShape expected({2, 3});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_QINT32, type);
+
+ // We expect the tensor value to be correct.
+ TensorSlice s = TensorSlice::ParseOrDie("-:-");
+ qint32 data[6];
+ EXPECT_TRUE(reader.CopySliceData("tensor_qint32", s, data));
+ for (int i = 0; i < 6; ++i) {
+ EXPECT_EQ(*reinterpret_cast<qint32*>(&i) * qint8(2), data[i]);
+ }
+ }
+}
+
+class SaveSlicesOpTest : public OpsTestBase {
+ protected:
+ void MakeOp() {
+ RequireDefaultOps();
+ ASSERT_OK(NodeDefBuilder("myop", "SaveSlices")
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Input(FakeInput(
+ {DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8, DT_QINT32}))
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+// Here we save only slices. We restore them in a larger tensor and we check
+// that the right slice is restored. It is quite tricky to check that the
+// right slices are actually restored so instead we just check that
+// CopySliceData() return true/false depending on the slice we ask for.
+TEST_F(SaveSlicesOpTest, Slices) {
+ const string filename = io::JoinPath(testing::TmpDir(), "tensor_slices");
+ const string tensornames[] = {"tensor_int", "tensor_float", "tensor_double",
+ "tensor_qint8", "tensor_qint32"};
+ // Specifies that the data we save are slices of larger tensors.
+ // See core/framework/tensor_slice.h for the slice syntax.
+ const string tensorshapes[] = {
+ "10 -", // Full contents of a 10 element vector.
+ "2 4 -:0,2", // A 2x2 slice of a 2x4 tensor.
+ "2 4 0,1:2,2", // A 1x2 slice of a 2x4 tensor.
+ "3 2 -:-", // Full contents of a 3x2 tensor.
+ "2 3 1,1:2,1" // Another 1x1 slice of a2x3 tensor.
+ };
+
+ MakeOp();
+ // Add a file name
+ AddInput<string>(TensorShape({}),
+ [&filename](int x) -> string { return filename; });
+
+ // Add the tensor names
+ AddInput<string>(TensorShape({5}),
+ [&tensornames](int x) -> string { return tensornames[x]; });
+
+ // Add the tensor shapes and slices
+ AddInput<string>(TensorShape({5}), [&tensorshapes](int x) -> string {
+ return tensorshapes[x];
+ });
+
+ // Add a 1-d integer tensor
+ AddInput<int32>(TensorShape({10}), [](int x) -> int32 { return x + 1; });
+
+ // Add a 2-d float tensor
+ AddInput<float>(TensorShape({2, 2}),
+ [](int x) -> float { return static_cast<float>(x) / 10; });
+
+ // Add a 2-d double tensor
+ AddInput<double>(TensorShape({1, 2}),
+ [](int x) -> double { return static_cast<double>(x) / 20; });
+
+ // Add a 2-d qint8 tensor
+ AddInput<qint8>(TensorShape({3, 2}),
+ [](int x) -> qint8 { return *reinterpret_cast<qint8*>(&x); });
+
+ // Add a 2-d qint32 tensor
+ AddInput<qint32>(TensorShape({1, 1}), [](int x) -> qint32 {
+ return *reinterpret_cast<qint32*>(&x) * qint8(2);
+ });
+
+ ASSERT_OK(RunOpKernel());
+
+ // Check that the checkpoint file is properly written
+ checkpoint::TensorSliceReader reader(filename,
+ checkpoint::OpenTableTensorSliceReader);
+ EXPECT_OK(reader.status());
+
+ // We expect to find all saved tensors
+ {
+ // The 1-d integer tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_int", &shape, &type));
+ TensorShape expected({10});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_INT32, type);
+
+ // We saved the full tensor so we should be able to read it all.
+ TensorSlice s = TensorSlice::ParseOrDie("-");
+ int data[10];
+ EXPECT_TRUE(reader.CopySliceData("tensor_int", s, data));
+ }
+
+ {
+ // The 2-d float tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_float", &shape, &type));
+ TensorShape expected({2, 4});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_FLOAT, type);
+
+ // We saved the slice "-:0,2" so we should not be able to read the full
+ // tensor.
+ TensorSlice full_slice = TensorSlice::ParseOrDie("-:-");
+ TensorSlice saved_slice = TensorSlice::ParseOrDie("-:0,2");
+ float data[8];
+ EXPECT_FALSE(reader.CopySliceData("tensor_float", full_slice, data));
+ EXPECT_TRUE(reader.CopySliceData("tensor_float", saved_slice, data));
+ }
+
+ {
+ // The 2-d double tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_double", &shape, &type));
+ TensorShape expected({2, 4});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_DOUBLE, type);
+
+ // We saved the slice "0,1:2,2" so we should not be able to read the full
+ // tensor.
+ TensorSlice full_slice = TensorSlice::ParseOrDie("-:-");
+ TensorSlice saved_slice = TensorSlice::ParseOrDie("0,1:2,2");
+ double data[8];
+ EXPECT_FALSE(reader.CopySliceData("tensor_double", full_slice, data));
+ EXPECT_TRUE(reader.CopySliceData("tensor_double", saved_slice, data));
+ }
+
+ {
+ // The 2-d qint8 tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_qint8", &shape, &type));
+ TensorShape expected({3, 2});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_QINT8, type);
+
+ // We saved the full slice.
+ TensorSlice s = TensorSlice::ParseOrDie("-:-");
+ qint8 data[6];
+ EXPECT_TRUE(reader.CopySliceData("tensor_qint8", s, data));
+ }
+
+ {
+ // The 2-d qint32 tensor
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("tensor_qint32", &shape, &type));
+ TensorShape expected({2, 3});
+ EXPECT_TRUE(shape.IsSameSize(expected));
+ EXPECT_EQ(DT_QINT32, type);
+
+ // We expect the tensor value to be correct.
+ TensorSlice s = TensorSlice::ParseOrDie("1,1:2,1");
+ TensorSlice full_slice = TensorSlice::ParseOrDie("-:-");
+ TensorSlice saved_slice = TensorSlice::ParseOrDie("1,1:2,1");
+ qint32 data[6];
+ EXPECT_FALSE(reader.CopySliceData("tensor_qint32", full_slice, data));
+ EXPECT_TRUE(reader.CopySliceData("tensor_qint32", saved_slice, data));
+ }
+}
+
+class SaveOpSlices2Test : public OpsTestBase {
+ protected:
+ void MakeOp() {
+ RequireDefaultOps();
+ ASSERT_OK(NodeDefBuilder("myop", "SaveSlices")
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Input(FakeInput({DT_INT32, DT_INT32, DT_FLOAT}))
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(SaveOpSlices2Test, TwoSlices) {
+ const string filename = io::JoinPath(testing::TmpDir(), "three_slices");
+ // We will save 2 slices of the tensor named "four_by_sixteen" which is 4x16,
+ // and one slice of the "small" tensor.
+ const string tensornames[] = {"four_by_sixteen", "four_by_sixteen", "small"};
+ const string tensorshapes[] = {
+ // Slice specifications for the 2 slices of "four_by_sixteen"
+ "4 16 0,2:-", // 1st slice covers indices 0 and 1 in the first dim.
+ "4 16 2,2:-", // 2nd slice covers indices 2 and 3 in the first dim.
+ "" // We save the full "small" tensors.
+ };
+
+ MakeOp();
+ // Add a file name
+ AddInput<string>(TensorShape({}),
+ [&filename](int x) -> string { return filename; });
+
+ // Add the tensor names
+ AddInput<string>(TensorShape({3}),
+ [&tensornames](int x) -> string { return tensornames[x]; });
+
+ // Add the tensor shapes and slices
+ AddInput<string>(TensorShape({3}), [&tensorshapes](int x) -> string {
+ return tensorshapes[x];
+ });
+
+ // Add an integer tensor for slice 0,2:- of a 4x16 tensor: It is 2x16.
+ AddInput<int32>(TensorShape({2, 16}), [](int x) -> int32 { return x + 1; });
+
+ // Add an integer tensor for slice 2,2:- of a 4x16 tensor: It is 2x16.
+ AddInput<int32>(TensorShape({2, 16}),
+ [](int x) -> int32 { return 10 * (x + 1); });
+
+ // Add a float tensor for "small"
+ AddInput<float>(TensorShape({2, 4}),
+ [](int x) -> float { return static_cast<float>(x) / 10; });
+
+ ASSERT_OK(RunOpKernel());
+
+ // Check that the checkpoint file is properly written
+ checkpoint::TensorSliceReader reader(filename,
+ checkpoint::OpenTableTensorSliceReader);
+ EXPECT_OK(reader.status());
+
+ {
+ // Reload the two slices of "four_by_sixteen" into that tensor.
+ Tensor reloaded(DT_INT32, {4, 16});
+
+ // We expect to find all slices
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("four_by_sixteen", &shape, &type));
+ EXPECT_TRUE(shape.IsSameSize(reloaded.shape()));
+ EXPECT_EQ(type, reloaded.dtype());
+
+ // Reload the whole tensor.
+ EXPECT_TRUE(reader.CopySliceData("four_by_sixteen",
+ TensorSlice(reloaded.dims()),
+ reloaded.flat<int>().data()));
+
+ {
+ auto slice = reloaded.Slice(0, 2).flat<int>();
+ for (int i = 0; i < slice.size(); ++i) {
+ EXPECT_EQ(i + 1, slice(i));
+ }
+ }
+ {
+ auto slice = reloaded.Slice(2, 4).flat<int>();
+ for (int i = 0; i < slice.size(); ++i) {
+ EXPECT_EQ(10 * (i + 1), slice(i));
+ }
+ }
+ }
+
+ {
+ // Reload the small float tensor.
+ Tensor reloaded(DT_FLOAT, {2, 4});
+
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("small", &shape, &type));
+ EXPECT_TRUE(shape.IsSameSize(reloaded.shape()));
+ EXPECT_EQ(DT_FLOAT, reloaded.dtype());
+
+ EXPECT_TRUE(reader.CopySliceData("small", TensorSlice(reloaded.dims()),
+ reloaded.flat<float>().data()));
+
+ for (int64 i = 0; i < reloaded.NumElements(); ++i) {
+ EXPECT_EQ(static_cast<float>(i) / 10, reloaded.flat<float>().data()[i]);
+ }
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc
new file mode 100644
index 0000000000..88fcc1bdcc
--- /dev/null
+++ b/tensorflow/core/kernels/scatter_op.cc
@@ -0,0 +1,167 @@
+// See docs in ../ops/state_ops.cc.
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+enum class UpdateOp { ASSIGN, ADD, SUB };
+
+template <class T, typename Index, UpdateOp op>
+class ScatterUpdateOp : public OpKernel {
+ public:
+ // QUESTION: It'd be nice to support DT_INT16, DT_UINT8,
+ // etc. here. Should we have the framework do some sort of
+ // integer promotion automatically, or should that be something
+ // that users have to do explicitly with a conversion operator
+ // in the graph?
+ explicit ScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
+ }
+
+ void Compute(OpKernelContext* c) override {
+ if (use_exclusive_lock_) {
+ // Hold mutex while we apply updates
+ mutex_lock l(*c->input_ref_mutex(0));
+ DoCompute(c);
+ } else {
+ DoCompute(c);
+ }
+ }
+
+ private:
+ bool use_exclusive_lock_;
+
+ // Check whether updates.shape = indices.shape + params.shape[1:]
+ static bool ValidShapes(const Tensor& params, const Tensor& updates,
+ const Tensor& indices) {
+ if (updates.dims() != indices.dims() + params.dims() - 1) return false;
+ for (int d = 0; d < indices.dims(); d++) {
+ if (updates.dim_size(d) != indices.dim_size(d)) {
+ return false;
+ }
+ }
+ for (int d = 1; d < params.dims(); d++) {
+ if (params.dim_size(d) != updates.dim_size(d - 1 + indices.dims())) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ void DoCompute(OpKernelContext* c) {
+ Tensor Tparams = c->mutable_input(0, use_exclusive_lock_);
+ OP_REQUIRES(c, Tparams.IsInitialized(),
+ errors::FailedPrecondition("Null ref for params"));
+ const Tensor& Tindices = c->input(1);
+ const Tensor& Tupdates = c->input(2);
+ OP_REQUIRES(
+ c, TensorShapeUtils::IsVectorOrHigher(Tparams.shape()),
+ errors::InvalidArgument("params must be at least 1-D, got shape ",
+ Tparams.shape().ShortDebugString()));
+ OP_REQUIRES(
+ c, ValidShapes(Tparams, Tupdates, Tindices),
+ errors::InvalidArgument(
+ "Must have updates.shape = indices.shape + params.shape[1:], got ",
+ "updates.shape ", Tupdates.shape().ShortDebugString(),
+ ", indices.shape ", Tindices.shape().ShortDebugString(),
+ ", params.shape ", Tparams.shape().ShortDebugString()));
+ const Index N = Tindices.NumElements();
+
+ // We always return the input ref.
+ c->forward_ref_input_to_ref_output(0, 0);
+
+ if (N > 0) {
+ const Index first_dim_size = Tparams.dim_size(0);
+ // Validate all the indices are in range
+ auto Tindices_vec = Tindices.flat<Index>();
+ for (Index i = 0; i < N; i++) {
+ const Index index = Tindices_vec(i);
+ OP_REQUIRES(c, index >= 0 && index < first_dim_size,
+ errors::InvalidArgument(
+ strings::StrCat("Index ", index, " at offset ", i,
+ " in indices is out of range")));
+ }
+ auto Tparams_flat = Tparams.flat_outer_dims<T>();
+ auto Tupdates_flat =
+ Tupdates.shaped<T, 2>({N, Tupdates.NumElements() / N});
+ for (Index i = 0; i < N; i++) {
+ // Copy last Ndim-1 dimensions of Tupdates[i] to
+ // Tparams[Tindices[i]]
+ switch (op) {
+ case UpdateOp::ASSIGN: {
+ Tparams_flat.template chip<0>(Tindices_vec(i)) =
+ Tupdates_flat.template chip<0>(i);
+ break;
+ }
+ case UpdateOp::ADD: {
+ Tparams_flat.template chip<0>(Tindices_vec(i)) +=
+ Tupdates_flat.template chip<0>(i);
+ break;
+ }
+ case UpdateOp::SUB: {
+ Tparams_flat.template chip<0>(Tindices_vec(i)) -=
+ Tupdates_flat.template chip<0>(i);
+ break;
+ }
+ }
+ }
+ }
+ }
+};
+
+#define REGISTER_SCATTER_UPDATE(type, index_type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ScatterUpdate") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ScatterUpdateOp<type, index_type, UpdateOp::ASSIGN>);
+
+#define REGISTER_SCATTER_UPDATE_INT32(type) REGISTER_SCATTER_UPDATE(type, int32)
+#define REGISTER_SCATTER_UPDATE_INT64(type) REGISTER_SCATTER_UPDATE(type, int64)
+
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_INT32);
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_INT64);
+
+#undef REGISTER_SCATTER_UPDATE_INT64
+#undef REGISTER_SCATTER_UPDATE_INT32
+#undef REGISTER_SCATTER_UPDATE
+
+#define REGISTER_SCATTER_ADD(type, index_type) \
+ REGISTER_KERNEL_BUILDER(Name("ScatterAdd") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ScatterUpdateOp<type, index_type, UpdateOp::ADD>);
+
+#define REGISTER_SCATTER_ADD_INT32(type) REGISTER_SCATTER_ADD(type, int32)
+#define REGISTER_SCATTER_ADD_INT64(type) REGISTER_SCATTER_ADD(type, int64)
+
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ADD_INT32);
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ADD_INT64);
+
+#undef REGISTER_SCATTER_ADD_INT32
+#undef REGISTER_SCATTER_ADD_INT64
+#undef REGISTER_SCATTER_ADD
+
+#define REGISTER_SCATTER_SUB(type, index_type) \
+ REGISTER_KERNEL_BUILDER(Name("ScatterSub") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ScatterUpdateOp<type, index_type, UpdateOp::SUB>);
+
+#define REGISTER_SCATTER_SUB_INT32(type) REGISTER_SCATTER_SUB(type, int32)
+#define REGISTER_SCATTER_SUB_INT64(type) REGISTER_SCATTER_SUB(type, int64)
+
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_SUB_INT32);
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_SUB_INT64);
+
+#undef REGISTER_SCATTER_SUB_INT64
+#undef REGISTER_SCATTER_SUB_INT32
+#undef REGISTER_SCATTER_SUB
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/scatter_op_test.cc b/tensorflow/core/kernels/scatter_op_test.cc
new file mode 100644
index 0000000000..8885f1edb3
--- /dev/null
+++ b/tensorflow/core/kernels/scatter_op_test.cc
@@ -0,0 +1,255 @@
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+namespace {
+
+class ScatterUpdateOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(DataType index_type) {
+ RequireDefaultOps();
+ ASSERT_OK(NodeDefBuilder("myop", "ScatterUpdate")
+ .Input(FakeInput(DT_FLOAT_REF))
+ .Input(FakeInput(index_type))
+ .Input(FakeInput(DT_FLOAT))
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(ScatterUpdateOpTest, Simple_TwoD32) {
+ MakeOp(DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2});
+ AddInputFromArray<float>(TensorShape({3, 3}),
+ {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
+ test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
+ 10002, 0, 0, 0, 777, 778, 779});
+ test::ExpectTensorEqual<float>(expected, params_tensor);
+}
+
+TEST_F(ScatterUpdateOpTest, Simple_Two64) {
+ MakeOp(DT_INT64);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int64>(TensorShape({3}), {0, 4, 2});
+ AddInputFromArray<float>(TensorShape({3, 3}),
+ {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
+ test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
+ 10002, 0, 0, 0, 777, 778, 779});
+ test::ExpectTensorEqual<float>(expected, params_tensor);
+}
+
+TEST_F(ScatterUpdateOpTest, Simple_ZeroD) {
+ MakeOp(DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {101});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
+ test::FillValues<float>(&expected, {0, 0, 0, 101, 0});
+ test::ExpectTensorEqual<float>(expected, params_tensor);
+}
+
+TEST_F(ScatterUpdateOpTest, Simple_OneD) {
+ MakeOp(DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2});
+ AddInputFromArray<float>(TensorShape({3}), {100, 101, 102});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
+ test::FillValues<float>(&expected, {100, 0, 102, 0, 101});
+ test::ExpectTensorEqual<float>(expected, params_tensor);
+}
+
+TEST_F(ScatterUpdateOpTest, HigherRank) {
+ MakeOp(DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({2, 3}), {0, 4, 2, 1, 3, 6});
+ AddInputFromArray<float>(TensorShape({2, 3}), {10, 20, 30, 40, 50, 60});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({8}));
+ test::FillValues<float>(&expected, {10, 40, 30, 50, 20, 0, 60, 0});
+ test::ExpectTensorEqual<float>(expected, params_tensor);
+}
+
+TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) {
+ MakeOp(DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 4, 99});
+ AddInputFromArray<float>(TensorShape({3, 3}),
+ {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("Index 99 at offset 2 in indices is out of range"))
+ << s;
+}
+
+TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) {
+ MakeOp(DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({1, 3}), {0, 4, 99});
+ AddInputFromArray<float>(TensorShape({3, 3}),
+ {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("Must have updates.shape = indices.shape + "
+ "params.shape[1:], got "))
+ << s;
+}
+
+TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
+ MakeOp(DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2});
+ AddInputFromArray<float>(
+ TensorShape({3, 4}),
+ {100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("Must have updates.shape = indices.shape + "
+ "params.shape[1:], got "))
+
+ << s;
+}
+
+TEST_F(ScatterUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
+ MakeOp(DT_INT32);
+
+ // Feed and run
+ AddInputFromArray<float>(TensorShape({5, 3}),
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2});
+ AddInputFromArray<float>(TensorShape({2, 3}),
+ {100, 101, 102, 10000, 10001, 10002});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString())
+ .contains("Must have updates.shape = indices.shape + "
+ "params.shape[1:], got "))
+ << s;
+}
+
+class ScatterUpdateBM : public ScatterUpdateOpTest {
+ public:
+ virtual void TestBody() {}
+ void MakeBenchmarkOp(const char* op, DataType index_type) {
+ ASSERT_OK(NodeDefBuilder("myop", op)
+ .Input(FakeInput(DT_FLOAT_REF))
+ .Input(FakeInput(index_type))
+ .Input(FakeInput(DT_FLOAT))
+ .Finalize(node_def()));
+ TF_CHECK_OK(InitOp());
+ }
+};
+
+template <typename Index>
+static void BM_ScatterHelper(int iters, int embedding_size, const char* op) {
+ testing::StopTiming();
+ const int kRows = 10000000 / embedding_size;
+ std::vector<float> values;
+ for (int i = 0; i < kRows * embedding_size; i++) {
+ values.push_back(i);
+ }
+ const int kNumUpdates = 1000;
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ std::vector<Index> indices;
+ std::vector<float> updates;
+ for (int i = 0; i < kNumUpdates; i++) {
+ indices.push_back(rnd.Uniform(kRows));
+ for (int j = 0; j < embedding_size; j++) {
+ updates.push_back(i * 10 + j);
+ }
+ }
+
+ ScatterUpdateBM bm;
+ bm.MakeBenchmarkOp(op, DataTypeToEnum<Index>::v());
+ bm.AddInputFromArray<float>(TensorShape({kRows, embedding_size}), values);
+ bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices);
+ bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}),
+ updates);
+ testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
+ iters);
+ testing::StartTiming();
+ while (iters-- > 0) {
+ Status s = bm.RunOpKernel();
+ }
+}
+
+static void BM_ScatterUpdateInt32(int iters, int embedding_size) {
+ BM_ScatterHelper<int32>(iters, embedding_size, "ScatterUpdate");
+}
+static void BM_ScatterUpdateInt64(int iters, int embedding_size) {
+ BM_ScatterHelper<int64>(iters, embedding_size, "ScatterUpdate");
+}
+
+static void BM_ScatterAddInt32(int iters, int embedding_size) {
+ BM_ScatterHelper<int32>(iters, embedding_size, "ScatterAdd");
+}
+static void BM_ScatterAddInt64(int iters, int embedding_size) {
+ BM_ScatterHelper<int64>(iters, embedding_size, "ScatterAdd");
+}
+
+BENCHMARK(BM_ScatterUpdateInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+BENCHMARK(BM_ScatterUpdateInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+
+BENCHMARK(BM_ScatterAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+BENCHMARK(BM_ScatterAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc
new file mode 100644
index 0000000000..2b6a8c5a88
--- /dev/null
+++ b/tensorflow/core/kernels/segment_reduction_ops.cc
@@ -0,0 +1,466 @@
+// See docs in ../ops/math_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+// This operator handles reducing segments along the first dimension.
+// See core/ops/math_ops.cc for more details.
+template <typename Device, class T, class Index, typename Reducer>
+class SegmentReductionOp : public OpKernel {
+ public:
+ explicit SegmentReductionOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& segment_ids = context->input(1);
+
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
+ errors::InvalidArgument("segment_ids should be a vector."));
+ const int64 num_indices = segment_ids.NumElements();
+ OP_REQUIRES(context, num_indices == input.dim_size(0),
+ errors::InvalidArgument(
+ "segment_ids should be the same size as dimension 0 of"
+ " input."));
+
+ auto input_flat = input.flat_outer_dims<T>();
+ const int64 num_col = input_flat.dimension(1);
+
+ const auto segment_vec = segment_ids.vec<Index>();
+ // Note that the current implementation assumes that segment_vec values are
+ // sorted.
+ const Index output_rows =
+ num_indices > 0 ? segment_vec(num_indices - 1) + 1 : 0;
+
+ TensorShape output_shape = input.shape();
+ output_shape.set_dim(0, output_rows);
+
+ // Note that we do not initialize the output buffer with a default value.
+ // We require that segment ids be sorted and cover all values (otherwise we
+ // return an error).
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+ auto output_flat = output->flat_outer_dims<T>();
+
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ Eigen::DSizes<Eigen::DenseIndex, 1> dims_to_reduce;
+ dims_to_reduce[0] = 0;
+#else
+ Eigen::IndexList<Eigen::type2index<0>> dims_to_reduce;
+#endif
+ Index start = 0, end = 1;
+ // TODO(agarwal): if this loop becomes a bottleneck, consider sharding it
+ // across threads.
+ Eigen::DSizes<Eigen::DenseIndex, 1> out_slice_shape(num_col);
+ while (end <= num_indices) {
+ if (end < num_indices) {
+ if (segment_vec(start) == segment_vec(end)) {
+ ++end;
+ continue;
+ }
+ // We have a new segment here. Verify that the segment ids grow by one
+ // each time, so that we cover every possible output value.
+ OP_REQUIRES(
+ context, segment_vec(start) + 1 == segment_vec(end),
+ errors::InvalidArgument("segment ids are not increasing by 1"));
+ }
+
+ // Process segment [start, end)
+ const T* in_slice_ptr = &input_flat(start, 0);
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>,
+ Eigen::Unaligned> OutT;
+ T* out_slice_ptr = &output_flat(segment_vec(start), 0);
+ OutT out_slice(out_slice_ptr, out_slice_shape);
+ // We don't use out_slice.device(context->egien_device<Device>)
+ // because these pieces of work are likely to be very small and
+ // the context switching overhead dwarfs any benefit we get from
+ // using another thread to do this work.
+ if (start == end - 1) {
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
+ Eigen::Unaligned> InT;
+ InT in_slice(in_slice_ptr, out_slice_shape);
+ out_slice = in_slice;
+ } else {
+ Eigen::DSizes<Eigen::DenseIndex, 2> in_slice_shape(end - start,
+ num_col);
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned> InT;
+ InT in_slice(in_slice_ptr, in_slice_shape);
+
+ out_slice = in_slice.reduce(dims_to_reduce, Reducer());
+ }
+ start = end;
+ ++end;
+ }
+ }
+};
+
+#define REGISTER_CPU_KERNELS(type, index_type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SegmentSum") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ SegmentReductionOp<CPUDevice, type, index_type, \
+ Eigen::internal::SumReducer<type>>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SegmentMean") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ SegmentReductionOp<CPUDevice, type, index_type, \
+ Eigen::internal::MeanReducer<type>>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SegmentProd") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ SegmentReductionOp<CPUDevice, type, index_type, \
+ Eigen::internal::ProdReducer<type>>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SegmentMin") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ SegmentReductionOp<CPUDevice, type, index_type, \
+ Eigen::internal::MinReducer<type>>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SegmentMax") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ SegmentReductionOp<CPUDevice, type, index_type, \
+ Eigen::internal::MaxReducer<type>>);
+
+#define REGISTER_CPU_KERNELS_ALL(type) \
+ REGISTER_CPU_KERNELS(type, int32); \
+ REGISTER_CPU_KERNELS(type, int64);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS_ALL);
+#undef REGISTER_CPU_KERNELS
+#undef REGISTER_CPU_KERNELS_ALL
+
+// Similar to SegmentReductionOp but can handle unsorted segment definitions and
+// specifying size of output.
+template <typename Device, class T, class Index>
+class UnsortedSegmentSumOp : public OpKernel {
+ public:
+ explicit UnsortedSegmentSumOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& data = context->input(0);
+ const Tensor& segment_ids = context->input(1);
+ const Tensor& num_segments = context->input(2);
+
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsLegacyScalar(num_segments.shape()),
+ errors::InvalidArgument("num_segments should be a scalar, not shape ",
+ num_segments.shape().ShortDebugString()));
+
+ OP_REQUIRES(context,
+ TensorShapeUtils::StartsWith(data.shape(), segment_ids.shape()),
+ errors::InvalidArgument(
+ "data.shape = ", data.shape().ShortDebugString(),
+ " does not start with segment_ids.shape = ",
+ segment_ids.shape().ShortDebugString()));
+
+ const auto segment_flat = segment_ids.flat<Index>();
+ const int32 N = segment_flat.dimension(0);
+ const int32 output_rows = num_segments.scalar<int32>()();
+
+ if (N > 0) {
+ Eigen::Tensor<Index, 0, Eigen::RowMajor> m = segment_flat.maximum();
+ OP_REQUIRES(
+ context, m() < output_rows,
+ errors::InvalidArgument("More segments found than output size"));
+ }
+
+ TensorShape output_shape;
+ output_shape.AddDim(output_rows);
+ for (int i = segment_ids.dims(); i < data.dims(); i++) {
+ output_shape.AddDim(data.dim_size(i));
+ }
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+ auto output_flat = output->flat_outer_dims<T>();
+ output_flat.setZero();
+
+ if (data.NumElements() > 0) {
+ auto data_flat = data.shaped<T, 2>({N, data.NumElements() / N});
+ for (int i = 0; i < N; ++i) {
+ output_flat.template chip<0>(segment_flat(i)) +=
+ data_flat.template chip<0>(i);
+ }
+ }
+ }
+};
+
+#define REGISTER_CPU_UNSORTED_KERNELS(type, index_type) \
+ REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ UnsortedSegmentSumOp<CPUDevice, type, index_type>);
+
+#define REGISTER_CPU_UNSORTED_KERNELS_ALL(type) \
+ REGISTER_CPU_UNSORTED_KERNELS(type, int32); \
+ REGISTER_CPU_UNSORTED_KERNELS(type, int64);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_UNSORTED_KERNELS_ALL);
+#undef REGISTER_CPU_UNSORTED_KERNELS
+#undef REGISTER_CPU_UNSORTED_KERNELS_ALL
+
+// Same as SegmentReductionOp but takes as input a "sparse" tensor, represented
+// by two dense tensors, one containing the data, and the other containing
+// indices into the data.
+template <typename Device, class T>
+class SparseSegmentReductionOpBase : public OpKernel {
+ public:
+ explicit SparseSegmentReductionOpBase(OpKernelConstruction* context,
+ bool is_mean)
+ : OpKernel(context), is_mean_(is_mean) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& indices = context->input(1);
+ const Tensor& segment_ids = context->input(2);
+
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()),
+ errors::InvalidArgument("indices should be a vector."));
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
+ errors::InvalidArgument("segment_ids should be a vector."));
+
+ const int32 num_indices = indices.NumElements();
+ OP_REQUIRES(context, num_indices == segment_ids.NumElements(),
+ errors::InvalidArgument(
+ "segment_ids and indices should have same size."));
+
+ auto input_flat = input.flat_outer_dims<T>();
+
+ const auto indices_vec = indices.vec<int32>();
+ const auto segment_vec = segment_ids.vec<int32>();
+ // Note that the current implementation assumes that segment_vec values are
+ // sorted.
+ const int32 output_rows =
+ num_indices > 0 ? segment_vec(num_indices - 1) + 1 : 0;
+
+ TensorShape output_shape = input.shape();
+ output_shape.set_dim(0, output_rows);
+
+ // Note that we do not initialize the output buffer with a default value.
+ // We require that segment ids be sorted and cover all values (otherwise we
+ // return an error).
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+ if (num_indices == 0) return;
+ auto output_flat = output->flat_outer_dims<T>();
+
+ int32 start = 0, end = 1;
+ while (end <= num_indices) {
+ if (end < num_indices) {
+ if (segment_vec(start) == segment_vec(end)) {
+ ++end;
+ continue;
+ }
+ // We have a new segment here. Verify that the segment ids grow by one
+ // each time, so that we cover every possible output value.
+ OP_REQUIRES(
+ context, segment_vec(start) + 1 == segment_vec(end),
+ errors::InvalidArgument("segment ids are not increasing by 1"));
+ }
+
+ auto out = output_flat.template chip<0>(segment_vec(start));
+#define I(i) input_flat.template chip<0>(indices_vec(start + i))
+ int num = end - start;
+ if (num == 1) {
+ out = I(0);
+ } else {
+ int r = num % 8;
+ T m = (is_mean_ && (num < 10)) ? num : 1;
+ switch (r) {
+ case 2:
+ out = (I(0) + I(1)) / m;
+ break;
+ case 3:
+ out = (I(0) + I(1) + I(2)) / m;
+ break;
+ case 4:
+ out = (I(0) + I(1) + I(2) + I(3)) / m;
+ break;
+ case 5:
+ out = (I(0) + I(1) + I(2) + I(3) + I(4)) / m;
+ break;
+ case 6:
+ out = (I(0) + I(1) + I(2) + I(3) + I(4) + I(5)) / m;
+ break;
+ case 7:
+ out = (I(0) + I(1) + I(2) + I(3) + I(4) + I(5) + I(6)) / m;
+ break;
+ case 0:
+ out = (I(0) + I(1) + I(2) + I(3) + I(4) + I(5) + I(6) + I(7)) / m;
+ r = 8;
+ break;
+ case 1:
+ out =
+ (I(0) + I(1) + I(2) + I(3) + I(4) + I(5) + I(6) + I(7) + I(8)) /
+ m;
+ r = 9;
+ break;
+ }
+ for (; r < num; r += 8) {
+ out += I(r) + I(r + 1) + I(r + 2) + I(r + 3) + I(r + 4) + I(r + 5) +
+ I(r + 6) + I(r + 7);
+ }
+#undef I
+ if (is_mean_ && num >= 10) {
+ out = out / static_cast<T>(num);
+ }
+ }
+ start = end;
+ ++end;
+ }
+ }
+
+ private:
+ bool is_mean_;
+};
+
+template <typename Device, class T>
+class SparseSegmentReductionMeanOp
+ : public SparseSegmentReductionOpBase<Device, T> {
+ public:
+ explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context)
+ : SparseSegmentReductionOpBase<Device, T>(context, true /*is_mean*/) {}
+};
+
+template <typename Device, class T>
+class SparseSegmentReductionSumOp
+ : public SparseSegmentReductionOpBase<Device, T> {
+ public:
+ explicit SparseSegmentReductionSumOp(OpKernelConstruction* context)
+ : SparseSegmentReductionOpBase<Device, T>(context, false /*is_mean*/) {}
+};
+
+#define REGISTER_CPU_SPARSE_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseSegmentSum").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ SparseSegmentReductionSumOp<CPUDevice, type>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS);
+#undef REGISTER_CPU_SPARSE_KERNELS
+
+#define REGISTER_CPU_SPARSE_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseSegmentMean").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ SparseSegmentReductionMeanOp<CPUDevice, type>);
+REGISTER_CPU_SPARSE_KERNELS(float);
+REGISTER_CPU_SPARSE_KERNELS(double);
+#undef REGISTER_CPU_SPARSE_KERNELS
+
+template <class T>
+class SparseSegmentMeanGradOp : public OpKernel {
+ public:
+ explicit SparseSegmentMeanGradOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& indices = context->input(1);
+ const Tensor& segment_ids = context->input(2);
+ const Tensor& output_dim0 = context->input(3);
+
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()),
+ errors::InvalidArgument("indices should be a vector."));
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
+ errors::InvalidArgument("segment_ids should be a vector."));
+ OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(output_dim0.shape()),
+ errors::InvalidArgument("output_dim0 should be a scalar."));
+
+ const int64 N = indices.NumElements();
+ OP_REQUIRES(context, N == segment_ids.NumElements(),
+ errors::InvalidArgument(
+ "segment_ids and indices should have same size."));
+ const int32 M = output_dim0.scalar<int32>()();
+
+ auto input_flat = input.flat_outer_dims<T>();
+ const auto indices_vec = indices.vec<int32>();
+ const auto segment_vec = segment_ids.vec<int32>();
+
+ TensorShape output_shape = input.shape();
+ output_shape.set_dim(0, M);
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+ if (M == 0 || N == 0) return;
+
+ // Note that similar to SparseSegmentMean, we assume that segment_vec is
+ // already sorted and has non-negative values.
+ int num_segments = segment_vec(N - 1) + 1;
+ OP_REQUIRES(context, input.dim_size(0) == num_segments,
+ errors::InvalidArgument("Invalid number of segments"));
+
+ // Compute scaling factors for input.
+ std::vector<double> scaling(num_segments, 0.0);
+ for (int64 i = 0; i < N; ++i) {
+ scaling[segment_vec(i)] += 1;
+ }
+ for (int i = 0; i < scaling.size(); ++i) {
+ scaling[i] = 1.0 / std::max(scaling[i], 1.0);
+ }
+
+ auto output_flat = output->flat_outer_dims<T>();
+ output_flat.setZero();
+ std::vector<bool> is_modified(M, false);
+
+ for (int64 i = 0; i < N; ++i) {
+ int output_idx = indices_vec(i);
+ int idx = segment_vec(i);
+ T scale = static_cast<T>(scaling[idx]);
+ if (is_modified[output_idx]) {
+ if (scale == 1.0) {
+ output_flat.template chip<0>(output_idx) +=
+ input_flat.template chip<0>(idx);
+ } else {
+ output_flat.template chip<0>(output_idx) +=
+ input_flat.template chip<0>(idx) * scale;
+ }
+ } else {
+ if (scale == 1.0) {
+ output_flat.template chip<0>(output_idx) =
+ input_flat.template chip<0>(idx);
+ } else {
+ output_flat.template chip<0>(output_idx) =
+ input_flat.template chip<0>(idx) * scale;
+ }
+ }
+ is_modified[output_idx] = true;
+ }
+ }
+};
+
+#define REGISTER_CPU_SPARSE_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("SparseSegmentMeanGrad") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T"), \
+ SparseSegmentMeanGradOp<type>);
+
+REGISTER_CPU_SPARSE_KERNELS(float);
+REGISTER_CPU_SPARSE_KERNELS(double);
+
+#undef REGISTER_CPU_SPARSE_KERNELS
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/segment_reduction_ops_test.cc b/tensorflow/core/kernels/segment_reduction_ops_test.cc
new file mode 100644
index 0000000000..87647a21a8
--- /dev/null
+++ b/tensorflow/core/kernels/segment_reduction_ops_test.cc
@@ -0,0 +1,157 @@
+#include <functional>
+
+#include "tensorflow/core/public/session_options.h"
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+
+namespace tensorflow {
+
+template <typename Index>
+static void BM_SegmentReduction(int iters, string reduction, Index num_rows,
+ Index num_cols, Index segment_size) {
+ testing::StopTiming();
+ std::unique_ptr<Device> device(
+ DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
+
+ // Create inputs
+ gtl::InlinedVector<TensorValue, 4> reduction_inputs;
+ TensorShape shape1({num_rows, num_cols});
+ Tensor input1(DT_FLOAT, shape1);
+ reduction_inputs.push_back({nullptr, &input1});
+
+ TensorShape shape2({num_rows});
+ Tensor input2(DataTypeToEnum<Index>::v(), shape2);
+ test::FillFn<Index>(&input2, [&num_rows, &segment_size](Index i) -> Index {
+ return std::min(i / segment_size, num_rows - 1);
+ });
+ reduction_inputs.push_back({nullptr, &input2});
+
+ NodeDef reduction_node_def;
+ TF_CHECK_OK(NodeDefBuilder(reduction, reduction)
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DataTypeToEnum<Index>::v()))
+ .Finalize(&reduction_node_def));
+ Status status;
+ std::unique_ptr<OpKernel> reduction_op(CreateOpKernel(
+ DEVICE_CPU, device.get(), cpu_allocator(), reduction_node_def, &status));
+ OpKernelContext::Params params;
+ params.device = device.get();
+ params.frame_iter = FrameAndIter(0, 0);
+ params.inputs = &reduction_inputs;
+ params.op_kernel = reduction_op.get();
+ params.output_alloc_attr = [&device, &reduction_op, &params](int index) {
+ AllocatorAttributes attr;
+ const bool on_host =
+ (reduction_op->output_memory_types()[index] == HOST_MEMORY);
+ attr.set_on_host(on_host);
+ return attr;
+ };
+
+ std::unique_ptr<OpKernelContext> reduction_context(
+ new OpKernelContext(params));
+
+ reduction_op->Compute(reduction_context.get());
+ TF_CHECK_OK(reduction_context->status());
+ testing::StartTiming();
+ for (int i = 0; i < iters; ++i) {
+ delete reduction_context->release_output(0).tensor;
+ reduction_op->Compute(reduction_context.get());
+ }
+ int64 bytes_per_iter =
+ static_cast<int64>(num_rows * num_cols * sizeof(float));
+ testing::BytesProcessed(bytes_per_iter * iters);
+}
+
+#define BM_Reduce(O, R, C, S) \
+ static void BM_Reduce_##O##_##R##_##C##_##S##_int32(int iters) { \
+ BM_SegmentReduction<int32>(iters, #O, R, C, S); \
+ } \
+ static void BM_Reduce_##O##_##R##_##C##_##S##_int64(int iters) { \
+ BM_SegmentReduction<int64>(iters, #O, R, C, S); \
+ } \
+ BENCHMARK(BM_Reduce_##O##_##R##_##C##_##S##_int32); \
+ BENCHMARK(BM_Reduce_##O##_##R##_##C##_##S##_int64);
+
+#define BM_Reduce_Arg(R, C, S) \
+ BM_Reduce(SegmentSum, R, C, S); \
+ BM_Reduce(SegmentMean, R, C, S);
+
+BM_Reduce_Arg(64, 32, 1);
+BM_Reduce_Arg(4096, 128, 1);
+
+BM_Reduce_Arg(16, 8, 2);
+BM_Reduce_Arg(64, 32, 2);
+BM_Reduce_Arg(4096, 32, 2);
+BM_Reduce_Arg(4096, 128, 2);
+
+static void SparseSegmentMeanGradHelper(int iters, float uniqueness, int size) {
+ testing::StopTiming();
+ RequireDefaultOps();
+ Graph* g = new Graph(OpRegistry::Global());
+ CHECK_LE(uniqueness, 1.0);
+ CHECK_GT(uniqueness, 0.0);
+
+ const int kNumIndices = size;
+ Tensor indices(DT_INT32, TensorShape({kNumIndices}));
+ auto indices_flat = indices.flat<int32>();
+ Tensor segments(DT_INT32, TensorShape({kNumIndices}));
+ auto segments_flat = segments.flat<int32>();
+
+ int kUniqueIndices = uniqueness * kNumIndices;
+ Tensor output_dim0(DT_INT32, TensorShape({}));
+ output_dim0.scalar<int32>()() = kUniqueIndices;
+
+ for (int i = 0; i < kNumIndices; ++i) {
+ indices_flat(i) = (i * 31) % kUniqueIndices;
+ segments_flat(i) = i * .8;
+ }
+
+ const int kDim1 = segments_flat(kNumIndices - 1) + 1;
+ const int kDim2 = 128;
+ Tensor input(DT_FLOAT, TensorShape({kDim1, kDim2}));
+ input.flat<float>().setRandom();
+
+ Node* node;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "SparseSegmentMeanGrad")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, indices))
+ .Input(test::graph::Constant(g, segments))
+ .Input(test::graph::Constant(g, output_dim0))
+ .Attr("T", DT_FLOAT)
+ .Finalize(g, &node));
+
+ testing::UseRealTime();
+ testing::BytesProcessed(static_cast<int64>(iters) * (kDim1 * kDim2) *
+ sizeof(float));
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+static void BM_SparseSegmentMeanGrad_Low(int iters, int size) {
+ return SparseSegmentMeanGradHelper(iters, 1.0, size);
+}
+
+static void BM_SparseSegmentMeanGrad_High(int iters, int size) {
+ return SparseSegmentMeanGradHelper(iters, 0.01, size);
+}
+
+BENCHMARK(BM_SparseSegmentMeanGrad_Low)->Arg(1000)->Arg(100000);
+BENCHMARK(BM_SparseSegmentMeanGrad_High)->Arg(1000)->Arg(100000);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc
new file mode 100644
index 0000000000..2abb183d1a
--- /dev/null
+++ b/tensorflow/core/kernels/sendrecv_ops.cc
@@ -0,0 +1,116 @@
+#include "tensorflow/core/kernels/sendrecv_ops.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+static string GetRendezvousKeyPrefix(const string& send_device,
+ const string& recv_device,
+ const uint64 send_device_incarnation,
+ const string& tensor_name) {
+ return strings::StrCat(send_device, ";",
+ strings::FpToString(send_device_incarnation), ";",
+ recv_device, ";", tensor_name);
+}
+
+static string GetRendezvousKey(const string& key_prefix,
+ const FrameAndIter& frame_iter) {
+ return strings::StrCat(key_prefix, ";", frame_iter.frame_id, ":",
+ frame_iter.iter_id);
+}
+
+SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string send_device;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device));
+ string recv_device;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device));
+ uint64 send_device_incarnation;
+ OP_REQUIRES_OK(
+ ctx, ctx->GetAttr("send_device_incarnation",
+ reinterpret_cast<int64*>(&send_device_incarnation)));
+ string tensor_name;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
+ key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
+ send_device_incarnation, tensor_name);
+}
+
+void SendOp::Compute(OpKernelContext* ctx) {
+ OP_REQUIRES(
+ ctx, ctx->rendezvous() != nullptr,
+ errors::Internal("Op kernel context needs to provide a rendezvous."));
+ const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter());
+ VLOG(2) << "Send " << key;
+
+ // The device context may be passed between the Send/Recv
+ // boundary, so that the device context used to produce the Tensor
+ // is used when performing the copy on the recv side (which may be
+ // a different device).
+ Rendezvous::Args args;
+ args.device_context = ctx->op_device_context();
+ args.alloc_attrs = ctx->input_alloc_attr(0);
+ Status s =
+ ctx->rendezvous()->Send(key, args, ctx->input(0), ctx->is_input_dead());
+ ctx->SetStatus(s);
+}
+
+REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
+REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_GPU), SendOp);
+
+REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_CPU), SendOp);
+REGISTER_KERNEL_BUILDER(
+ Name("_HostSend").Device(DEVICE_GPU).HostMemory("tensor"), SendOp);
+
+RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
+ string send_device;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device));
+ string recv_device;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device));
+ uint64 send_device_incarnation;
+ OP_REQUIRES_OK(
+ ctx, ctx->GetAttr("send_device_incarnation",
+ reinterpret_cast<int64*>(&send_device_incarnation)));
+ string tensor_name;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
+ key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
+ send_device_incarnation, tensor_name);
+}
+
+void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
+ OP_REQUIRES(
+ ctx, ctx->rendezvous() != nullptr,
+ errors::Internal("Op kernel context needs to provide a rendezvous."));
+ const string key = GetRendezvousKey(key_prefix_, ctx->frame_iter());
+ VLOG(2) << "Recv " << key;
+
+ Rendezvous::Args args;
+ args.device_context = ctx->op_device_context();
+ args.alloc_attrs = ctx->output_alloc_attr(0);
+ ctx->rendezvous()->RecvAsync(
+ key, args, [ctx, done](const Status& s, const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& val, bool is_dead) {
+ ctx->SetStatus(s);
+ if (s.ok()) {
+ // 'ctx' allocates the output tensor of the expected type. The
+ // runtime checks whether the tensor received here is the same type.
+ if (!is_dead) {
+ ctx->set_output(0, val);
+ }
+ *ctx->is_output_dead() = is_dead;
+ }
+ done();
+ });
+}
+
+REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp);
+REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_GPU), RecvOp);
+
+REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_CPU), RecvOp);
+REGISTER_KERNEL_BUILDER(
+ Name("_HostRecv").Device(DEVICE_GPU).HostMemory("tensor"), RecvOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/sendrecv_ops.h b/tensorflow/core/kernels/sendrecv_ops.h
new file mode 100644
index 0000000000..b3f5703ccf
--- /dev/null
+++ b/tensorflow/core/kernels/sendrecv_ops.h
@@ -0,0 +1,32 @@
+#ifndef TENSORFLOW_KERNELS_SENDRECV_OPS_H_
+#define TENSORFLOW_KERNELS_SENDRECV_OPS_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+class SendOp : public OpKernel {
+ public:
+ explicit SendOp(OpKernelConstruction* ctx);
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ string key_prefix_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SendOp);
+};
+
+class RecvOp : public AsyncOpKernel {
+ public:
+ explicit RecvOp(OpKernelConstruction* ctx);
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
+
+ private:
+ string key_prefix_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RecvOp);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_SENDRECV_OPS_H_
diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc
new file mode 100644
index 0000000000..60ba2e15f9
--- /dev/null
+++ b/tensorflow/core/kernels/sequence_ops.cc
@@ -0,0 +1,123 @@
+// See docs in ../ops/math_ops.cc.
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+int32 GetValue(int32 v) { return v; }
+
+template <typename T>
+class RangeOp : public OpKernel {
+ public:
+ explicit RangeOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& start_in = context->input(0);
+ const Tensor& limit_in = context->input(1);
+ const Tensor& delta_in = context->input(2);
+ OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(start_in.shape()),
+ errors::InvalidArgument("start must be a scalar, not shape ",
+ start_in.shape().ShortDebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(limit_in.shape()),
+ errors::InvalidArgument("limit must be a scalar, not shape ",
+ limit_in.shape().ShortDebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(delta_in.shape()),
+ errors::InvalidArgument("delta must be a scalar, not shape ",
+ delta_in.shape().ShortDebugString()));
+ const int32 start = GetValue(start_in.scalar<T>()());
+ const int32 limit = GetValue(limit_in.scalar<T>()());
+ OP_REQUIRES(context, start <= limit,
+ errors::InvalidArgument("Requires start <= limit: ", start, "/",
+ limit));
+ const int32 delta = GetValue(delta_in.scalar<T>()());
+ OP_REQUIRES(context, delta > 0,
+ errors::InvalidArgument("Requires delta > 0: ", delta));
+ int32 size = (limit - start + delta - 1) / delta;
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({size}), &out));
+ auto flat = out->flat<T>();
+ int32 val = start;
+ for (int32 i = 0; i < size; ++i) {
+ flat(i) = T(val);
+ val += delta;
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("Range")
+ .Device(DEVICE_CPU)
+ .HostMemory("start")
+ .HostMemory("limit")
+ .HostMemory("delta")
+ .HostMemory("output"),
+ RangeOp<int32>);
+
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("Range")
+ .Device(DEVICE_GPU)
+ .HostMemory("start")
+ .HostMemory("limit")
+ .HostMemory("delta")
+ .HostMemory("output"),
+ RangeOp<int32>);
+#endif // GOOGLE_CUDA
+
+template <typename T>
+class LinSpaceOp : public OpKernel {
+ public:
+ explicit LinSpaceOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& start_in = context->input(0);
+ const Tensor& stop_in = context->input(1);
+ const Tensor& num_in = context->input(2);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(start_in.shape()),
+ errors::InvalidArgument("start must be a scalar, not shape ",
+ start_in.shape().ShortDebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(stop_in.shape()),
+ errors::InvalidArgument("stop must be a scalar, not shape ",
+ stop_in.shape().ShortDebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_in.shape()),
+ errors::InvalidArgument("num must be a scalar, not shape ",
+ num_in.shape().ShortDebugString()));
+ const T start = start_in.scalar<T>()();
+ const T stop = stop_in.scalar<T>()();
+ const int32 num = num_in.scalar<int32>()();
+ OP_REQUIRES(context, num > 0,
+ errors::InvalidArgument("Requires num > 0: ", num));
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({num}), &out));
+ auto flat = out->flat<T>();
+ if (num == 1) {
+ flat(0) = start;
+ } else {
+ const T step = (stop - start) / (num - 1);
+ for (int32 i = 0; i < num; ++i) flat(i) = start + step * i;
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("LinSpace")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .HostMemory("start")
+ .HostMemory("stop")
+ .HostMemory("num")
+ .HostMemory("output"),
+ LinSpaceOp<float>);
+REGISTER_KERNEL_BUILDER(Name("LinSpace")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<double>("T")
+ .HostMemory("start")
+ .HostMemory("stop")
+ .HostMemory("num")
+ .HostMemory("output"),
+ LinSpaceOp<double>);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc
new file mode 100644
index 0000000000..7cb1da8983
--- /dev/null
+++ b/tensorflow/core/kernels/shape_ops.cc
@@ -0,0 +1,261 @@
+// See docs in ../ops/array_ops.cc.
+
+#include <unordered_set>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+class ShapeOp : public OpKernel {
+ public:
+ explicit ShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& inp = ctx->input(0);
+ const int rank = inp.dims();
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({rank}), &out));
+ auto vec = out->vec<int32>();
+ for (int i = 0; i < rank; ++i) vec(i) = inp.dim_size(i);
+ }
+
+ bool IsExpensive() override { return false; }
+};
+REGISTER_KERNEL_BUILDER(Name("Shape").Device(DEVICE_CPU).HostMemory("output"),
+ ShapeOp);
+
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Shape") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ ShapeOp)
+TF_CALL_REAL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+#undef REGISTER_GPU_KERNEL
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Shape")
+ .Device(DEVICE_GPU)
+ .HostMemory("input")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T"),
+ ShapeOp);
+
+class RankOp : public OpKernel {
+ public:
+ explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& inp = ctx->input(0);
+ const int rank = inp.dims();
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
+ out->scalar<int32>()() = rank;
+ }
+
+ bool IsExpensive() override { return false; }
+};
+REGISTER_KERNEL_BUILDER(Name("Rank").Device(DEVICE_CPU).HostMemory("output"),
+ RankOp);
+
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Rank") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("output"), \
+ RankOp);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+#undef REGISTER_GPU_KERNEL
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Rank")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int32>("T")
+ .HostMemory("input")
+ .HostMemory("output"),
+ RankOp);
+
+class SizeOp : public OpKernel {
+ public:
+ explicit SizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& inp = ctx->input(0);
+ const int64 size = inp.NumElements();
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
+ // TODO(josh11b): switch output to int64?
+ out->scalar<int32>()() = size;
+ }
+
+ bool IsExpensive() override { return false; }
+};
+REGISTER_KERNEL_BUILDER(Name("Size").Device(DEVICE_CPU).HostMemory("output"),
+ SizeOp);
+
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Size") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("output"), \
+ SizeOp);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+#undef REGISTER_GPU_KERNEL
+
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Size")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int32>("T")
+ .HostMemory("input")
+ .HostMemory("output"),
+ SizeOp);
+
+class ExpandDimsOp : public OpKernel {
+ public:
+ explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ int dim = ctx->input(1).flat<int>()(0);
+ OP_REQUIRES(
+ ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()),
+ errors::InvalidArgument("Tried to expand dim index ", dim,
+ " for tensor with ", ctx->input(0).dims(),
+ " dimensions."));
+
+ auto existing_dims = ctx->input(0).shape().dim_sizes();
+ std::vector<int64> new_shape(existing_dims.size());
+ for (size_t i = 0; i < new_shape.size(); ++i) {
+ new_shape[i] = existing_dims[i];
+ }
+
+ // We emulate numpy's interpretation of the dim axis when
+ // -input.dims() >= dim <= input.dims().
+ if (dim < 0) {
+ dim += existing_dims.size() + 1;
+ }
+
+ // Clamp to the end if needed.
+ dim = std::min<int32>(dim, existing_dims.size());
+ new_shape.emplace(new_shape.begin() + dim, 1);
+ const TensorShape output_shape(new_shape);
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
+ if (!output->CopyFrom(ctx->input(0), output_shape)) {
+ // This should never happen, since the sizes of the input and output
+ // should always be the same (we only expand the dimension with 1).
+ ctx->SetStatus(
+ errors::Internal("Could not expand dimension with input shape ",
+ ctx->input(0).shape().DebugString(),
+ " and output shape ", output_shape.DebugString()));
+ }
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("ExpandDims").Device(DEVICE_CPU).HostMemory("dim"),
+ ExpandDimsOp);
+
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("ExpandDims") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("dim"), \
+ ExpandDimsOp);
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+#undef REGISTER_GPU_KERNEL
+
+REGISTER_KERNEL_BUILDER(Name("ExpandDims")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int32>("T")
+ .HostMemory("input")
+ .HostMemory("dim")
+ .HostMemory("output"),
+ ExpandDimsOp);
+
+class SqueezeOp : public OpKernel {
+ public:
+ explicit SqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ std::vector<int32> squeeze_dims;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims));
+ squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end());
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ auto existing_dims = ctx->input(0).shape().dim_sizes();
+ std::vector<int64> new_shape;
+
+ std::unordered_set<int32> wrapped_squeeze_dims;
+ wrapped_squeeze_dims.reserve(squeeze_dims_.size());
+ // Validate squeeze dims against the input.
+ for (int32 dim : squeeze_dims_) {
+ OP_REQUIRES(
+ ctx, (dim >= -ctx->input(0).dims() && dim < ctx->input(0).dims()),
+ errors::InvalidArgument("Tried to squeeze dim index ", dim,
+ " for tensor with ", ctx->input(0).dims(),
+ " dimensions."));
+ // If dim is < 0, we wrap around (-1 means the last element).
+ if (dim < 0) {
+ dim = existing_dims.size() + dim;
+ }
+
+ wrapped_squeeze_dims.insert(dim);
+ }
+
+ for (size_t i = 0; i < existing_dims.size(); ++i) {
+ auto existing_dim = existing_dims[i];
+
+ // If squeeze_set is non-empty, only squeeze those dimensions.
+ if (!wrapped_squeeze_dims.empty()) {
+ if (wrapped_squeeze_dims.count(i) > 0) {
+ OP_REQUIRES(ctx, existing_dim == 1,
+ errors::InvalidArgument("Tried to explicitly squeeze "
+ "dimension ",
+ i, " but dimension was not 1: ",
+ existing_dim));
+ } else {
+ // This dimension is not being squeezed.
+ new_shape.push_back(existing_dim);
+ }
+ } else {
+ // Copy over all non-1-length dimensions.
+ if (existing_dim != 1) {
+ new_shape.push_back(existing_dim);
+ }
+ }
+ }
+
+ const TensorShape output_shape(new_shape);
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
+ if (!output->CopyFrom(ctx->input(0), output_shape)) {
+ // This should never happen, since the sizes of the input and
+ // output should always be the same.
+ ctx->SetStatus(errors::Internal("Could not squeeze input with shape ",
+ ctx->input(0).shape().DebugString(),
+ " and output shape ",
+ output_shape.DebugString()));
+ }
+ }
+
+ private:
+ std::unordered_set<int32> squeeze_dims_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("Squeeze").Device(DEVICE_CPU), SqueezeOp);
+
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Squeeze").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ SqueezeOp);
+TF_CALL_NUMBER_TYPES(REGISTER_GPU_KERNEL);
+#undef REGISTER_GPU_KERNEL
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
new file mode 100644
index 0000000000..3477266d5d
--- /dev/null
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -0,0 +1,242 @@
+// See docs in ../ops/array_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif // GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/slice_op.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+namespace {
+
+gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) {
+ gtl::InlinedVector<int64, 4> out;
+ if (tensor.dtype() == DT_INT32) {
+ for (int64 i = 0; i < tensor.NumElements(); ++i) {
+ out.push_back(tensor.flat<int32>()(i));
+ }
+ } else if (tensor.dtype() == DT_INT64) {
+ for (int64 i = 0; i < tensor.NumElements(); ++i) {
+ out.push_back(tensor.flat<int64>()(i));
+ }
+ } else {
+ LOG(FATAL) << "begin must be either int32 or int64";
+ }
+ return out;
+}
+
+} // namespace
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+// Shared code that is not dependent on the type of T. We do this to reduce
+// code size by not duplicating all this for all T (float, double, int32, etc.)
+static void SharedValidation(OpKernelContext* context,
+ TensorShape* output_shape, bool* is_identity,
+ bool* slice_dim0,
+ gtl::InlinedVector<int64, 4>* begin,
+ gtl::InlinedVector<int64, 4>* size) {
+ const Tensor& input = context->input(0);
+ const Tensor& begin_tensor = context->input(1);
+ const Tensor& size_tensor = context->input(2);
+
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsLegacyVector(begin_tensor.shape()) &&
+ TensorShapeUtils::IsLegacyVector(size_tensor.shape()) &&
+ begin_tensor.NumElements() == input.dims() &&
+ size_tensor.NumElements() == input.dims(),
+ errors::InvalidArgument(
+ "Expected begin and size arguments to be 1-D tensors of size ",
+ input.dims(), ", but got ", begin_tensor.NumElements(), " and ",
+ size_tensor.NumElements(), " instead."));
+
+ const int input_dims = input.dims();
+ *begin = IntTensorToInt64Vec(begin_tensor);
+ *size = IntTensorToInt64Vec(size_tensor);
+ for (int i = 0; i < input_dims; ++i) {
+ if ((*size)[i] == -1) {
+ // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
+ (*size)[i] = input.dim_size(i) - (*begin)[i];
+ }
+ }
+
+ *is_identity = true;
+ *slice_dim0 = true;
+ for (int i = 0; i < input_dims; ++i) {
+ int64 b = (*begin)[i];
+ int64 s = (*size)[i];
+ if (input.dim_size(i) == 0) {
+ OP_REQUIRES(
+ context, b == 0 && s == 0,
+ errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b,
+ ") and size[", i, "] == 0 ", "(got ", s,
+ ") when ", "input.dim_size(", i, ") == 0"));
+ } else {
+ OP_REQUIRES(context, 0 <= b && b <= input.dim_size(i),
+ errors::InvalidArgument("Expected begin[", i, "] in [0, ",
+ input.dim_size(i), "], but got ", b));
+ OP_REQUIRES(
+ context, 0 <= s && b + s <= input.dim_size(i),
+ errors::InvalidArgument("Expected size[", i, "] in [0, ",
+ input.dim_size(i) - b, "], but ", "got ", s));
+ }
+ output_shape->AddDim(s);
+ const bool take_all = (b == 0) && (s == input.dim_size(i));
+ (*is_identity) &= take_all;
+ (*slice_dim0) &= (i == 0) || take_all;
+ }
+}
+
+template <typename Device, typename T>
+class SliceOp : public OpKernel {
+ public:
+ explicit SliceOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ TensorShape output_shape;
+ bool is_identity = true;
+ bool slice_dim0 = true;
+ gtl::InlinedVector<int64, 4> begin;
+ gtl::InlinedVector<int64, 4> size;
+ SharedValidation(context, &output_shape, &is_identity, &slice_dim0, &begin,
+ &size);
+ if (!context->status().ok()) return;
+ const Tensor& input = context->input(0);
+ if (is_identity) {
+ VLOG(1) << "Slice identity";
+ context->set_output(0, input);
+ return;
+ }
+
+ if (slice_dim0 && IsInnerDimsSizeAligned<T>(input.shape())) {
+ VLOG(1) << "Slice dim 0: " << input.shape().DebugString();
+ CHECK_GE(input.dims(), 1); // Otherwise, is_identity should be true.
+ context->set_output(0, input.Slice(begin[0], begin[0] + size[0]));
+ return;
+ }
+
+ Tensor* result = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &result));
+ const int input_dims = input.dims();
+
+ if (output_shape.num_elements() > 0) {
+ if (std::is_same<Device, CPUDevice>::value && input_dims == 2 &&
+ DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
+ auto input = context->input(0).tensor<T, 2>();
+ auto output = result->tensor<T, 2>();
+ // TODO(agarwal): Consider multi-threading this loop for cases where
+ // size[0] is very large.
+ for (int i = 0; i < size[0]; ++i) {
+ const int row = begin[0] + i;
+ if (i + 1 < size[0]) {
+ port::prefetch<port::PREFETCH_HINT_T0>(&output(i + 1, 0));
+ port::prefetch<port::PREFETCH_HINT_T0>(&input(row + 1, begin[1]));
+ }
+ memcpy(&output(i, 0), &input(row, begin[1]), size[1] * sizeof(T));
+ }
+ return;
+ }
+#define HANDLE_DIM(NDIM) \
+ if (input_dims == NDIM) { \
+ HandleCase<NDIM>(context, begin, size, result); \
+ return; \
+ }
+
+ HANDLE_DIM(1);
+ HANDLE_DIM(2);
+ HANDLE_DIM(3);
+ HANDLE_DIM(4);
+ HANDLE_DIM(5);
+
+#undef HANDLE_DIM
+
+ OP_REQUIRES(context, false, errors::Unimplemented(
+ "SliceOp : Unhandled input dimensions"));
+ }
+ }
+
+ private:
+ template <int NDIM>
+ void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
+ const gtl::ArraySlice<int64>& size, Tensor* result) {
+ Eigen::DSizes<ptrdiff_t, NDIM> indices;
+ Eigen::DSizes<ptrdiff_t, NDIM> sizes;
+ for (int i = 0; i < NDIM; ++i) {
+ indices[i] = begin[i];
+ sizes[i] = size[i];
+ }
+
+ functor::Slice<Device, T, NDIM>()(
+ context->eigen_device<Device>(), result->tensor<T, NDIM>(),
+ context->input(0).tensor<T, NDIM>(), indices, sizes);
+ }
+};
+
+#define REGISTER_SLICE(type) \
+ REGISTER_KERNEL_BUILDER(Name("Slice") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("begin") \
+ .HostMemory("size"), \
+ SliceOp<CPUDevice, type>)
+
+TF_CALL_ALL_TYPES(REGISTER_SLICE);
+REGISTER_SLICE(bfloat16);
+
+#undef REGISTER_SLICE
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T, NDIM) \
+ template <> \
+ void Slice<GPUDevice, T, NDIM>::operator()( \
+ const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
+ typename TTypes<T, NDIM>::ConstTensor input, \
+ const Eigen::DSizes<ptrdiff_t, NDIM>& indices, \
+ const Eigen::DSizes<ptrdiff_t, NDIM>& sizes); \
+ extern template struct Slice<GPUDevice, T, NDIM>;
+
+#define DECLARE_FOR_N(T) \
+ DECLARE_GPU_SPEC(T, 1); \
+ DECLARE_GPU_SPEC(T, 2); \
+ DECLARE_GPU_SPEC(T, 3); \
+ DECLARE_GPU_SPEC(T, 4); \
+ DECLARE_GPU_SPEC(T, 5);
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N);
+DECLARE_FOR_N(int32);
+
+#undef DECLARE_FOR_N
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+#define REGISTER_GPU(type) \
+ REGISTER_KERNEL_BUILDER(Name("Slice") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("begin") \
+ .HostMemory("size") \
+ .TypeConstraint<int32>("Index"), \
+ SliceOp<GPUDevice, type>)
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
+REGISTER_GPU(int32);
+
+#undef REGISTER_GPU
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/slice_op.h b/tensorflow/core/kernels/slice_op.h
new file mode 100644
index 0000000000..1b6bd9c112
--- /dev/null
+++ b/tensorflow/core/kernels/slice_op.h
@@ -0,0 +1,25 @@
+#ifndef TENSORFLOW_KERNELS_SLICE_OP_H_
+#define TENSORFLOW_KERNELS_SLICE_OP_H_
+
+// Functor definition for SliceOp, must be compilable by nvcc.
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T, int NDIMS>
+struct Slice {
+ void operator()(const Device& d, typename TTypes<T, NDIMS>::Tensor output,
+ typename TTypes<T, NDIMS>::ConstTensor input,
+ const Eigen::DSizes<ptrdiff_t, NDIMS>& slice_indices,
+ const Eigen::DSizes<ptrdiff_t, NDIMS>& slice_sizes) {
+ output.device(d) = input.slice(slice_indices, slice_sizes);
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_SLICE_OP_H_
diff --git a/tensorflow/core/kernels/slice_op_gpu.cu.cc b/tensorflow/core/kernels/slice_op_gpu.cu.cc
new file mode 100644
index 0000000000..6e919b244c
--- /dev/null
+++ b/tensorflow/core/kernels/slice_op_gpu.cu.cc
@@ -0,0 +1,31 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <stdio.h>
+
+#include "tensorflow/core/kernels/slice_op.h"
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#define DEFINE_GPU_KERNELS(T) \
+ template struct functor::Slice<GPUDevice, T, 1>; \
+ template struct functor::Slice<GPUDevice, T, 2>; \
+ template struct functor::Slice<GPUDevice, T, 3>; \
+ template struct functor::Slice<GPUDevice, T, 4>; \
+ template struct functor::Slice<GPUDevice, T, 5>;
+
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
+DEFINE_GPU_KERNELS(int32);
+
+#undef DEFINE_GPU_KERNELS
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/slice_op_test.cc b/tensorflow/core/kernels/slice_op_test.cc
new file mode 100644
index 0000000000..27c78c6dc0
--- /dev/null
+++ b/tensorflow/core/kernels/slice_op_test.cc
@@ -0,0 +1,73 @@
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace {
+
+// For the benchmark, we set up two 2-dimensional tensors, each kDim1 x 'dim'
+// in size, and concat them together along "concat_dimension"
+template <typename T>
+static void SliceHelper(int iters, int size) {
+ testing::StopTiming();
+ RequireDefaultOps();
+ Graph* g = new Graph(OpRegistry::Global());
+ DataType dt = DataTypeToEnum<T>::v();
+ int kDim = 100;
+ int kMaxSize = 15000;
+ CHECK_LT(size, kMaxSize);
+
+ Tensor begin(DT_INT32, TensorShape({2}));
+ begin.flat<int32>()(0) = 10;
+ begin.flat<int32>()(1) = 10;
+
+ Tensor sizes(DT_INT32, TensorShape({2}));
+ sizes.flat<int32>()(0) = kDim;
+ sizes.flat<int32>()(1) = size;
+
+ Tensor input(dt, TensorShape({2 * kDim, kMaxSize}));
+ input.flat<T>().setRandom();
+
+ Node* node;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Slice")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, begin))
+ .Input(test::graph::Constant(g, sizes))
+ .Attr("T", dt)
+ .Finalize(g, &node));
+
+ testing::BytesProcessed(static_cast<int64>(iters) * kDim * size * sizeof(T));
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+ testing::UseRealTime();
+}
+
+static void BM_SliceFloat(int iters, int dim2) {
+ SliceHelper<float>(iters, dim2);
+}
+
+BENCHMARK(BM_SliceFloat)->Arg(100)->Arg(1000)->Arg(10000);
+
+static void BM_SliceBFloat16(int iters, int dim2) {
+ SliceHelper<bfloat16>(iters, dim2);
+}
+
+BENCHMARK(BM_SliceBFloat16)->Arg(100)->Arg(1000)->Arg(10000);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/softmax_op.cc b/tensorflow/core/kernels/softmax_op.cc
new file mode 100644
index 0000000000..abe6331a4f
--- /dev/null
+++ b/tensorflow/core/kernels/softmax_op.cc
@@ -0,0 +1,62 @@
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/kernels/softmax_op.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class SoftmaxOp : public OpKernel {
+ public:
+ explicit SoftmaxOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& logits_in = context->input(0);
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()),
+ errors::InvalidArgument("logits must be 2-dimensional"));
+ Tensor* softmax_out = nullptr;
+ OP_REQUIRES_OK(
+ context, context->allocate_output(0, logits_in.shape(), &softmax_out));
+ functor::SoftmaxFunctor<Device, T> functor;
+ functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
+ softmax_out->matrix<T>());
+ }
+};
+
+// Partial specialization for a CPUDevice, that uses the Eigen implementation
+// from SoftmaxEigenImpl.
+namespace functor {
+template <typename T>
+struct SoftmaxFunctor<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix logits,
+ typename TTypes<T>::Matrix softmax) {
+ SoftmaxEigenImpl<CPUDevice, T>::Compute(d, logits, softmax);
+ }
+};
+} // namespace functor
+
+REGISTER_KERNEL_BUILDER(Name("Softmax")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ SoftmaxOp<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("Softmax")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<double>("T"),
+ SoftmaxOp<CPUDevice, double>);
+
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("Softmax")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T"),
+ SoftmaxOp<GPUDevice, float>);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/softmax_op.h b/tensorflow/core/kernels/softmax_op.h
new file mode 100644
index 0000000000..69bd531b70
--- /dev/null
+++ b/tensorflow/core/kernels/softmax_op.h
@@ -0,0 +1,70 @@
+#ifndef TENSORFLOW_KERNELS_SOFTMAX_OP_H_
+#define TENSORFLOW_KERNELS_SOFTMAX_OP_H_
+// Functor definition for SoftmaxOp, must be compilable by nvcc.
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+// Functor used by SoftmaxOp to do the computations.
+template <typename Device, typename T>
+struct SoftmaxFunctor {
+ // Computes Softmax activation.
+ //
+ // logits: dim: batch_size, num_classes.
+ // softmax: dims: batch_size, num_classes.
+ void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits,
+ typename TTypes<T>::Matrix softmax);
+};
+
+// Eigen code implementing SoftmaxFunctor::operator().
+// This code works for both CPU and GPU and is used by the functor
+// specializations for both device types.
+template <typename Device, typename T>
+struct SoftmaxEigenImpl {
+ static void Compute(const Device& d, typename TTypes<T>::ConstMatrix logits,
+ typename TTypes<T>::Matrix softmax) {
+ const int kBatchDim = 0;
+ const int kClassDim = 1;
+
+ const int batch_size = logits.dimension(kBatchDim);
+ const int num_classes = logits.dimension(kClassDim);
+
+// These arrays are used to reduce along the class dimension, and broadcast
+// the resulting value to all classes.
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ Eigen::DSizes<int, 1> along_class(kClassDim);
+ Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
+ Eigen::DSizes<int, 2> one_by_class(1, num_classes);
+#else
+ Eigen::IndexList<Eigen::type2index<kClassDim> > along_class;
+ Eigen::IndexList<Eigen::type2index<1> > depth_dim;
+ Eigen::IndexList<int, Eigen::type2index<1> > batch_by_one;
+ batch_by_one.set(0, batch_size);
+ Eigen::IndexList<Eigen::type2index<1>, int> one_by_class;
+ one_by_class.set(1, num_classes);
+#endif
+ // NOTE(mdevin): If you modify this implementation please run
+ // the ImageNetSoftmaxFwd benchmark in core_ops_test.cc.
+ //
+ // softmax = exp(logits - max(logits along classes));
+ softmax.device(d) = (logits -
+ logits.maximum(along_class)
+ .eval()
+ .reshape(batch_by_one)
+ .broadcast(one_by_class)).exp();
+ // softmax = softmax / sum(softmax along classes);
+ softmax.device(d) = (softmax /
+ softmax.sum(along_class)
+ .eval()
+ .reshape(batch_by_one)
+ .broadcast(one_by_class));
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_SOFTMAX_OP_H_
diff --git a/tensorflow/core/kernels/softmax_op_gpu.cu.cc b/tensorflow/core/kernels/softmax_op_gpu.cu.cc
new file mode 100644
index 0000000000..d5aaf9c364
--- /dev/null
+++ b/tensorflow/core/kernels/softmax_op_gpu.cu.cc
@@ -0,0 +1,31 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/softmax_op.h"
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// Partial specialization for a GPUDevice, that uses the Eigen implementation
+// from SoftmaxEigenImpl.
+namespace functor {
+template <typename T>
+struct SoftmaxFunctor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::ConstMatrix logits,
+ typename TTypes<T>::Matrix softmax) {
+ SoftmaxEigenImpl<GPUDevice, T>::Compute(d, logits, softmax);
+ }
+};
+} // end namespace functor
+
+// Instantiate the GPU implementation for float.
+template struct functor::SoftmaxFunctor<GPUDevice, float>;
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/softplus_op.cc b/tensorflow/core/kernels/softplus_op.cc
new file mode 100644
index 0000000000..b5fb57d3c5
--- /dev/null
+++ b/tensorflow/core/kernels/softplus_op.cc
@@ -0,0 +1,97 @@
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/softplus_op.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class SoftplusOp : public UnaryElementWiseOp<T, SoftplusOp<Device, T>> {
+ public:
+ using UnaryElementWiseOp<T, SoftplusOp<Device, T>>::UnaryElementWiseOp;
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ functor::Softplus<Device, T> functor;
+ functor(context->eigen_device<Device>(), input.flat<T>(),
+ output->flat<T>());
+ }
+};
+
+template <typename Device, typename T>
+class SoftplusGradOp
+ : public BinaryElementWiseOp<T, SoftplusGradOp<Device, T>> {
+ public:
+ using BinaryElementWiseOp<T, SoftplusGradOp<Device, T>>::BinaryElementWiseOp;
+
+ // INPUTS:
+ // g (gradients): backpropagated gradients
+ // a (inputs): inputs that were passed to SoftplusOp()
+ // OUTPUT:
+ // gradients to backprop
+ template <int NDIMS>
+ void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ OP_REQUIRES(context, a.IsSameSize(g),
+ errors::InvalidArgument("g and a must be the same size"));
+ functor::SoftplusGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
+ output->flat<T>());
+ }
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Softplus").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ SoftplusOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SoftplusGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ SoftplusGradOp<CPUDevice, type>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void Softplus<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor features, \
+ typename TTypes<T>::Tensor activations); \
+ extern template struct Softplus<GPUDevice, T>; \
+ \
+ template <> \
+ void SoftplusGrad<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
+ typename TTypes<T>::ConstTensor features, \
+ typename TTypes<T>::Tensor backprops); \
+ extern template struct SoftplusGrad<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Softplus").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ SoftplusOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SoftplusGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ SoftplusGradOp<GPUDevice, type>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
+#undef REGISTER_GPU_KERNELS
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/softplus_op.h b/tensorflow/core/kernels/softplus_op.h
new file mode 100644
index 0000000000..3545a78246
--- /dev/null
+++ b/tensorflow/core/kernels/softplus_op.h
@@ -0,0 +1,46 @@
+#ifndef TENSORFLOW_KERNELS_SOFTPLUS_OP_H_
+#define TENSORFLOW_KERNELS_SOFTPLUS_OP_H_
+// Functor definition for SoftplusOp and SoftplusGradOp, must be compilable by
+// nvcc.
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+// Functor used by SoftplusOp to do the computations.
+template <typename Device, typename T>
+struct Softplus {
+ // Computes Softplus activation.
+ //
+ // features: any shape.
+ // activations: same shape as "features".
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
+ typename TTypes<T>::Tensor activations) {
+ activations.device(d) =
+ (features > features.constant(30.f))
+ .select(features, (features.exp() + features.constant(1.0f)).log());
+ }
+};
+
+// Functor used by SoftplusGradOp to do the computations.
+template <typename Device, typename T>
+struct SoftplusGrad {
+ // Computes SoftplusGrad backprops.
+ //
+ // gradients: gradients backpropagated to the Softplus op.
+ // features: inputs that where passed to the Softplus op.
+ // backprops: gradients to backpropagate to the Softplus inputs.
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
+ typename TTypes<T>::ConstTensor features,
+ typename TTypes<T>::Tensor backprops) {
+ backprops.device(d) =
+ gradients / ((-features).exp() + features.constant(1.0f));
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_SOFTPLUS_OP_H_
diff --git a/tensorflow/core/kernels/softplus_op_gpu.cu.cc b/tensorflow/core/kernels/softplus_op_gpu.cu.cc
new file mode 100644
index 0000000000..7a974321a7
--- /dev/null
+++ b/tensorflow/core/kernels/softplus_op_gpu.cu.cc
@@ -0,0 +1,25 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <stdio.h>
+
+#include "tensorflow/core/kernels/softplus_op.h"
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// Definition of the GPU implementations declared in softplus_op.cc.
+#define DEFINE_GPU_KERNELS(T) \
+ template struct functor::Softplus<GPUDevice, T>; \
+ template struct functor::SoftplusGrad<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/sparse_concat_op.cc b/tensorflow/core/kernels/sparse_concat_op.cc
new file mode 100644
index 0000000000..72c267a47d
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_concat_op.cc
@@ -0,0 +1,139 @@
+#define EIGEN_USE_THREADS
+
+#include <algorithm>
+#include <unordered_map>
+#include <utility>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+namespace tensorflow {
+
+template <typename T>
+class SparseConcatOp : public OpKernel {
+ public:
+ explicit SparseConcatOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("concat_dim", &concat_dim_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ OpInputList inds;
+ OP_REQUIRES_OK(context, context->input_list("indices", &inds));
+ const int N = inds.size();
+ for (int i = 0; i < N; i++) {
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(inds[i].shape()),
+ errors::InvalidArgument(
+ "Input indices should be a matrix but received shape ",
+ inds[i].shape().DebugString(), " at position ", i));
+ }
+
+ OpInputList vals;
+ OP_REQUIRES_OK(context, context->input_list("values", &vals));
+ OP_REQUIRES(context, vals.size() == N,
+ errors::InvalidArgument("Expected ", N, " input values, got ",
+ vals.size()));
+ for (int i = 0; i < N; i++) {
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(vals[i].shape()),
+ errors::InvalidArgument(
+ "Input values should be a vector but received shape ",
+ vals[i].shape().DebugString(), " at position ", i));
+ }
+
+ OpInputList shapes;
+ OP_REQUIRES_OK(context, context->input_list("shapes", &shapes));
+ OP_REQUIRES(context, shapes.size() == N,
+ errors::InvalidArgument("Expected ", N, " input shapes, got ",
+ shapes.size()));
+ for (int i = 0; i < N; i++) {
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(shapes[i].shape()),
+ errors::InvalidArgument(
+ "Input shapes should be a vector but received shape ",
+ shapes[i].shape().DebugString(), " at position ", i));
+ }
+
+ const TensorShape input_shape(shapes[0].vec<int64>());
+ OP_REQUIRES(
+ context, concat_dim_ >= 0 && concat_dim_ < input_shape.dims(),
+ errors::InvalidArgument("Concat dimension must be between 0 and rank (",
+ input_shape.dims(), "), got ", concat_dim_));
+ for (int i = 1; i < N; ++i) {
+ const TensorShape current_shape(shapes[i].vec<int64>());
+ OP_REQUIRES(context, current_shape.dims() == input_shape.dims(),
+ errors::InvalidArgument(
+ "Ranks of all input tensors must match: expected ",
+ input_shape.dims(), " but got ", current_shape.dims(),
+ " at position ", i));
+ for (int j = 0; j < input_shape.dims(); ++j) {
+ if (j != concat_dim_) {
+ OP_REQUIRES(
+ context, input_shape.dim_size(j) == current_shape.dim_size(j),
+ errors::InvalidArgument(
+ "Input shapes must match: expected ", input_shape.dim_size(j),
+ " for dimension ", j, " but got ", current_shape.dim_size(j),
+ " at position ", i));
+ }
+ }
+ }
+
+ // The input and output sparse tensors are assumed to be ordered along
+ // increasing dimension number. But in order for concat to work properly,
+ // order[0] must be concat_dim. So we will reorder the inputs to the
+ // concat ordering, concatenate, then reorder back to the standard order.
+ // We make a deep copy of the input tensors to ensure that the in-place
+ // reorder doesn't create race conditions for other ops that may be
+ // concurrently reading the indices and values tensors.
+
+ gtl::InlinedVector<int64, 8> std_order(input_shape.dims());
+ std::iota(std_order.begin(), std_order.end(), 0);
+
+ std::vector<int64> concat_order;
+ concat_order.reserve(input_shape.dims());
+ concat_order.push_back(concat_dim_);
+ for (int j = 0; j < input_shape.dims(); ++j) {
+ if (j != concat_dim_) {
+ concat_order.push_back(j);
+ }
+ }
+
+ std::vector<sparse::SparseTensor> sp_inputs;
+ for (int i = 0; i < N; ++i) {
+ const TensorShape current_shape(shapes[i].vec<int64>());
+ sp_inputs.emplace_back(tensor::DeepCopy(inds[i]),
+ tensor::DeepCopy(vals[i]), current_shape,
+ std_order);
+ sp_inputs[i].Reorder<T>(concat_order);
+ }
+
+ sparse::SparseTensor concat = sparse::SparseTensor::Concat<T>(sp_inputs);
+ concat.Reorder<T>(std_order);
+
+ context->set_output(0, concat.indices());
+ context->set_output(1, concat.values());
+
+ Tensor* output_shape_out = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 2, TensorShape({concat.shape().dims()}),
+ &output_shape_out));
+ auto output_shape = output_shape_out->vec<int64>();
+ for (int j = 0; j < concat.shape().dims(); ++j) {
+ output_shape(j) = concat.shape().dim_size(j);
+ }
+ }
+
+ private:
+ int concat_dim_;
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseConcat").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ SparseConcatOp<type>)
+
+TF_CALL_ALL_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc
new file mode 100644
index 0000000000..919e129ff8
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_matmul_op.cc
@@ -0,0 +1,192 @@
+// See docs in ../ops/math_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/port.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename T>
+void PrefetchBlockNTA(const T& tensor, int si, int ei, int sj, int ej) {
+ for (int i = si; i < ei; ++i) {
+ for (int j = sj; j < ej; j = j + 16) {
+ port::prefetch<port::PREFETCH_HINT_NTA>(&tensor(i, j));
+ }
+ }
+}
+
+template <typename T>
+void PrefetchBlockT1(const T& tensor, int si, int ei, int sj, int ej) {
+ for (int i = si; i < ei; ++i) {
+ for (int j = sj; j < ej; j = j + 16) {
+ port::prefetch<port::PREFETCH_HINT_T1>(&tensor(i, j));
+ }
+ }
+}
+
+struct Block {
+ Block(int sm, int em, int sk, int ek, int sn, int en)
+ : startm(sm), endm(em), startk(sk), endk(ek), startn(sn), endn(en) {}
+
+ int startm;
+ int endm;
+ int startk;
+ int endk;
+ int startn;
+ int endn;
+};
+
+bool NextBlock(const int Bm, const int Bk, const int Bn, const int m_start,
+ const int m, const int k, const int n, const Block& b,
+ Block* next) {
+ *next = b;
+ if (b.endk < k) {
+ next->startk = b.endk;
+ next->endk = std::min(b.endk + Bk, k);
+ } else {
+ next->startk = 0;
+ next->endk = std::min(Bk, k);
+ if (b.endm < m) {
+ next->startm = b.endm;
+ next->endm = std::min(b.endm + Bm, m);
+ } else {
+ next->startm = m_start;
+ next->endm = std::min(m_start + Bm, m);
+ next->startn = b.endn;
+ next->endn = std::min(b.endn + Bn, n);
+ }
+ }
+ return next->startn == next->endn;
+}
+
+class SparseMatMulOp : public OpKernel {
+ public:
+ explicit SparseMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse", &a_is_sparse_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse", &b_is_sparse_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& a = ctx->input(0);
+ const Tensor& b = ctx->input(1);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
+ errors::InvalidArgument("a is not a matrix"));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
+ errors::InvalidArgument("b is not a matrix"));
+
+ auto left = a.matrix<float>();
+ auto right_mat = b.matrix<float>();
+ const int m = transpose_a_ ? left.dimension(1) : left.dimension(0);
+ const int k = transpose_a_ ? left.dimension(0) : left.dimension(1);
+ const int n =
+ transpose_b_ ? right_mat.dimension(0) : right_mat.dimension(1);
+ const int k2 =
+ transpose_b_ ? right_mat.dimension(1) : right_mat.dimension(0);
+
+ OP_REQUIRES(ctx, k == k2,
+ errors::InvalidArgument("Matrix size incompatible: a: ",
+ a.shape().DebugString(), ", b: ",
+ b.shape().DebugString()));
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output));
+ auto out = output->matrix<float>();
+
+ if (!a_is_sparse_) {
+ // Fallback to Eigen contract.
+ // Note that we currently don't optimize the case where only right is
+ // sparse. That can generally be handled by tranposing the order of the
+ // matmul.
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
+ dim_pair[0].first = transpose_a_ ? 0 : 1;
+ dim_pair[0].second = transpose_b_ ? 1 : 0;
+ out.device(ctx->template eigen_device<CPUDevice>()) =
+ left.contract(right_mat, dim_pair);
+ return;
+ }
+ typedef Eigen::Tensor<float, 2, Eigen::RowMajor> Matrix;
+ std::unique_ptr<Matrix> right_tr_mat;
+ std::unique_ptr<TTypes<float>::ConstMatrix> right_tr_map;
+ if (transpose_b_) {
+ right_tr_mat.reset(new Matrix(k, n));
+ Eigen::array<int, 2> perm({1, 0});
+ right_tr_mat->device(ctx->template eigen_device<CPUDevice>()) =
+ right_mat.shuffle(perm);
+ right_tr_map.reset(new TTypes<float>::ConstMatrix(
+ right_tr_mat->data(), right_tr_mat->dimensions()));
+ }
+ TTypes<float>::ConstMatrix& right =
+ transpose_b_ ? *right_tr_map : right_mat;
+
+ const bool transpose_a = transpose_a_;
+
+ typedef Eigen::TensorMap<Eigen::Tensor<float, 1, Eigen::RowMajor>,
+ Eigen::Unaligned> TensorMap;
+ typedef Eigen::TensorMap<Eigen::Tensor<const float, 1, Eigen::RowMajor>,
+ Eigen::Unaligned> ConstTensorMap;
+ typedef Eigen::DSizes<Eigen::DenseIndex, 1> DSizes;
+ const int Bm = 16;
+ const int Bk = 16;
+ const int Bn = 1024;
+
+ auto work_shard = [m, n, k, transpose_a, Bm, Bk, Bn, &left, &right, &out](
+ int64 start64, int64 end64) {
+ const int start = static_cast<int>(start64);
+ const int end = static_cast<int>(end64);
+ Block curr(start, std::min(start + Bm, end), 0, std::min(Bk, k), 0,
+ std::min(Bn, n));
+ Block next(curr);
+ bool done = false;
+ for (int i = start; i < end; ++i) {
+ out.chip<0>(i).setZero();
+ }
+ while (true) {
+ done = NextBlock(Bm, Bk, Bn, start, end, k, n, curr, &next);
+
+ PrefetchBlockT1(right, curr.startk, curr.endk, curr.startn, curr.endn);
+
+ // Process current block
+ for (int i = curr.startm; i < curr.endm; ++i) {
+ PrefetchBlockNTA(left, i, i + 1, curr.startk, curr.endk);
+ PrefetchBlockNTA(out, i, i + 1, curr.startn, curr.endn);
+ DSizes out_slice_shape(curr.endn - curr.startn);
+ TensorMap out_i(&out(i, curr.startn), out_slice_shape);
+ for (int j = curr.startk; j < curr.endk; ++j) {
+ const float l = transpose_a ? left(j, i) : left(i, j);
+ if (l == 0) continue;
+ ConstTensorMap right_j(&right(j, curr.startn), out_slice_shape);
+ out_i += right_j * l;
+ }
+ }
+ if (done) break;
+ curr = next;
+ }
+ };
+ auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
+ Shard(worker_threads.num_threads, worker_threads.workers, m, 2 * k * n,
+ work_shard);
+ }
+
+ private:
+ bool transpose_a_;
+ bool transpose_b_;
+ bool a_is_sparse_;
+ bool b_is_sparse_;
+ TF_DISALLOW_COPY_AND_ASSIGN(SparseMatMulOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("SparseMatMul").Device(DEVICE_CPU),
+ SparseMatMulOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/sparse_matmul_op_test.cc b/tensorflow/core/kernels/sparse_matmul_op_test.cc
new file mode 100644
index 0000000000..883d0d1224
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_matmul_op_test.cc
@@ -0,0 +1,139 @@
+#include "tensorflow/core/framework/types.pb.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+random::PhiloxRandom philox(1, 1);
+random::SimplePhilox rnd(&philox);
+
+void Sparsify(Tensor* t, float sparsity) {
+ const int64 N = t->NumElements();
+ CHECK_LE(sparsity, 1);
+ if (sparsity <= 0) return;
+ auto flat = t->flat<float>();
+ static const uint32 K = 10000;
+ for (int64 i = 0; i < N; ++i) {
+ if (rnd.Uniform(K) < sparsity * K) {
+ flat(i) = 0;
+ }
+ }
+}
+
+Node* SparseMatMulNode(Graph* g, Node* in0, Node* in1, bool transpose_a,
+ bool transpose_b, bool a_sparse, bool b_sparse) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "SparseMatMul")
+ .Input(in0)
+ .Input(in1)
+ .Attr("transpose_a", transpose_a)
+ .Attr("transpose_b", transpose_b)
+ .Attr("a_is_sparse", a_sparse)
+ .Attr("b_is_sparse", b_sparse)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d, float sparsity,
+ bool transpose_a, bool transpose_b,
+ bool a_sparse, bool b_sparse) {
+ a_sparse = a_sparse && (sparsity > 0);
+ b_sparse = b_sparse && (sparsity > 0);
+
+ auto left_shape = transpose_a ? TensorShape({d, m}) : TensorShape({m, d});
+ Tensor left(DataTypeToEnum<float>::value, left_shape);
+ left.flat<float>().setRandom();
+ if (a_sparse) {
+ Sparsify(&left, sparsity);
+ }
+
+ auto right_shape = transpose_b ? TensorShape({n, d}) : TensorShape({d, n});
+ Tensor right(DataTypeToEnum<float>::value, right_shape);
+ right.flat<float>().setRandom();
+ if (b_sparse) {
+ Sparsify(&right, sparsity);
+ }
+
+ SparseMatMulNode(g, test::graph::Constant(g, left),
+ test::graph::Constant(g, right), transpose_a, transpose_b,
+ a_sparse, b_sparse);
+ return g;
+}
+
+static Graph* SparseMatMul(int m, int n, int d, float sparsity,
+ bool transpose_a, bool transpose_b) {
+ Graph* g = new Graph(OpRegistry::Global());
+ return SparseMatMulHelper(g, m, n, d, sparsity, transpose_a, transpose_b,
+ true, false);
+}
+
+static Graph* MultiSparseMatMul(int m, int n, int d, float sparsity_a,
+ float sparsity_b) {
+ Graph* g = new Graph(OpRegistry::Global());
+ if (sparsity_a == 0 && sparsity_b > 0) {
+ SparseMatMulHelper(g, m, n, d, sparsity_a, false, false, false, false);
+ SparseMatMulHelper(g, n, d, m, sparsity_b, true, true, true, false);
+ SparseMatMulHelper(g, m, d, n, sparsity_b, false, false, true, false);
+ } else {
+ SparseMatMulHelper(g, m, n, d, sparsity_a, false, true, true, false);
+ SparseMatMulHelper(g, d, n, m, sparsity_a, true, false, true, true);
+ SparseMatMulHelper(g, m, d, n, sparsity_b, false, false, true, false);
+ }
+ return g;
+}
+
+#define BM_SPARSE(M, K, N, S) \
+ static void BM_Sparse##_##M##_##K##_##N##_##S(int iters) { \
+ testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \
+ std::string label = strings::Printf("%d_%d_%d_%0.2f", M, K, N, S / 100.0); \
+ testing::SetLabel(label); \
+ test::Benchmark("cpu", SparseMatMul(M, N, K, S / 100.0, false, false)) \
+ .Run(iters); \
+ } \
+ BENCHMARK(BM_Sparse##_##M##_##K##_##N##_##S);
+
+BM_SPARSE(2048, 2048, 2048, 0);
+BM_SPARSE(2048, 2048, 2048, 1);
+BM_SPARSE(2048, 2048, 2048, 85);
+
+BM_SPARSE(1024, 1024, 1024, 0);
+BM_SPARSE(1024, 1024, 1024, 1);
+BM_SPARSE(1024, 1024, 1024, 85);
+
+BM_SPARSE(256, 256, 256, 1);
+BM_SPARSE(512, 512, 512, 1);
+
+#define BM_SPARSE_MULTI(M, K, N, S1, S2) \
+ static void BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2(int iters) { \
+ testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2 * 3); \
+ std::string label = strings::Printf("%d_%d_%d_%0.2f_%0.2f", M, K, N, \
+ S1 / 100.0, S2 / 100.0); \
+ testing::SetLabel(label); \
+ test::Benchmark("cpu", MultiSparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0)) \
+ .Run(iters); \
+ } \
+ BENCHMARK(BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2);
+
+BM_SPARSE_MULTI(512, 2140, 4096, 0, 82);
+BM_SPARSE_MULTI(512, 4096, 2048, 83, 83);
+
+#define BM_SPARSE_TR(M, K, N, S, TA, TB) \
+ static void BM_Sparse##_##M##_##K##_##N##_##S##_##TA##_##TB(int iters) { \
+ testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \
+ std::string label = \
+ strings::Printf("%d_%d_%d_%d_%d_%0.2f", M, K, N, TA, TB, S / 100.0); \
+ testing::SetLabel(label); \
+ test::Benchmark("cpu", SparseMatMul(M, N, K, S / 100.0, TA, TB)) \
+ .Run(iters); \
+ } \
+ BENCHMARK(BM_Sparse##_##M##_##K##_##N##_##S##_##TA##_##TB);
+
+BM_SPARSE_TR(2048, 2048, 2048, 1, true, false);
+BM_SPARSE_TR(2048, 2048, 2048, 1, false, true);
+BM_SPARSE_TR(2048, 2048, 2048, 1, true, true);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/sparse_reorder_op.cc b/tensorflow/core/kernels/sparse_reorder_op.cc
new file mode 100644
index 0000000000..fd6824a4e2
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_reorder_op.cc
@@ -0,0 +1,71 @@
+#define EIGEN_USE_THREADS
+
+#include <algorithm>
+#include <unordered_map>
+#include <utility>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+namespace tensorflow {
+
+template <typename T>
+class SparseReorderOp : public OpKernel {
+ public:
+ explicit SparseReorderOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input_ind = context->input(0);
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_ind.shape()),
+ errors::InvalidArgument(
+ "Input indices should be a matrix but received shape",
+ input_ind.shape().DebugString()));
+
+ const Tensor& input_val = context->input(1);
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(input_val.shape()),
+ errors::InvalidArgument(
+ "Input values should be a vector but received shape",
+ input_val.shape().DebugString()));
+
+ const Tensor& input_shape_in = context->input(2);
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()),
+ errors::InvalidArgument(
+ "Input shape should be a vector but received shape",
+ input_shape_in.shape().DebugString()));
+
+ const TensorShape input_shape(input_shape_in.vec<int64>());
+
+ gtl::InlinedVector<int64, 8> std_order(input_shape.dims());
+ std::iota(std_order.begin(), std_order.end(), 0);
+
+ // Check if the sparse tensor is already ordered correctly
+ sparse::SparseTensor input_sp(input_ind, input_val, input_shape, std_order);
+
+ if (input_sp.IndicesValid()) {
+ context->set_output(0, input_sp.indices());
+ context->set_output(1, input_sp.values());
+ } else {
+ // Deep-copy the input Tensors, then reorder in-place
+ sparse::SparseTensor reordered_sp(tensor::DeepCopy(input_ind),
+ tensor::DeepCopy(input_val),
+ input_shape);
+ reordered_sp.Reorder<T>(std_order);
+ context->set_output(0, reordered_sp.indices());
+ context->set_output(1, reordered_sp.values());
+ }
+ }
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseReorder").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ SparseReorderOp<type>)
+
+TF_CALL_ALL_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/sparse_to_dense_op.cc b/tensorflow/core/kernels/sparse_to_dense_op.cc
new file mode 100644
index 0000000000..47e91c134d
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_to_dense_op.cc
@@ -0,0 +1,129 @@
+// See core/ops/sparse_ops.cc for documentation.
+//
+// NOTE: the operations in this file only are suitable for execution
+// on CPUs.
+
+#define EIGEN_USE_THREADS
+
+#include <string>
+#include <sstream>
+#include <unordered_map>
+#include <utility>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+namespace tensorflow {
+
+// Operator to convert sparse representations to dense.
+template <typename T, typename Index>
+class SparseToDense : public OpKernel {
+ public:
+ explicit SparseToDense(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* c) override {
+ // sparse_indices
+ const Tensor& indices = c->input(0);
+ OP_REQUIRES(c, indices.dims() <= 2,
+ errors::InvalidArgument(
+ "sparse_indices should be a scalar, vector, or matrix, "
+ "got shape ",
+ indices.shape().ShortDebugString()));
+ const int64 num_elems = indices.dims() > 0 ? indices.dim_size(0) : 1;
+ const int64 num_dims = indices.dims() > 1 ? indices.dim_size(1) : 1;
+
+ // output_shape
+ const Tensor& output_shape = c->input(1);
+ OP_REQUIRES(
+ c, TensorShapeUtils::IsLegacyVector(output_shape.shape()),
+ errors::InvalidArgument("output_shape should be a vector, got shape ",
+ output_shape.shape().ShortDebugString()));
+ OP_REQUIRES(c, output_shape.NumElements() == num_dims,
+ errors::InvalidArgument(
+ "output_shape has incorrect number of elements: ",
+ output_shape.NumElements(), " should be: ", num_dims));
+
+ // sparse_values
+ const Tensor& sparse_values = c->input(2);
+ const int64 num_values = sparse_values.NumElements();
+ OP_REQUIRES(
+ c, sparse_values.dims() == 0 ||
+ (sparse_values.dims() == 1 && num_values == num_elems),
+ errors::InvalidArgument("sparse_values has incorrect shape ",
+ sparse_values.shape().ShortDebugString(),
+ ", should be [] or [", num_elems, "]"));
+
+ // default_value
+ const Tensor& default_value = c->input(3);
+ OP_REQUIRES(c, TensorShapeUtils::IsScalar(default_value.shape()),
+ errors::InvalidArgument("default_value should be a scalar."));
+
+ auto output_shape_vec = output_shape.flat<Index>();
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(c, c->allocate_output(0, TensorShapeUtils::MakeShape(
+ output_shape_vec.data(),
+ output_shape_vec.size()),
+ &output));
+
+ TensorShape ix_shape({num_elems, num_dims});
+ Tensor indices_shaped(DT_INT64, ix_shape);
+ if (indices.dtype() == DT_INT64) {
+ CHECK(indices_shaped.CopyFrom(indices, ix_shape));
+ } else {
+ indices_shaped.matrix<int64>() =
+ indices.shaped<Index, 2>(ix_shape.dim_sizes()).template cast<int64>();
+ }
+
+ // If we received a scalar, we'll need to create a new
+ // tensor with copies of the values as a vec.
+ // TODO(ebrevdo): find a way to avoid this temp allocation.
+ Tensor sparse_values_b;
+
+ if (TensorShapeUtils::IsScalar(sparse_values.shape())) {
+ OP_REQUIRES_OK(
+ c, c->allocate_temp(DataTypeToEnum<T>::value,
+ TensorShape({num_elems}), &sparse_values_b));
+ sparse_values_b.vec<T>().setConstant(sparse_values.scalar<T>()());
+ } else {
+ sparse_values_b = sparse_values;
+ }
+
+ gtl::InlinedVector<int64, 8> order(output->shape().dims());
+ std::iota(order.begin(), order.end(), 0); // Assume order is correct
+ sparse::SparseTensor st(indices_shaped, sparse_values_b, output->shape(),
+ order);
+
+ output->flat<T>().setConstant(default_value.scalar<T>()());
+ OP_REQUIRES(c, st.template ToDense<T>(output, false /* initialize */),
+ errors::InvalidArgument(
+ "Indices are not valid (out of bounds). Shape: ",
+ output->shape().DebugString()));
+ }
+};
+
+#define REGISTER_KERNELS(type, index_type) \
+ REGISTER_KERNEL_BUILDER(Name("SparseToDense") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ SparseToDense<type, index_type>);
+
+#define REGISTER_KERNELS_ALL(type) \
+ REGISTER_KERNELS(type, int32); \
+ REGISTER_KERNELS(type, int64);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS_ALL);
+REGISTER_KERNELS_ALL(bool);
+REGISTER_KERNELS_ALL(string);
+
+#undef REGISTER_KERNELS_ALL
+#undef REGISTER_KERNELS
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/sparse_to_dense_op_test.cc b/tensorflow/core/kernels/sparse_to_dense_op_test.cc
new file mode 100644
index 0000000000..e9800ccd68
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_to_dense_op_test.cc
@@ -0,0 +1,283 @@
+#include <functional>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+namespace {
+
+class SparseToDenseTest : public OpsTestBase {
+ protected:
+ void SetUp() override { RequireDefaultOps(); }
+
+ void MakeOp(int dim, DataType index_type, DataType value_type) {
+ ASSERT_OK(NodeDefBuilder("sparsetodense", "SparseToDense")
+ .Input(FakeInput(index_type))
+ .Input(FakeInput(index_type))
+ .Input(FakeInput(value_type))
+ .Input(FakeInput(value_type))
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(SparseToDenseTest, OneD_OneValue) {
+ MakeOp(1, DT_INT32, DT_FLOAT);
+
+ // sparse_indices
+ AddInputFromArray<int32>(TensorShape({3}), {1, 3, 4});
+ // output_shape
+ AddInputFromArray<int32>(TensorShape({1}), {5});
+ // sparse_values
+ AddInputFromArray<float>(TensorShape({}), {2});
+ // default_value
+ AddInputFromArray<float>(TensorShape({}), {-2});
+
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, {5});
+ test::FillValues<float>(&expected, {-2, 2, -2, 2, 2});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(SparseToDenseTest, OneD_OneValue_int64_double) {
+ MakeOp(1, DT_INT64, DT_DOUBLE);
+
+ // sparse_indices
+ AddInputFromArray<int64>(TensorShape({3}), {1, 3, 4});
+ // output_shape
+ AddInputFromArray<int64>(TensorShape({1}), {5});
+ // sparse_values
+ AddInputFromArray<double>(TensorShape({}), {2});
+ // default_value
+ AddInputFromArray<double>(TensorShape({}), {-2});
+
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_DOUBLE, {5});
+ test::FillValues<double>(&expected, {-2, 2, -2, 2, 2});
+ test::ExpectTensorEqual<double>(expected, *GetOutput(0));
+}
+
+TEST_F(SparseToDenseTest, OneD_MultValues) {
+ MakeOp(1, DT_INT32, DT_FLOAT);
+
+ // sparse_indices
+ AddInputFromArray<int32>({3}, {1, 3, 4});
+ // output_shape
+ AddInputFromArray<int32>({1}, {5});
+ // sparse_values
+ AddInputFromArray<float>({3}, {3, 4, 5});
+ // default_value
+ AddInputFromArray<float>({}, {-2});
+
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, {5});
+ test::FillValues<float>(&expected, {-2, 3, -2, 4, 5});
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(SparseToDenseTest, TwoD_OneValue) {
+ MakeOp(2, DT_INT32, DT_FLOAT);
+
+ // sparse_indices
+ AddInputFromArray<int32>(TensorShape({3, 2}), {0, 1, 0, 2, 2, 3});
+ // output_shape
+ AddInputFromArray<int32>(TensorShape({2}), {3, 4});
+ // sparse_values
+ AddInputFromArray<float>(TensorShape({}), {2});
+ // default_value
+ AddInputFromArray<float>(TensorShape({}), {-2});
+
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, {3, 4});
+ expected.flat<float>().setConstant(-2);
+ expected.tensor<float, 2>()(0, 1) = 2;
+ expected.tensor<float, 2>()(0, 2) = 2;
+ expected.tensor<float, 2>()(2, 3) = 2;
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(SparseToDenseTest, TwoD_MultValues) {
+ MakeOp(2, DT_INT32, DT_FLOAT);
+
+ // sparse_indices
+ AddInputFromArray<int32>(TensorShape({3, 2}), {0, 1, 0, 2, 2, 3});
+ // output_shape
+ AddInputFromArray<int32>(TensorShape({2}), {3, 4});
+ // sparse_values
+ AddInputFromArray<float>(TensorShape({3}), {3, 4, 5});
+ // default_value
+ AddInputFromArray<float>(TensorShape({}), {-2});
+
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, {3, 4});
+ expected.flat<float>().setConstant(-2);
+ expected.tensor<float, 2>()(0, 1) = 3;
+ expected.tensor<float, 2>()(0, 2) = 4;
+ expected.tensor<float, 2>()(2, 3) = 5;
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(SparseToDenseTest, ThreeD_OneValue) {
+ MakeOp(3, DT_INT32, DT_FLOAT);
+
+ // sparse_indices
+ AddInputFromArray<int32>(TensorShape({3, 3}), {0, 1, 1, 0, 2, 0, 2, 3, 1});
+ // output_shape
+ AddInputFromArray<int32>(TensorShape({3}), {3, 4, 2});
+ // sparse_values
+ AddInputFromArray<float>(TensorShape({}), {2});
+ // default_value
+ AddInputFromArray<float>(TensorShape({}), {-2});
+
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, {3, 4, 2});
+ expected.flat<float>().setConstant(-2);
+ expected.tensor<float, 3>()(0, 1, 1) = 2;
+ expected.tensor<float, 3>()(0, 2, 0) = 2;
+ expected.tensor<float, 3>()(2, 3, 1) = 2;
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+TEST_F(SparseToDenseTest, ThreeD_MultValues) {
+ MakeOp(3, DT_INT32, DT_FLOAT);
+
+ // sparse_indices
+ AddInputFromArray<int32>(TensorShape({3, 3}), {0, 1, 1, 0, 2, 0, 2, 3, 1});
+ // output_shape
+ AddInputFromArray<int32>(TensorShape({3}), {3, 4, 2});
+ // sparse_values
+ AddInputFromArray<float>(TensorShape({3}), {3, 4, 5});
+ // default_value
+ AddInputFromArray<float>(TensorShape({}), {-2});
+
+ ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_FLOAT, {3, 4, 2});
+ expected.flat<float>().setConstant(-2);
+ expected.tensor<float, 3>()(0, 1, 1) = 3;
+ expected.tensor<float, 3>()(0, 2, 0) = 4;
+ expected.tensor<float, 3>()(2, 3, 1) = 5;
+ test::ExpectTensorEqual<float>(expected, *GetOutput(0));
+}
+
+} // namespace
+
+static int BM_Arg(int ndim, int n) { return (ndim * 1000000) + n; }
+static int NDIM_from_arg(int bm_arg) { return bm_arg / 1000000; }
+static int N_from_arg(int bm_arg) { return bm_arg % 1000000; }
+
+static void BM_SparseToDense(int iters, const int bm_arg) {
+ const int NDIM = NDIM_from_arg(bm_arg);
+ const int N = N_from_arg(bm_arg);
+ // TODO(zhifengc): Switch to use kernel_benchmark_testlib.h
+ tensorflow::testing::StopTiming();
+
+ const int IndexDim = (NDIM == 1) ? 0 : 1;
+
+ std::unique_ptr<Device> device(
+ DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
+
+ gtl::InlinedVector<TensorValue, 4> inputs;
+
+ // Create a dense tensor with dims [1, ..., 1, N]
+ Tensor output_shape(DT_INT32, TensorShape({NDIM}));
+ Tensor sparse_indices(DT_INT32, TensorShape({N, NDIM}));
+ Tensor sparse_values(DT_FLOAT, TensorShape({N}));
+ Tensor default_value(DT_FLOAT, TensorShape({}));
+ auto output_shape_t = output_shape.vec<int32>();
+ for (int d = 0; d < NDIM; ++d) {
+ output_shape_t(d) = (d == IndexDim) ? N : 3;
+ }
+
+ auto sparse_indices_t = sparse_indices.matrix<int32>();
+ for (int n = 0; n < N; ++n) {
+ for (int d = 0; d < NDIM; ++d)
+ sparse_indices_t(n, d) = (d == IndexDim) ? n : 0;
+ }
+
+ for (auto* ptr :
+ {&sparse_indices, &output_shape, &sparse_values, &default_value}) {
+ inputs.push_back({nullptr, ptr});
+ }
+
+ NodeDef sparse_node_def;
+ TF_CHECK_OK(NodeDefBuilder("sparsetodense", "SparseToDense")
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Finalize(&sparse_node_def));
+
+ Status status;
+ std::unique_ptr<OpKernel> op(CreateOpKernel(
+ DEVICE_CPU, device.get(), cpu_allocator(), sparse_node_def, &status));
+
+ OpKernelContext::Params params;
+ params.device = device.get();
+ params.frame_iter = FrameAndIter(0, 0);
+ params.inputs = &inputs;
+ params.op_kernel = op.get();
+ params.output_alloc_attr = [&device, &op, &params](int index) {
+ AllocatorAttributes attr;
+ const bool on_host = (op->output_memory_types()[index] == HOST_MEMORY);
+ attr.set_on_host(on_host);
+ return attr;
+ };
+
+ std::unique_ptr<OpKernelContext> sparse_context(new OpKernelContext(params));
+ op->Compute(sparse_context.get());
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters; ++i) {
+ delete sparse_context->release_output(0).tensor;
+ op->Compute(sparse_context.get());
+ ASSERT_OK(sparse_context->status());
+ }
+ tensorflow::testing::StopTiming();
+
+ // processing input, mainly
+ int64 bytes_per_iter = static_cast<int64>((N + N * NDIM) * sizeof(float));
+
+ tensorflow::testing::BytesProcessed(bytes_per_iter * iters);
+}
+
+BENCHMARK(BM_SparseToDense)
+ ->Arg(BM_Arg(1, 10))
+ ->Arg(BM_Arg(1, 100))
+ ->Arg(BM_Arg(1, 1000))
+ ->Arg(BM_Arg(1, 10000))
+ ->Arg(BM_Arg(2, 10))
+ ->Arg(BM_Arg(2, 100))
+ ->Arg(BM_Arg(2, 1000))
+ ->Arg(BM_Arg(2, 10000))
+ ->Arg(BM_Arg(3, 10))
+ ->Arg(BM_Arg(3, 100))
+ ->Arg(BM_Arg(3, 1000))
+ ->Arg(BM_Arg(3, 10000))
+ ->Arg(BM_Arg(5, 10))
+ ->Arg(BM_Arg(5, 100))
+ ->Arg(BM_Arg(5, 1000))
+ ->Arg(BM_Arg(5, 10000));
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc
new file mode 100644
index 0000000000..f4f9ada000
--- /dev/null
+++ b/tensorflow/core/kernels/split_op.cc
@@ -0,0 +1,146 @@
+// See docs in ../ops/array_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/split_op.h"
+
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class SplitOp : public OpKernel {
+ public:
+ explicit SplitOp(OpKernelConstruction* c) : OpKernel(c) {}
+
+ void Compute(OpKernelContext* context) override {
+ const int32 split_dim = context->input(0).flat<int32>()(0);
+ const int32 num_split = num_outputs();
+ const Tensor& input = context->input(1);
+ const TensorShape& input_shape = input.shape();
+
+ OP_REQUIRES(
+ context, 0 <= split_dim && split_dim < input_shape.dims(),
+ errors::InvalidArgument("0 <= split_dim < number of input dimensions (",
+ input_shape.dims(), "), but got ", split_dim));
+
+ OP_REQUIRES(
+ context, num_split > 0,
+ errors::InvalidArgument(
+ "Number of ways to split should be > 0, but got ", num_split));
+
+ OP_REQUIRES(context, input_shape.dim_size(split_dim) % num_split == 0,
+ errors::InvalidArgument(
+ "Number of ways to split should evenly divide the split "
+ "dimension, but got split_dim ",
+ split_dim, " (size = ", input_shape.dim_size(split_dim),
+ ") ", "and num_split ", num_split));
+
+ // Special case 1: num_split == 1. Nothing to do.
+ if (num_split == 1) {
+ VLOG(1) << "Split identity";
+ context->set_output(0, context->input(1));
+ return;
+ }
+
+ // Special case 2: split along the 1st dimension. We can share the
+ // underlying buffer.
+ //
+ // Apply this optimization conservatively: if input is aligned,
+ // the resulting tensors must be aligned. It's conservative
+ // because if the immediate consumer of the resulting tensors are
+ // not using eigen for computation, its perfectly fine to avoid
+ // the copying.
+ if ((split_dim == 0) && IsInnerDimsSizeAligned<T>(input_shape)) {
+ VLOG(1) << "Slice dim 0: " << input_shape.DebugString();
+ const int64 delta = input_shape.dim_size(0) / num_split;
+ for (int i = 0; i < num_split; ++i) {
+ context->set_output(i, input.Slice(i * delta, (i + 1) * delta));
+ }
+ return;
+ }
+
+ int32 prefix_dim_size = 1;
+ for (int i = 0; i < split_dim; ++i) {
+ prefix_dim_size *= input_shape.dim_size(i);
+ }
+
+ int32 split_dim_size = input_shape.dim_size(split_dim);
+
+ int32 suffix_dim_size = 1;
+ for (int i = split_dim + 1; i < input_shape.dims(); ++i) {
+ suffix_dim_size *= input_shape.dim_size(i);
+ }
+
+ auto input_reshaped =
+ input.shaped<T, 3>({prefix_dim_size, split_dim_size, suffix_dim_size});
+
+ const int32 split_dim_output_size = split_dim_size / num_split;
+ TensorShape output_shape(input_shape);
+ output_shape.set_dim(split_dim, split_dim_output_size);
+
+ Eigen::DSizes<ptrdiff_t, 3> indices{0, 0, 0};
+ Eigen::DSizes<ptrdiff_t, 3> sizes{prefix_dim_size, split_dim_output_size,
+ suffix_dim_size};
+
+ for (int i = 0; i < num_split; ++i) {
+ Tensor* result = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(i, output_shape, &result));
+ if (prefix_dim_size * split_dim_output_size * suffix_dim_size > 0) {
+ Eigen::DSizes<ptrdiff_t, 3> slice_indices;
+ Eigen::DSizes<ptrdiff_t, 3> slice_sizes;
+ for (int j = 0; j < 3; ++j) {
+ slice_indices[j] = indices[j];
+ slice_sizes[j] = sizes[j];
+ }
+
+ auto result_shaped = result->shaped<T, 3>(
+ {prefix_dim_size, split_dim_output_size, suffix_dim_size});
+
+ functor::Split<Device, T>()(context->eigen_device<Device>(),
+ result_shaped, input_reshaped,
+ slice_indices, slice_sizes);
+ }
+ indices[1] += split_dim_output_size;
+ }
+ }
+};
+
+#define REGISTER_SPLIT(type) \
+ REGISTER_KERNEL_BUILDER(Name("Split") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("split_dim"), \
+ SplitOp<CPUDevice, type>)
+
+TF_CALL_ALL_TYPES(REGISTER_SPLIT);
+
+#undef REGISTER_SPLIT
+
+#if GOOGLE_CUDA
+
+#define REGISTER_GPU(type) \
+ REGISTER_KERNEL_BUILDER(Name("Split") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("split_dim"), \
+ SplitOp<GPUDevice, type>)
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
+#undef REGISTER_GPU
+
+#endif // GOOGLE_CUDA
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/split_op.h b/tensorflow/core/kernels/split_op.h
new file mode 100644
index 0000000000..2572c77285
--- /dev/null
+++ b/tensorflow/core/kernels/split_op.h
@@ -0,0 +1,31 @@
+#ifndef TENSORFLOW_KERNELS_SPLIT_OP_H_
+#define TENSORFLOW_KERNELS_SPLIT_OP_H_
+// Functor definition for SplitOp, must be compilable by nvcc.
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T>
+struct Split {
+ void operator()(const Device& d, typename TTypes<T, 3>::Tensor output,
+ typename TTypes<T, 3>::ConstTensor input,
+ const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
+ const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes);
+};
+
+template <typename T>
+struct Split<Eigen::ThreadPoolDevice, T> {
+ void operator()(const Eigen::ThreadPoolDevice& d,
+ typename TTypes<T, 3>::Tensor output,
+ typename TTypes<T, 3>::ConstTensor input,
+ const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
+ const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes);
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_SPLIT_OP_H_
diff --git a/tensorflow/core/kernels/split_op_cpu.cc b/tensorflow/core/kernels/split_op_cpu.cc
new file mode 100644
index 0000000000..b86deeb8fb
--- /dev/null
+++ b/tensorflow/core/kernels/split_op_cpu.cc
@@ -0,0 +1,30 @@
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/split_op.h"
+
+#include "tensorflow/core/framework/numeric_types.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename T>
+void Split<Eigen::ThreadPoolDevice, T>::operator()(
+ const Eigen::ThreadPoolDevice& d, typename TTypes<T, 3>::Tensor output,
+ typename TTypes<T, 3>::ConstTensor input,
+ const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
+ const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes) {
+ if (output.size() < 131072) {
+ output = input.slice(slice_indices, slice_sizes);
+ } else {
+ output.device(d) = input.slice(slice_indices, slice_sizes);
+ }
+}
+
+#define DEFINE_CPU_KERNELS(T) template struct Split<Eigen::ThreadPoolDevice, T>;
+
+TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS)
+
+} // namespace functor
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/split_op_gpu.cu.cc b/tensorflow/core/kernels/split_op_gpu.cu.cc
new file mode 100644
index 0000000000..f8931d6a89
--- /dev/null
+++ b/tensorflow/core/kernels/split_op_gpu.cu.cc
@@ -0,0 +1,31 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <stdio.h>
+
+#include "tensorflow/core/kernels/split_op.h"
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T>
+void Split<Device, T>::operator()(
+ const Device& d, typename TTypes<T, 3>::Tensor output,
+ typename TTypes<T, 3>::ConstTensor input,
+ const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
+ const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes) {
+ output.device(d) = input.slice(slice_indices, slice_sizes);
+}
+
+#define DEFINE_GPU_KERNELS(T) template struct Split<Eigen::GpuDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/string_to_hash_bucket_op.cc b/tensorflow/core/kernels/string_to_hash_bucket_op.cc
new file mode 100644
index 0000000000..bd6fa47268
--- /dev/null
+++ b/tensorflow/core/kernels/string_to_hash_bucket_op.cc
@@ -0,0 +1,47 @@
+#include <string>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+class StringToHashBucketOp : public OpKernel {
+ public:
+ explicit StringToHashBucketOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(context, context->input("string_tensor", &input_tensor));
+ const auto& input_flat = input_tensor->flat<string>();
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("output", input_tensor->shape(),
+ &output_tensor));
+ auto output_flat = output_tensor->flat<int64>();
+
+ for (int i = 0; i < input_flat.size(); ++i) {
+ const uint64 input_hash = Hash64(input_flat(i));
+ const uint64 bucket_id = input_hash % num_buckets_;
+ // The number of buckets is always in the positive range of int64 so is
+ // the resulting bucket_id. Casting the bucket_id from uint64 to int64 is
+ // safe.
+ output_flat(i) = static_cast<int64>(bucket_id);
+ }
+ }
+
+ private:
+ int64 num_buckets_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(StringToHashBucketOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("StringToHashBucket").Device(DEVICE_CPU),
+ StringToHashBucketOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/string_to_number_op.cc b/tensorflow/core/kernels/string_to_number_op.cc
new file mode 100644
index 0000000000..8d23a4fdf8
--- /dev/null
+++ b/tensorflow/core/kernels/string_to_number_op.cc
@@ -0,0 +1,71 @@
+// See docs in ../ops/parse_ops.cc.
+
+#include <errno.h>
+#include <string>
+
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+static constexpr char kErrorMessage[] =
+ "StringToNumberOp could not correctly convert string: ";
+
+template <typename OutputType>
+class StringToNumberOp : public OpKernel {
+ public:
+ using OpKernel::OpKernel;
+
+ void Compute(OpKernelContext* context) override {
+ // This is not a deep copy of the input tensor; they will share the same
+ // underlying storage.
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(context, context->input("string_tensor", &input_tensor));
+ const auto& input_flat = input_tensor->flat<string>();
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("output", input_tensor->shape(),
+ &output_tensor));
+ auto output_flat = output_tensor->flat<OutputType>();
+
+ for (int i = 0; i < input_flat.size(); ++i) {
+ const char* s = input_flat(i).data();
+ Convert(s, &output_flat(i), context);
+ }
+ }
+
+ private:
+ void Convert(const char* s, OutputType* output_data,
+ OpKernelContext* context);
+};
+
+template <>
+void StringToNumberOp<float>::Convert(const char* s, float* output_data,
+ OpKernelContext* context) {
+ OP_REQUIRES(context, strings::safe_strtof(s, output_data),
+ errors::InvalidArgument(kErrorMessage, s));
+}
+
+template <>
+void StringToNumberOp<int32>::Convert(const char* s, int32* output_data,
+ OpKernelContext* context) {
+ OP_REQUIRES(context, strings::safe_strto32(s, output_data),
+ errors::InvalidArgument(kErrorMessage, s));
+}
+
+// Registers the currently supported output types.
+#define REGISTER(type) \
+ REGISTER_KERNEL_BUILDER(Name("StringToNumber") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("out_type"), \
+ StringToNumberOp<type>)
+REGISTER(float);
+REGISTER(int32);
+#undef REGISTER
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/summary_image_op.cc b/tensorflow/core/kernels/summary_image_op.cc
new file mode 100644
index 0000000000..ba765f2e84
--- /dev/null
+++ b/tensorflow/core/kernels/summary_image_op.cc
@@ -0,0 +1,169 @@
+// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
+// inputs or outputs in various ways.
+
+// See docs in ../ops/summary_ops.cc.
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/summary.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/png/png_io.h"
+
+namespace tensorflow {
+
+class SummaryImageOp : public OpKernel {
+ public:
+ explicit SummaryImageOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("max_images", &max_images_));
+ const TensorProto* proto;
+ OP_REQUIRES_OK(context, context->GetAttr("bad_color", &proto));
+ OP_REQUIRES_OK(context, context->device()->MakeTensorFromProto(
+ *proto, AllocatorAttributes(), &bad_color_));
+ OP_REQUIRES(context, bad_color_.dtype() == DT_UINT8,
+ errors::InvalidArgument("bad_color must be uint8, got ",
+ DataTypeString(bad_color_.dtype())));
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsVector(bad_color_.shape()),
+ errors::InvalidArgument("bad_color must be a vector, got shape ",
+ bad_color_.shape().ShortDebugString()));
+ }
+
+ void Compute(OpKernelContext* c) override {
+ const Tensor& tags = c->input(0);
+ const Tensor& tensor = c->input(1);
+ OP_REQUIRES(c, TensorShapeUtils::IsLegacyScalar(tags.shape()),
+ errors::InvalidArgument("Tags must have be a scalar"));
+ OP_REQUIRES(c, tensor.dims() == 4 &&
+ (tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 ||
+ tensor.dim_size(3) == 4),
+ errors::InvalidArgument(
+ "Tensor must be 4-D with last dim 1, 3, or 4, not ",
+ tensor.shape().DebugString()));
+ const string& base_tag = tags.scalar<string>()();
+
+ const int batch_size = tensor.dim_size(0);
+ const int h = tensor.dim_size(1);
+ const int w = tensor.dim_size(2);
+ const int hw = h * w; // Compact these two dims for simplicity
+ const int depth = tensor.dim_size(3);
+ auto tensor_eigen = tensor.shaped<float, 3>({batch_size, hw, depth});
+
+ OP_REQUIRES(c, bad_color_.dim_size(0) >= depth,
+ errors::InvalidArgument(
+ "expected depth <= bad_color.size, got depth = ", depth,
+ ", bad_color.size = ", bad_color_.dim_size(0)));
+ auto bad_color_full = bad_color_.vec<uint8>();
+ typename TTypes<uint8>::Vec bad_color(bad_color_full.data(), depth);
+
+ // RGB (or gray or RGBA) is last dimension
+ Eigen::Tensor<uint8, 2, Eigen::RowMajor> image(hw, depth);
+
+ Summary s;
+ const int N = std::min<int>(max_images_, batch_size);
+ for (int i = 0; i < N; ++i) {
+ Summary::Value* v = s.add_value();
+ // The tag depends on the number of requested images (not the number
+ // produced.)
+ //
+ // Note that later on avisu uses "/" to figure out a consistent naming
+ // convention for display, so we append "/image" to guarantee that the
+ // image(s) won't be displayed in the global scope with no name.
+ if (max_images_ > 1) {
+ v->set_tag(strings::StrCat(base_tag, "/image/", i));
+ } else {
+ v->set_tag(strings::StrCat(base_tag, "/image"));
+ }
+
+ if (image.size()) {
+ typename TTypes<float>::ConstMatrix values(
+ &tensor_eigen(i, 0, 0),
+ Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
+
+ // Rescale the image to uint8 range.
+ //
+ // We are trying to generate an RCG image from a float tensor. We do
+ // not have any info about the expected range of values in the tensor
+ // but the generated image needs to have all RGB values within [0, 255].
+ //
+ // We use two different algorithms to generate these values. If the
+ // tensor has only positive values we scale them all by 255/max(values).
+ // If the tensor has both negative and positive values we scale them by
+ // the max of their absolute values and center them around 127.
+ //
+ // This works for most cases, but has the incovenient of not respecting
+ // the relative dynamic range across different instances of the tensor.
+
+ // Compute min and max ignoring nonfinite pixels
+ float image_min = std::numeric_limits<float>::infinity();
+ float image_max = -image_min;
+ for (int i = 0; i < hw; i++) {
+ bool finite = true;
+ for (int j = 0; j < depth; j++) {
+ if (!std::isfinite(values(i, j))) {
+ finite = false;
+ break;
+ }
+ }
+ if (finite) {
+ for (int j = 0; j < depth; j++) {
+ float value = values(i, j);
+ image_min = std::min(image_min, value);
+ image_max = std::max(image_max, value);
+ }
+ }
+ }
+
+ // Pick an affine transform into uint8
+ const float kZeroThreshold = 1e-6;
+ float scale, offset;
+ if (image_min < 0) {
+ float max_val = std::max(std::abs(image_min), std::abs(image_max));
+ scale = max_val < kZeroThreshold ? 0.0f : 127.0f / max_val;
+ offset = 128.0f;
+ } else {
+ scale = image_max < kZeroThreshold ? 0.0f : 255.0f / image_max;
+ offset = 0.0f;
+ }
+
+ // Transform image, turning nonfinite values to bad_color
+ for (int i = 0; i < hw; i++) {
+ bool finite = true;
+ for (int j = 0; j < depth; j++) {
+ if (!std::isfinite(values(i, j))) {
+ finite = false;
+ break;
+ }
+ }
+ if (finite) {
+ image.chip<0>(i) =
+ (values.chip<0>(i) * scale + offset).cast<uint8>();
+ } else {
+ image.chip<0>(i) = bad_color;
+ }
+ }
+ }
+
+ Summary::Image* si = v->mutable_image();
+ si->set_height(h);
+ si->set_width(w);
+ si->set_colorspace(depth);
+ OP_REQUIRES(c, png::WriteImageToBuffer(
+ image.data(), w, h, w * depth, depth, 8, -1,
+ si->mutable_encoded_image_string(), nullptr),
+ errors::Internal("PNG encoding failed"));
+ }
+
+ Tensor* summary_tensor = nullptr;
+ OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
+ CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
+ }
+
+ private:
+ int64 max_images_;
+ Tensor bad_color_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ImageSummary").Device(DEVICE_CPU),
+ SummaryImageOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/summary_image_op_test.cc b/tensorflow/core/kernels/summary_image_op_test.cc
new file mode 100644
index 0000000000..ddfeeffc0b
--- /dev/null
+++ b/tensorflow/core/kernels/summary_image_op_test.cc
@@ -0,0 +1,141 @@
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/summary.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/histogram/histogram.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+namespace {
+
+static void EXPECT_SummaryMatches(const Summary& actual,
+ const string& expected_str) {
+ Summary expected;
+ CHECK(protobuf::TextFormat::ParseFromString(expected_str, &expected));
+ EXPECT_EQ(expected.DebugString(), actual.DebugString());
+}
+
+// --------------------------------------------------------------------------
+// SummaryImageOp
+// --------------------------------------------------------------------------
+class SummaryImageOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(int max_images) {
+ RequireDefaultOps();
+ ASSERT_OK(NodeDefBuilder("myop", "ImageSummary")
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Attr("max_images", max_images)
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+
+ void CheckAndRemoveEncodedImages(Summary* summary) {
+ for (int i = 0; i < summary->value_size(); ++i) {
+ Summary::Value* value = summary->mutable_value(i);
+ ASSERT_TRUE(value->has_image()) << "No image for value: " << value->tag();
+ ASSERT_FALSE(value->image().encoded_image_string().empty())
+ << "No encoded_image_string for value: " << value->tag();
+ if (VLOG_IS_ON(2)) {
+ // When LOGGING, output the images to disk for manual inspection.
+ TF_CHECK_OK(WriteStringToFile(
+ Env::Default(), strings::StrCat("/tmp/", value->tag(), ".png"),
+ value->image().encoded_image_string()));
+ }
+ value->mutable_image()->clear_encoded_image_string();
+ }
+ }
+};
+
+TEST_F(SummaryImageOpTest, ThreeGrayImagesOutOfFive4dInput) {
+ MakeOp(3 /* max images */);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({}), {"tag"});
+ AddInputFromArray<float>(TensorShape({5, 2, 1, 1}),
+ {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the output size.
+ Tensor* out_tensor = GetOutput(0);
+ ASSERT_EQ(0, out_tensor->dims());
+ Summary summary;
+ ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+
+ CheckAndRemoveEncodedImages(&summary);
+ EXPECT_SummaryMatches(summary, R"(
+ value { tag: 'tag/image/0' image { width: 1 height: 2 colorspace: 1} }
+ value { tag: 'tag/image/1' image { width: 1 height: 2 colorspace: 1} }
+ value { tag: 'tag/image/2' image { width: 1 height: 2 colorspace: 1} }
+ )");
+}
+
+TEST_F(SummaryImageOpTest, OneGrayImage4dInput) {
+ MakeOp(1 /* max images */);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({}), {"tag"});
+ AddInputFromArray<float>(TensorShape({5 /*batch*/, 2, 1, 1 /*depth*/}),
+ {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the output size.
+ Tensor* out_tensor = GetOutput(0);
+ ASSERT_EQ(0, out_tensor->dims());
+ Summary summary;
+ ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+
+ CheckAndRemoveEncodedImages(&summary);
+ EXPECT_SummaryMatches(summary, R"(
+ value { tag: 'tag/image' image { width: 1 height: 2 colorspace: 1} })");
+}
+
+TEST_F(SummaryImageOpTest, OneColorImage4dInput) {
+ MakeOp(1 /* max images */);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({}), {"tag"});
+ AddInputFromArray<float>(
+ TensorShape({1 /*batch*/, 5 /*rows*/, 2 /*columns*/, 3 /*depth*/}),
+ {
+ /* r0, c0, RGB */ 1.0, 0.1, 0.2,
+ /* r0, c1, RGB */ 1.0, 0.3, 0.4,
+ /* r1, c0, RGB */ 0.0, 1.0, 0.0,
+ /* r1, c1, RGB */ 0.0, 1.0, 0.0,
+ /* r2, c0, RGB */ 0.0, 0.0, 1.0,
+ /* r2, c1, RGB */ 0.0, 0.0, 1.0,
+ /* r3, c0, RGB */ 1.0, 1.0, 0.0,
+ /* r3, c1, RGB */ 1.0, 0.0, 1.0,
+ /* r4, c0, RGB */ 1.0, 1.0, 0.0,
+ /* r4, c1, RGB */ 1.0, 0.0, 1.0,
+ });
+ ASSERT_OK(RunOpKernel());
+
+ // Check the output size.
+ Tensor* out_tensor = GetOutput(0);
+ ASSERT_EQ(0, out_tensor->dims());
+ Summary summary;
+ ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+
+ CheckAndRemoveEncodedImages(&summary);
+ EXPECT_SummaryMatches(summary, R"(
+ value { tag: 'tag/image' image { width: 2 height: 5 colorspace: 3} })");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/summary_op.cc b/tensorflow/core/kernels/summary_op.cc
new file mode 100644
index 0000000000..1c4be64b8b
--- /dev/null
+++ b/tensorflow/core/kernels/summary_op.cc
@@ -0,0 +1,141 @@
+// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
+// inputs or outputs in various ways.
+
+// See docs in ../ops/summary_ops.cc.
+
+#include <unordered_set>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/summary.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/histogram/histogram.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+template <typename T>
+class SummaryScalarOp : public OpKernel {
+ public:
+ explicit SummaryScalarOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* c) override {
+ const Tensor& tags = c->input(0);
+ const Tensor& values = c->input(1);
+
+ OP_REQUIRES(c, tags.IsSameSize(values) ||
+ (TensorShapeUtils::IsLegacyScalar(tags.shape()) &&
+ TensorShapeUtils::IsLegacyScalar(values.shape())),
+ errors::InvalidArgument("tags and values not the same shape: ",
+ tags.shape().ShortDebugString(), " != ",
+ values.shape().ShortDebugString()));
+ auto Ttags = tags.flat<string>();
+ auto Tvalues = values.flat<T>();
+ Summary s;
+ for (int i = 0; i < Ttags.size(); i++) {
+ Summary::Value* v = s.add_value();
+ v->set_tag(Ttags(i));
+ v->set_simple_value(Tvalues(i));
+ }
+
+ Tensor* summary_tensor = nullptr;
+ OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
+ CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ScalarSummary")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ SummaryScalarOp<float>);
+REGISTER_KERNEL_BUILDER(Name("ScalarSummary")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<double>("T"),
+ SummaryScalarOp<double>);
+
+class SummaryHistoOp : public OpKernel {
+ public:
+ // SummaryHistoOp could be extended to take a list of custom bucket
+ // boundaries as an option.
+ explicit SummaryHistoOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* c) override {
+ const Tensor& tags = c->input(0);
+ const Tensor& values = c->input(1);
+ const auto flat = values.flat<float>();
+ OP_REQUIRES(c, TensorShapeUtils::IsLegacyScalar(tags.shape()),
+ errors::InvalidArgument("tags must be scalar"));
+ // Build histogram of values in "values" tensor
+ histogram::Histogram histo;
+ for (int64 i = 0; i < flat.size(); i++) {
+ float v = flat(i);
+ if (!std::isfinite(v)) {
+ c->SetStatus(
+ errors::OutOfRange("Nan in summary histogram for: ", name()));
+ break;
+ }
+ histo.Add(v);
+ }
+
+ Summary s;
+ Summary::Value* v = s.add_value();
+ v->set_tag(tags.scalar<string>()());
+ histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
+
+ Tensor* summary_tensor = nullptr;
+ OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
+ CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("HistogramSummary").Device(DEVICE_CPU),
+ SummaryHistoOp);
+
+struct HistogramResource : public ResourceBase {
+ histogram::ThreadSafeHistogram histogram;
+
+ string DebugString() override { return "A historam summary. Stats ..."; }
+};
+
+class SummaryMergeOp : public OpKernel {
+ public:
+ explicit SummaryMergeOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* c) override {
+ Summary s;
+ std::unordered_set<string> tags;
+ for (int input_num = 0; input_num < c->num_inputs(); input_num++) {
+ const Tensor& in = c->input(input_num);
+ auto in_vec = in.flat<string>();
+ for (int i = 0; i < in_vec.dimension(0); i++) {
+ const string& s_in = in_vec(i);
+ Summary summary_in;
+ if (!ParseProtoUnlimited(&summary_in, s_in)) {
+ c->SetStatus(errors::InvalidArgument(
+ "Could not parse one of the summary inputs"));
+ return;
+ }
+
+ for (int v = 0; v < summary_in.value_size(); v++) {
+ if (!tags.insert(summary_in.value(v).tag()).second) {
+ c->SetStatus(errors::InvalidArgument(
+ strings::StrCat("Duplicate tag ", summary_in.value(v).tag(),
+ " found in summary inputs")));
+ return;
+ }
+ *s.add_value() = summary_in.value(v);
+ }
+ }
+ }
+
+ Tensor* summary_tensor = nullptr;
+ OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
+ CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MergeSummary").Device(DEVICE_CPU),
+ SummaryMergeOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/summary_op_test.cc b/tensorflow/core/kernels/summary_op_test.cc
new file mode 100644
index 0000000000..fd271a6862
--- /dev/null
+++ b/tensorflow/core/kernels/summary_op_test.cc
@@ -0,0 +1,282 @@
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/summary.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/histogram/histogram.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace {
+
+static void EXPECT_SummaryMatches(const Summary& actual,
+ const string& expected_str) {
+ Summary expected;
+ CHECK(protobuf::TextFormat::ParseFromString(expected_str, &expected));
+ EXPECT_EQ(expected.DebugString(), actual.DebugString());
+}
+
+class SummaryScalarOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(DataType dt) {
+ RequireDefaultOps();
+ ASSERT_OK(NodeDefBuilder("myop", "ScalarSummary")
+ .Input(FakeInput())
+ .Input(FakeInput(dt))
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(SummaryScalarOpTest, SimpleFloat) {
+ MakeOp(DT_FLOAT);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({3}), {"tag1", "tag2", "tag3"});
+ AddInputFromArray<float>(TensorShape({3}), {1.0, -0.73, 10000.0});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the output size.
+ Tensor* out_tensor = GetOutput(0);
+ ASSERT_EQ(0, out_tensor->dims());
+ Summary summary;
+ ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+ EXPECT_SummaryMatches(summary, R"(
+ value { tag: 'tag1' simple_value: 1.0 }
+ value { tag: 'tag2' simple_value: -0.73 }
+ value { tag: 'tag3' simple_value: 10000.0 }
+ )");
+}
+
+TEST_F(SummaryScalarOpTest, SimpleDouble) {
+ MakeOp(DT_DOUBLE);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({3}), {"tag1", "tag2", "tag3"});
+ AddInputFromArray<double>(TensorShape({3}), {1.0, -0.73, 10000.0});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the output size.
+ Tensor* out_tensor = GetOutput(0);
+ ASSERT_EQ(0, out_tensor->dims());
+ Summary summary;
+ ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+ EXPECT_SummaryMatches(summary, R"(
+ value { tag: 'tag1' simple_value: 1.0 }
+ value { tag: 'tag2' simple_value: -0.73 }
+ value { tag: 'tag3' simple_value: 10000.0 }
+ )");
+}
+
+TEST_F(SummaryScalarOpTest, Error_MismatchedSize) {
+ MakeOp(DT_FLOAT);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({2}), {"tag1", "tag2"});
+ AddInputFromArray<float>(TensorShape({3}), {1.0, -0.73, 10000.0});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString()).contains("not the same shape")) << s;
+}
+
+TEST_F(SummaryScalarOpTest, Error_WrongDimsTags) {
+ MakeOp(DT_FLOAT);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({2, 1}), {"tag1", "tag2"});
+ AddInputFromArray<float>(TensorShape({2}), {1.0, -0.73});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(
+ StringPiece(s.ToString()).contains("tags and values not the same shape"))
+ << s;
+}
+
+TEST_F(SummaryScalarOpTest, Error_WrongDimsValues) {
+ MakeOp(DT_FLOAT);
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({2}), {"tag1", "tag2"});
+ AddInputFromArray<float>(TensorShape({2, 1}), {1.0, -0.73});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(
+ StringPiece(s.ToString()).contains("tags and values not the same shape"))
+ << s;
+}
+
+// --------------------------------------------------------------------------
+// SummaryHistoOp
+// --------------------------------------------------------------------------
+class SummaryHistoOpTest : public OpsTestBase {
+ protected:
+ void MakeOp() {
+ ASSERT_OK(NodeDefBuilder("myop", "HistogramSummary")
+ .Input(FakeInput())
+ .Input(FakeInput())
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(SummaryHistoOpTest, Simple) {
+ MakeOp();
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({}), {"taghisto"});
+ AddInputFromArray<float>(TensorShape({3, 2}), {0.1, -0.7, 4.1, 4., 5., 4.});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the output size.
+ Tensor* out_tensor = GetOutput(0);
+ ASSERT_EQ(0, out_tensor->dims());
+ Summary summary;
+ ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+ ASSERT_EQ(summary.value_size(), 1);
+ EXPECT_EQ(summary.value(0).tag(), "taghisto");
+ histogram::Histogram histo;
+ EXPECT_TRUE(histo.DecodeFromProto(summary.value(0).histo()));
+ EXPECT_EQ(
+ "Count: 6 Average: 2.7500 StdDev: 2.20\n"
+ "Min: -0.7000 Median: 3.9593 Max: 5.0000\n"
+ "------------------------------------------------------\n"
+ "[ -0.76, -0.69 ) 1 16.667% 16.667% ###\n"
+ "[ 0.093, 0.1 ) 1 16.667% 33.333% ###\n"
+ "[ 3.8, 4.2 ) 3 50.000% 83.333% ##########\n"
+ "[ 4.6, 5.1 ) 1 16.667% 100.000% ###\n",
+ histo.ToString());
+}
+
+TEST_F(SummaryHistoOpTest, Error_WrongDimsTags) {
+ MakeOp();
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({2, 1}), {"tag1", "tag2"});
+ AddInputFromArray<float>(TensorShape({2}), {1.0, -0.73});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString()).contains("tags must be scalar")) << s;
+}
+
+TEST_F(SummaryHistoOpTest, Error_TooManyTagValues) {
+ MakeOp();
+
+ // Feed and run
+ AddInputFromArray<string>(TensorShape({2}), {"tag1", "tag2"});
+ AddInputFromArray<float>(TensorShape({2, 1}), {1.0, -0.73});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString()).contains("tags must be scalar")) << s;
+}
+
+// --------------------------------------------------------------------------
+// SummaryMergeOp
+// --------------------------------------------------------------------------
+class SummaryMergeOpTest : public OpsTestBase {
+ protected:
+ void MakeOp(int num_inputs) {
+ ASSERT_OK(NodeDefBuilder("myop", "MergeSummary")
+ .Input(FakeInput(num_inputs))
+ .Finalize(node_def()));
+ ASSERT_OK(InitOp());
+ }
+};
+
+TEST_F(SummaryMergeOpTest, Simple) {
+ MakeOp(1);
+
+ // Feed and run
+ Summary s1;
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
+ "value { tag: \"tag1\" simple_value: 1.0 } "
+ "value { tag: \"tag2\" simple_value: -0.73 } ",
+ &s1));
+ Summary s2;
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
+ "value { tag: \"tag3\" simple_value: 10000.0 }", &s2));
+ Summary s3;
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
+ "value { tag: \"tag4\" simple_value: 11.0 }", &s3));
+
+ AddInputFromArray<string>(
+ TensorShape({3}),
+ {s1.SerializeAsString(), s2.SerializeAsString(), s3.SerializeAsString()});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the output size.
+ Tensor* out_tensor = GetOutput(0);
+ ASSERT_EQ(0, out_tensor->dims());
+ Summary summary;
+ ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+
+ EXPECT_SummaryMatches(summary,
+ "value { tag: \"tag1\" simple_value: 1.0 } "
+ "value { tag: \"tag2\" simple_value: -0.73 } "
+ "value { tag: \"tag3\" simple_value: 10000.0 }"
+ "value { tag: \"tag4\" simple_value: 11.0 }");
+}
+
+TEST_F(SummaryMergeOpTest, Simple_MultipleInputs) {
+ MakeOp(3);
+
+ // Feed and run
+ Summary s1;
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
+ "value { tag: \"tag1\" simple_value: 1.0 } "
+ "value { tag: \"tag2\" simple_value: -0.73 } ",
+ &s1));
+ Summary s2;
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
+ "value { tag: \"tag3\" simple_value: 10000.0 }", &s2));
+ Summary s3;
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
+ "value { tag: \"tag4\" simple_value: 11.0 }", &s3));
+
+ AddInputFromArray<string>(TensorShape({}), {s1.SerializeAsString()});
+ AddInputFromArray<string>(TensorShape({}), {s2.SerializeAsString()});
+ AddInputFromArray<string>(TensorShape({}), {s3.SerializeAsString()});
+ ASSERT_OK(RunOpKernel());
+
+ // Check the output size.
+ Tensor* out_tensor = GetOutput(0);
+ ASSERT_EQ(0, out_tensor->dims());
+ Summary summary;
+ ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
+
+ EXPECT_SummaryMatches(summary,
+ "value { tag: \"tag1\" simple_value: 1.0 } "
+ "value { tag: \"tag2\" simple_value: -0.73 } "
+ "value { tag: \"tag3\" simple_value: 10000.0 }"
+ "value { tag: \"tag4\" simple_value: 11.0 }");
+}
+
+TEST_F(SummaryMergeOpTest, Error_MismatchedSize) {
+ MakeOp(1);
+
+ // Feed and run
+ Summary s1;
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
+ "value { tag: \"tag1\" simple_value: 1.0 } "
+ "value { tag: \"tagduplicate\" simple_value: -0.73 } ",
+ &s1));
+ Summary s2;
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
+ "value { tag: \"tagduplicate\" simple_value: 1.0 } ", &s2));
+ AddInputFromArray<string>(TensorShape({2}),
+ {s1.SerializeAsString(), s2.SerializeAsString()});
+ Status s = RunOpKernel();
+ EXPECT_TRUE(StringPiece(s.ToString()).contains("Duplicate tag")) << s;
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/text_line_reader_op.cc b/tensorflow/core/kernels/text_line_reader_op.cc
new file mode 100644
index 0000000000..51e4d6a2b8
--- /dev/null
+++ b/tensorflow/core/kernels/text_line_reader_op.cc
@@ -0,0 +1,99 @@
+// See docs in ../ops/io_ops.cc.
+
+#include <memory>
+#include "tensorflow/core/framework/reader_op_kernel.h"
+#include "tensorflow/core/kernels/reader_base.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+
+class TextLineReader : public ReaderBase {
+ public:
+ TextLineReader(const string& node_name, int skip_header_lines, Env* env)
+ : ReaderBase(strings::StrCat("TextLineReader '", node_name, "'")),
+ skip_header_lines_(skip_header_lines),
+ env_(env),
+ line_number_(0) {}
+
+ Status OnWorkStartedLocked() override {
+ line_number_ = 0;
+ RandomAccessFile* file = nullptr;
+ TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file));
+ input_buffer_.reset(new io::InputBuffer(file, kBufferSize));
+ for (; line_number_ < skip_header_lines_; ++line_number_) {
+ string line_contents;
+ Status status = input_buffer_->ReadLine(&line_contents);
+ if (errors::IsOutOfRange(status)) {
+ // We ignore an end of file error when skipping header lines.
+ // We will end up skipping this file.
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(status);
+ }
+ return Status::OK();
+ }
+
+ Status OnWorkFinishedLocked() override {
+ input_buffer_.reset(nullptr);
+ return Status::OK();
+ }
+
+ Status ReadLocked(string* key, string* value, bool* produced,
+ bool* at_end) override {
+ Status status = input_buffer_->ReadLine(value);
+ ++line_number_;
+ if (status.ok()) {
+ *key = strings::StrCat(current_work(), ":", line_number_);
+ *produced = true;
+ return status;
+ }
+ if (errors::IsOutOfRange(status)) { // End of file, advance to the next.
+ *at_end = true;
+ return Status::OK();
+ } else { // Some other reading error
+ return status;
+ }
+ }
+
+ Status ResetLocked() override {
+ line_number_ = 0;
+ input_buffer_.reset(nullptr);
+ return ReaderBase::ResetLocked();
+ }
+
+ // TODO(josh11b): Implement serializing and restoring the state. Need
+ // to create TextLineReaderState proto to store ReaderBaseState,
+ // line_number_, and input_buffer_->Tell().
+
+ private:
+ enum { kBufferSize = 256 << 10 /* 256 kB */ };
+ const int skip_header_lines_;
+ Env* const env_;
+ int64 line_number_;
+ std::unique_ptr<io::InputBuffer> input_buffer_;
+};
+
+class TextLineReaderOp : public ReaderOpKernel {
+ public:
+ explicit TextLineReaderOp(OpKernelConstruction* context)
+ : ReaderOpKernel(context) {
+ int skip_header_lines = -1;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("skip_header_lines", &skip_header_lines));
+ OP_REQUIRES(context, skip_header_lines >= 0,
+ errors::InvalidArgument("skip_header_lines must be >= 0 not ",
+ skip_header_lines));
+ Env* env = context->env();
+ SetReaderFactory([this, skip_header_lines, env]() {
+ return new TextLineReader(name(), skip_header_lines, env);
+ });
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
+ TextLineReaderOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/tf_record_reader_op.cc b/tensorflow/core/kernels/tf_record_reader_op.cc
new file mode 100644
index 0000000000..551be18d5f
--- /dev/null
+++ b/tensorflow/core/kernels/tf_record_reader_op.cc
@@ -0,0 +1,76 @@
+// See docs in ../ops/io_ops.cc.
+
+#include <memory>
+#include "tensorflow/core/framework/reader_op_kernel.h"
+#include "tensorflow/core/kernels/reader_base.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/io/record_reader.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+
+class TFRecordReader : public ReaderBase {
+ public:
+ TFRecordReader(const string& node_name, Env* env)
+ : ReaderBase(strings::StrCat("TFRecordReader '", node_name, "'")),
+ env_(env),
+ offset_(0) {}
+
+ Status OnWorkStartedLocked() override {
+ offset_ = 0;
+ RandomAccessFile* file = nullptr;
+ TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file));
+ file_.reset(file);
+ reader_.reset(new io::RecordReader(file));
+ return Status::OK();
+ }
+
+ Status OnWorkFinishedLocked() override {
+ reader_.reset(nullptr);
+ file_.reset(nullptr);
+ return Status::OK();
+ }
+
+ Status ReadLocked(string* key, string* value, bool* produced,
+ bool* at_end) override {
+ *key = strings::StrCat(current_work(), ":", offset_);
+ Status status = reader_->ReadRecord(&offset_, value);
+ if (errors::IsOutOfRange(status)) {
+ *at_end = true;
+ return Status::OK();
+ }
+ if (!status.ok()) return status;
+ *produced = true;
+ return Status::OK();
+ }
+
+ Status ResetLocked() override {
+ offset_ = 0;
+ reader_.reset(nullptr);
+ file_.reset(nullptr);
+ return ReaderBase::ResetLocked();
+ }
+
+ // TODO(josh11b): Implement serializing and restoring the state.
+
+ private:
+ Env* const env_;
+ uint64 offset_;
+ std::unique_ptr<RandomAccessFile> file_;
+ std::unique_ptr<io::RecordReader> reader_;
+};
+
+class TFRecordReaderOp : public ReaderOpKernel {
+ public:
+ explicit TFRecordReaderOp(OpKernelConstruction* context)
+ : ReaderOpKernel(context) {
+ Env* env = context->env();
+ SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); });
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU),
+ TFRecordReaderOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc
new file mode 100644
index 0000000000..d5e0e89d60
--- /dev/null
+++ b/tensorflow/core/kernels/tile_ops.cc
@@ -0,0 +1,460 @@
+// See docs in ../ops/array_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#ifdef GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif // GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/tile_ops.h"
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+// --------------------------------------------------------------------------
+template <typename Device>
+class TileOp : public OpKernel {
+ public:
+ explicit TileOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& multiples = context->input(1);
+
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsLegacyVector(multiples.shape()),
+ errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
+ multiples.shape().ShortDebugString()));
+ OP_REQUIRES(context, input.dims() == multiples.NumElements(),
+ errors::InvalidArgument(
+ "Expected multiples argument to be a vector of length ",
+ input.dims(), " but got length ", multiples.dim_size(0)));
+
+ const int input_dims = input.dims();
+ const gtl::ArraySlice<int32> multiples_array(multiples.flat<int32>().data(),
+ input_dims);
+
+ TensorShape output_shape;
+ for (int i = 0; i < input_dims; ++i) {
+ OP_REQUIRES(
+ context, multiples_array[i] > 0,
+ errors::InvalidArgument("Expected multiples[", i, "] > 0, but got ",
+ multiples_array[i]));
+ output_shape.AddDim(input.dim_size(i) * multiples_array[i]);
+ }
+ Tensor* result = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &result));
+
+#define HANDLE_DIM(DT, NDIM) \
+ if (context->input(0).dtype() == DT && input_dims == NDIM) { \
+ HandleCase<DT, NDIM>(context, multiples_array, result); \
+ return; \
+ }
+
+#define HANDLE_TYPE(T) \
+ HANDLE_DIM(T, 0) \
+ HANDLE_DIM(T, 1) \
+ HANDLE_DIM(T, 2) \
+ HANDLE_DIM(T, 3) \
+ HANDLE_DIM(T, 4) \
+ HANDLE_DIM(T, 5)
+
+ HANDLE_TYPE(DT_BOOL);
+ HANDLE_TYPE(DT_FLOAT);
+ HANDLE_TYPE(DT_DOUBLE);
+ HANDLE_TYPE(DT_UINT8);
+ HANDLE_TYPE(DT_INT32);
+ HANDLE_TYPE(DT_INT16);
+ HANDLE_TYPE(DT_INT64);
+ HANDLE_TYPE(DT_STRING); // when DEVICE=CPUDevice.
+
+#undef HANDLE_TYPE
+#undef HANDLE_DIM
+
+ OP_REQUIRES(context, false,
+ errors::Unimplemented(
+ "TileOp : Unhandled input dimensions, DT : ",
+ context->input(0).dtype(), ", dims : ", input_dims));
+ }
+
+ private:
+ template <DataType DT, int NDIM>
+ void HandleCaseImpl(OpKernelContext* context,
+ const gtl::ArraySlice<int32>& multiples_array,
+ Tensor* result) {
+ typedef typename EnumToDataType<DT>::Type T;
+ Eigen::array<int32, NDIM> broadcast_array;
+ for (int i = 0; i < NDIM; ++i) {
+ broadcast_array[i] = multiples_array[i];
+ }
+ functor::Tile<Device, T, NDIM>()(
+ context->eigen_device<Device>(), result->tensor<T, NDIM>(),
+ context->input(0).tensor<T, NDIM>(), broadcast_array);
+ }
+
+ template <DataType DT, int NDIM>
+ void HandleCase(OpKernelContext* context,
+ const gtl::ArraySlice<int32>& multiples_array,
+ Tensor* result);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TileOp);
+};
+
+template <typename Device>
+template <DataType DT, int NDIM>
+inline void TileOp<Device>::HandleCase(
+ OpKernelContext* context, const gtl::ArraySlice<int32>& multiples_array,
+ Tensor* result) {
+ LOG(FATAL) << "TileOp: Invalid combination of Device, DT and NDIM: "
+ << typeid(Device).name() << ", " << DataTypeString(DT) << ", "
+ << NDIM;
+}
+
+#define HANDLE_CASE(device, dtype, ndim) \
+ template <> \
+ template <> \
+ void TileOp<device>::HandleCase<dtype, ndim>( \
+ OpKernelContext * context, \
+ const gtl::ArraySlice<int32>& multiples_array, Tensor* result) { \
+ HandleCaseImpl<dtype, ndim>(context, multiples_array, result); \
+ }
+
+#define HANDLE_CASE_DIM_POSITIVE(device, dtype) \
+ HANDLE_CASE(device, dtype, 1); \
+ HANDLE_CASE(device, dtype, 2); \
+ HANDLE_CASE(device, dtype, 3); \
+ HANDLE_CASE(device, dtype, 4); \
+ HANDLE_CASE(device, dtype, 5);
+
+#define HANDLE_CASE_DIM(device, dtype) \
+ HANDLE_CASE(device, dtype, 0); \
+ HANDLE_CASE_DIM_POSITIVE(device, dtype);
+
+HANDLE_CASE_DIM(CPUDevice, DT_BOOL);
+HANDLE_CASE_DIM(CPUDevice, DT_FLOAT);
+HANDLE_CASE_DIM(CPUDevice, DT_DOUBLE);
+HANDLE_CASE_DIM(CPUDevice, DT_UINT8);
+HANDLE_CASE_DIM(CPUDevice, DT_INT32);
+HANDLE_CASE_DIM(CPUDevice, DT_INT16);
+HANDLE_CASE_DIM(CPUDevice, DT_INT64);
+HANDLE_CASE_DIM(CPUDevice, DT_STRING);
+
+#if GOOGLE_CUDA
+// Eigen on GPU does not handle 0-dimension data types yet.
+HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_FLOAT);
+HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_DOUBLE);
+HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT16);
+HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT32);
+HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT64);
+#endif // GOOGLE_CUDA
+
+#undef HANDLE_CASE_DIM_POSITIVE
+#undef HANDLE_CASE_DIM
+#undef HANDLE_CASE
+
+// --------------------------------------------------------------------------
+template <typename Device>
+class TileGradientOp : public OpKernel {
+ public:
+ explicit TileGradientOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const Tensor& multiples = context->input(1);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsLegacyVector(multiples.shape()),
+ errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
+ multiples.shape().ShortDebugString()));
+ OP_REQUIRES(context, input.dims() == multiples.NumElements(),
+ errors::InvalidArgument(
+ "Expected multiples argument to be a vector of length ",
+ input.dims(), " but got length ", multiples.dim_size(0)));
+
+ const int input_dims = input.dims();
+ const gtl::ArraySlice<int32> multiples_array(multiples.flat<int32>().data(),
+ input_dims);
+
+ TensorShape output_shape;
+ std::vector<int32> input_dim_size_vec;
+ for (int i = 0; i < input_dims; ++i) {
+ OP_REQUIRES(
+ context, multiples_array[i] > 0,
+ errors::InvalidArgument("Expected multiples[", i, "] > 0, but got ",
+ multiples_array[i]));
+ OP_REQUIRES(context, input.dim_size(i) % multiples_array[i] == 0,
+ errors::InvalidArgument("Expected input_dim[", i,
+ "] to be divisible by multiples[", i,
+ "], but ", input.dim_size(i), " % ",
+ multiples_array[i], " != 0"));
+ output_shape.AddDim(input.dim_size(i) / multiples_array[i]);
+ input_dim_size_vec.push_back(input.dim_size(i));
+ }
+ Tensor* result = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &result));
+
+#define HANDLE_DIM(DT, NDIM) \
+ if (context->input(0).dtype() == DT && input_dims == NDIM) { \
+ HandleCase<DT, NDIM>(context, input_dim_size_vec, multiples_array, \
+ result); \
+ return; \
+ }
+
+#define HANDLE_TYPE(T) \
+ HANDLE_DIM(T, 0) \
+ HANDLE_DIM(T, 1) \
+ HANDLE_DIM(T, 2) \
+ HANDLE_DIM(T, 3) \
+ HANDLE_DIM(T, 4) \
+ HANDLE_DIM(T, 5)
+
+ HANDLE_TYPE(DT_FLOAT);
+ HANDLE_TYPE(DT_DOUBLE);
+ HANDLE_TYPE(DT_INT32);
+ HANDLE_TYPE(DT_INT16);
+ HANDLE_TYPE(DT_INT64);
+
+#undef HANDLE_TYPE
+#undef HANDLE_DIM
+
+ OP_REQUIRES(context, false,
+ errors::Unimplemented(
+ "TileGradientOp : Unhandled input dimensions, DT : ",
+ context->input(0).dtype(), ", dims : ", input_dims));
+ }
+
+ private:
+ template <DataType DT, int NDIM>
+ void HandleCase(OpKernelContext* context,
+ const std::vector<int32>& input_dims,
+ const gtl::ArraySlice<int32>& multiples_array,
+ Tensor* result);
+
+ template <DataType DT, int NDIM>
+ void HandleCaseImpl(OpKernelContext* context,
+ const std::vector<int32>& input_dims,
+ const gtl::ArraySlice<int32>& multiples_array,
+ Tensor* result) {
+ typedef typename EnumToDataType<DT>::Type T;
+
+ bool reduction_only = true;
+ std::vector<int> reduction_dims;
+
+ for (int i = 0; i < NDIM; ++i) {
+ if (input_dims[i] > multiples_array[i] && multiples_array[i] > 1) {
+ reduction_only = false;
+ break;
+ } else {
+ if (multiples_array[i] == input_dims[i]) {
+ reduction_dims.push_back(i);
+ }
+ }
+ }
+
+ if (reduction_only) {
+#define HANDLE_DIM(D) \
+ if (reduction_dims.size() == (D)) { \
+ HandleReduce<T, NDIM, (D)>(context, reduction_dims, result); \
+ return; \
+ }
+ // NOTE(keveman): Handling the most common case here.
+ // Adding more cases here would require more templating and code
+ // explosion. For instance, HANDLE_DIM(2) wouldn't make sense for NDIM=1.
+ HANDLE_DIM(NDIM > 0 ? 1 : 0);
+
+// Fall through to the unoptimized version.
+#undef HANDLE_DIM
+ }
+
+ Eigen::DSizes<ptrdiff_t, NDIM> indices;
+ Eigen::DSizes<ptrdiff_t, NDIM> sizes;
+
+ // Accumulate slices along the dimensions into the output. The number of
+ // slices along dimension 'i' is simply the multiple along dimension 'i'
+ // passed to the original Tile op.
+ for (int i = 0; i < NDIM; ++i) {
+ sizes[i] = input_dims[i] / multiples_array[i];
+ indices[i] = 0;
+ }
+
+ bool first = true;
+ while (true) {
+ functor::TileGrad<Device, T, NDIM>()(
+ context->eigen_device<Device>(), result->tensor<T, NDIM>(),
+ context->input(0).tensor<T, NDIM>(), indices, sizes, first);
+ first = false;
+ // Increment the begin indices.
+ int i = 0;
+ while (i < NDIM && indices[i] / sizes[i] == multiples_array[i] - 1) {
+ indices[i] = 0;
+ ++i;
+ }
+ // We are finished if we have iterated to the maximum along all
+ // dimensions.
+ if (i == NDIM) {
+ break;
+ }
+ indices[i] += sizes[i];
+ }
+ }
+
+ template <typename T, int NDIM, int REDUCENDIM>
+ void HandleReduce(OpKernelContext* context,
+ const std::vector<int32>& reduce_dim_in, Tensor* result) {
+ static_assert(NDIM >= REDUCENDIM, "Too many reduced dimensions");
+ Eigen::DSizes<ptrdiff_t, REDUCENDIM> reduce_dim;
+ Eigen::DSizes<ptrdiff_t, NDIM> reshape_dim;
+
+ for (int i = 0; i < REDUCENDIM; ++i) {
+ reduce_dim[i] = reduce_dim_in[i];
+ }
+
+ for (int i = 0; i < NDIM; ++i) {
+ reshape_dim[i] = result->dim_size(i);
+ }
+
+ functor::ReduceAndReshape<Device, T, NDIM, REDUCENDIM>()(
+ context->eigen_device<Device>(), result->tensor<T, NDIM>(),
+ context->input(0).tensor<T, NDIM>(), reduce_dim, reshape_dim);
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TileGradientOp);
+};
+
+template <typename Device>
+template <DataType DT, int NDIM>
+inline void TileGradientOp<Device>::HandleCase(
+ OpKernelContext* context, const std::vector<int32>& input_dims,
+ const gtl::ArraySlice<int32>& multiples_array, Tensor* result) {
+ LOG(FATAL) << "TileGradientOp: Invalid combination of Device, DT and NDIM: "
+ << typeid(Device).name() << ", " << DataTypeString(DT) << ", "
+ << NDIM;
+}
+
+#define HANDLE_CASE(device, dtype, ndim) \
+ template <> \
+ template <> \
+ void TileGradientOp<device>::HandleCase<dtype, ndim>( \
+ OpKernelContext * context, const std::vector<int32>& input_dims, \
+ const gtl::ArraySlice<int32>& multiples_array, Tensor* result) { \
+ HandleCaseImpl<dtype, ndim>(context, input_dims, multiples_array, result); \
+ }
+
+#define HANDLE_CASE_DIM_POSITIVE(device, dtype) \
+ HANDLE_CASE(device, dtype, 1); \
+ HANDLE_CASE(device, dtype, 2); \
+ HANDLE_CASE(device, dtype, 3); \
+ HANDLE_CASE(device, dtype, 4); \
+ HANDLE_CASE(device, dtype, 5);
+
+#define HANDLE_CASE_DIM(device, dtype) \
+ HANDLE_CASE(device, dtype, 0); \
+ HANDLE_CASE_DIM_POSITIVE(device, dtype);
+
+HANDLE_CASE_DIM(CPUDevice, DT_FLOAT);
+HANDLE_CASE_DIM(CPUDevice, DT_DOUBLE);
+HANDLE_CASE_DIM(CPUDevice, DT_INT16);
+HANDLE_CASE_DIM(CPUDevice, DT_INT32);
+HANDLE_CASE_DIM(CPUDevice, DT_INT64);
+
+#if GOOGLE_CUDA
+// Eigen on GPU does not handle 0-dimension data types yet.
+HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_FLOAT);
+HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_DOUBLE);
+HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT16);
+HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT32);
+HANDLE_CASE_DIM_POSITIVE(GPUDevice, DT_INT64);
+#endif // GOOGLE_CUDA
+
+#undef HANDLE_CASE_DIM_POSITIVE
+#undef HANDLE_CASE_DIM
+#undef HANDLE_CASE
+
+REGISTER_KERNEL_BUILDER(Name("Tile").Device(DEVICE_CPU).HostMemory("multiples"),
+ TileOp<CPUDevice>);
+REGISTER_KERNEL_BUILDER(Name("TileGrad")
+ .Device(DEVICE_CPU)
+ .HostMemory("multiples"),
+ TileGradientOp<CPUDevice>);
+
+#if GOOGLE_CUDA
+#define DEFINE_GPU_TYPE(T) \
+ DEFINE_GPU_DIM(T, 1) \
+ DEFINE_GPU_DIM(T, 2) \
+ DEFINE_GPU_DIM(T, 3) \
+ DEFINE_GPU_DIM(T, 4) \
+ DEFINE_GPU_DIM(T, 5)
+
+#define DEFINE_GPU_DIM(T, NDIM) \
+ template <> \
+ void Tile<GPUDevice, T, NDIM>::operator()( \
+ const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
+ typename TTypes<T, NDIM>::ConstTensor in, \
+ const Eigen::array<int32, NDIM>& broadcast_array) const; \
+ extern template struct Tile<GPUDevice, T, NDIM>; \
+ template <> \
+ void TileGrad<GPUDevice, T, NDIM>::operator()( \
+ const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
+ typename TTypes<T, NDIM>::ConstTensor in, \
+ const Eigen::DSizes<ptrdiff_t, NDIM>& indices, \
+ const Eigen::DSizes<ptrdiff_t, NDIM>& sizes, bool first) const; \
+ extern template struct TileGrad<GPUDevice, T, NDIM>; \
+ template <> \
+ void ReduceAndReshape<GPUDevice, T, NDIM, 1>::operator()( \
+ const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
+ typename TTypes<T, NDIM>::ConstTensor in, \
+ const Eigen::DSizes<ptrdiff_t, 1>& reduce_dim, \
+ const Eigen::DSizes<ptrdiff_t, NDIM>& reshape_dim) const; \
+ extern template struct ReduceAndReshape<GPUDevice, T, NDIM, 1>;
+
+namespace functor {
+DEFINE_GPU_TYPE(float);
+DEFINE_GPU_TYPE(double);
+DEFINE_GPU_TYPE(int64);
+DEFINE_GPU_TYPE(int32);
+DEFINE_GPU_TYPE(int16);
+} // end namespace functor
+
+#undef DEFINE_GPU_DIM
+#undef DEFINE_GPU_TYPE
+
+REGISTER_KERNEL_BUILDER(Name("Tile")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .HostMemory("multiples"),
+ TileOp<GPUDevice>);
+REGISTER_KERNEL_BUILDER(Name("Tile")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<double>("T")
+ .HostMemory("multiples"),
+ TileOp<GPUDevice>);
+REGISTER_KERNEL_BUILDER(Name("Tile")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int16>("T")
+ .HostMemory("multiples"),
+ TileOp<GPUDevice>);
+
+REGISTER_KERNEL_BUILDER(Name("TileGrad")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .HostMemory("multiples"),
+ TileGradientOp<GPUDevice>);
+REGISTER_KERNEL_BUILDER(Name("TileGrad")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<double>("T")
+ .HostMemory("multiples"),
+ TileGradientOp<GPUDevice>);
+REGISTER_KERNEL_BUILDER(Name("TileGrad")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int16>("T")
+ .HostMemory("multiples"),
+ TileGradientOp<GPUDevice>);
+#endif // GOOGLE_CUDA
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/tile_ops.h b/tensorflow/core/kernels/tile_ops.h
new file mode 100644
index 0000000000..b3cc6165e0
--- /dev/null
+++ b/tensorflow/core/kernels/tile_ops.h
@@ -0,0 +1,48 @@
+#ifndef TENSORFLOW_KERNELS_TILE_OPS_H_
+#define TENSORFLOW_KERNELS_TILE_OPS_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T, int NDIM>
+struct Tile {
+ void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out,
+ typename TTypes<T, NDIM>::ConstTensor in,
+ const Eigen::array<int32, NDIM>& broadcast_array) const {
+ out.device(d) = in.broadcast(broadcast_array);
+ }
+};
+
+template <typename Device, typename T, int NDIM>
+struct TileGrad {
+ void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out,
+ typename TTypes<T, NDIM>::ConstTensor in,
+ const Eigen::DSizes<ptrdiff_t, NDIM>& indices,
+ const Eigen::DSizes<ptrdiff_t, NDIM>& sizes,
+ bool first) const {
+ if (first) {
+ out.device(d) = in.slice(indices, sizes);
+ } else {
+ out.device(d) += in.slice(indices, sizes);
+ }
+ }
+};
+
+template <typename Device, typename T, int NDIM, int REDUCEDNDIM>
+struct ReduceAndReshape {
+ void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out,
+ typename TTypes<T, NDIM>::ConstTensor in,
+ const Eigen::DSizes<ptrdiff_t, REDUCEDNDIM>& reduce_dim,
+ const Eigen::DSizes<ptrdiff_t, NDIM>& reshape_dim) const {
+ out.device(d) = in.sum(reduce_dim).reshape(reshape_dim);
+ }
+};
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_TILE_OPS_H_
diff --git a/tensorflow/core/kernels/tile_ops_gpu.cu.cc b/tensorflow/core/kernels/tile_ops_gpu.cu.cc
new file mode 100644
index 0000000000..29481e1a54
--- /dev/null
+++ b/tensorflow/core/kernels/tile_ops_gpu.cu.cc
@@ -0,0 +1,38 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/tile_ops.h"
+#include <stdio.h>
+
+namespace tensorflow {
+namespace functor {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#define DEFINE_TYPE(T) \
+ DEFINE_DIM(T, 1) \
+ DEFINE_DIM(T, 2) \
+ DEFINE_DIM(T, 3) \
+ DEFINE_DIM(T, 4) \
+ DEFINE_DIM(T, 5)
+
+#define DEFINE_DIM(T, NDIM) \
+ template struct Tile<GPUDevice, T, NDIM>; \
+ template struct TileGrad<GPUDevice, T, NDIM>; \
+ template struct ReduceAndReshape<GPUDevice, T, NDIM, 1>;
+
+DEFINE_TYPE(float)
+DEFINE_TYPE(double)
+DEFINE_TYPE(int64)
+DEFINE_TYPE(int32)
+DEFINE_TYPE(int16)
+// NOTE(keveman): Eigen's int8 and string versions don't compile yet with nvcc.
+
+#undef DEFINE_DIM
+#undef DEFINE_TYPE
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/topk_op.cc b/tensorflow/core/kernels/topk_op.cc
new file mode 100644
index 0000000000..79b5d4d07e
--- /dev/null
+++ b/tensorflow/core/kernels/topk_op.cc
@@ -0,0 +1,71 @@
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/lib/gtl/top_n.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+template <typename T>
+class TopK : public OpKernel {
+ public:
+ explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const auto& input_in = context->input(0);
+ OP_REQUIRES(context, input_in.dims() == 2,
+ errors::InvalidArgument("input must be 2-dimensional"));
+ OP_REQUIRES(context, input_in.dim_size(1) >= k_,
+ errors::InvalidArgument("input must have at least k columns"));
+
+ const auto& input = input_in.matrix<T>();
+
+ const auto num_rows = input_in.dim_size(0); // generally batch_size
+ const auto num_cols = input_in.dim_size(1);
+
+ Tensor* values_out = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, TensorShape({num_rows, k_}), &values_out));
+ Tensor* indices_out = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 1, TensorShape({num_rows, k_}), &indices_out));
+ auto values = values_out->matrix<T>();
+ auto indices = indices_out->matrix<int32>();
+
+ gtl::TopN<std::pair<T, int32>> filter(k_);
+
+ for (int r = 0; r < num_rows; r++) {
+ for (int32 c = 0; c < num_cols; ++c) {
+ // The second element is the negated index, so that lower-index elements
+ // are considered larger than higher-index elements in case of ties.
+ filter.push(std::make_pair(input(r, c), -c));
+ }
+
+ std::unique_ptr<std::vector<std::pair<T, int32>>> top_k(filter.Extract());
+ for (int32 i = 0; i < k_; ++i) {
+ values(r, i) = (*top_k)[i].first;
+ indices(r, i) = -(*top_k)[i].second;
+ }
+ filter.Reset();
+ }
+ }
+
+ private:
+ int k_;
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("TopK").Device(DEVICE_CPU).TypeConstraint<type>("T"), TopK<type>)
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
new file mode 100644
index 0000000000..611fa4ac41
--- /dev/null
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -0,0 +1,884 @@
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/training_ops.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+static inline bool DoInline(int64 size) { return size <= (256ll << 10); }
+
+template <typename T>
+struct ApplyGradientDescent<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstFlat grad) {
+ if (DoInline(var.size())) {
+ var -= grad * lr();
+ } else {
+ var.device(d) -= grad * lr();
+ }
+ }
+};
+
+template <typename T>
+struct ApplyAdagrad<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat accum,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstFlat grad) {
+ if (DoInline(var.size())) {
+ accum += grad.square();
+ var -= grad * lr() * accum.rsqrt();
+ } else {
+ accum.device(d) += grad.square();
+ var.device(d) -= grad * lr() * accum.rsqrt();
+ }
+ }
+};
+
+template <typename T>
+struct ApplyMomentum<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat accum,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstFlat grad,
+ typename TTypes<T>::ConstScalar momentum) {
+ if (DoInline(var.size())) {
+ accum = accum * momentum() + grad;
+ var -= accum * lr();
+ } else {
+ accum.device(d) = accum * momentum() + grad;
+ var.device(d) -= accum * lr();
+ }
+ }
+};
+
+template <typename T>
+struct ApplyAdam<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
+ typename TTypes<T>::ConstScalar beta1_power,
+ typename TTypes<T>::ConstScalar beta2_power,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstScalar beta1,
+ typename TTypes<T>::ConstScalar beta2,
+ typename TTypes<T>::ConstScalar epsilon,
+ typename TTypes<T>::ConstFlat grad) {
+ const T alpha = lr() * std::sqrt(1 - beta2_power()) / (1 - beta1_power());
+ if (DoInline(var.size())) {
+ m += (grad - m) * (1 - beta1());
+ v += (grad.square() - v) * (1 - beta2());
+ var -= (m * alpha) / (v.sqrt() + epsilon());
+ } else {
+ m.device(d) += (grad - m) * (1 - beta1());
+ v.device(d) += (grad.square() - v) * (1 - beta2());
+ var.device(d) -= (m * alpha) / (v.sqrt() + epsilon());
+ }
+ }
+};
+
+template <typename T>
+struct ApplyRMSProp<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstScalar rho,
+ typename TTypes<T>::ConstScalar momentum,
+ typename TTypes<T>::ConstScalar epsilon,
+ typename TTypes<T>::ConstFlat grad) {
+ if (DoInline(var.size())) {
+ ms += (grad.square() - ms) * (1 - rho());
+ mom = mom * momentum() + (grad * lr()) / ((ms + epsilon()).sqrt());
+ var -= mom;
+ } else {
+ ms.device(d) += (grad.square() - ms) * (1 - rho());
+ mom.device(d) =
+ mom * momentum() + (grad * lr()) / ((ms + epsilon()).sqrt());
+ var.device(d) -= mom;
+ }
+ }
+};
+
+} // namespace functor
+
+template <typename Device, typename T>
+class ApplyGradientDescentOp : public OpKernel {
+ public:
+ explicit ApplyGradientDescentOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ if (use_exclusive_lock_) {
+ mutex_lock l(*ctx->input_ref_mutex(0));
+ DoValidate(ctx);
+ if (!ctx->status().ok()) return;
+ DoCompute(ctx);
+ } else {
+ DoValidate(ctx);
+ if (!ctx->status().ok()) return;
+ DoCompute(ctx);
+ }
+ ctx->forward_ref_input_to_ref_output(0, 0);
+ }
+
+ private:
+ bool use_exclusive_lock_;
+
+ void DoValidate(OpKernelContext* ctx) {
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ OP_REQUIRES(
+ ctx, var.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(0)));
+ const Tensor& alpha = ctx->input(1);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(alpha.shape()),
+ errors::InvalidArgument("alpha is not a scalar: ",
+ alpha.shape().DebugString()));
+ const Tensor& delta = ctx->input(2);
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(delta.shape()),
+ errors::InvalidArgument("var and delta do not have the same shape",
+ var.shape().DebugString(), " ",
+ delta.shape().DebugString()));
+ }
+
+ void DoCompute(OpKernelContext* ctx) {
+ const Device& device = ctx->template eigen_device<Device>();
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ const Tensor& alpha = ctx->input(1);
+ const Tensor& delta = ctx->input(2);
+ functor::ApplyGradientDescent<Device, T>()(
+ device, var.flat<T>(), alpha.scalar<T>(), delta.flat<T>());
+ }
+};
+
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ApplyGradientDescent").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ ApplyGradientDescentOp<D##Device, T>);
+
+REGISTER_KERNELS(CPU, float);
+REGISTER_KERNELS(CPU, double);
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void ApplyGradientDescent<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::Flat var, \
+ typename TTypes<T>::ConstScalar alpha, \
+ typename TTypes<T>::ConstFlat delta); \
+ extern template struct ApplyGradientDescent<GPUDevice, T>;
+DECLARE_GPU_SPEC(float);
+DECLARE_GPU_SPEC(double);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+REGISTER_KERNELS(GPU, float);
+REGISTER_KERNELS(GPU, double);
+#endif
+#undef REGISTER_KERNELS
+
+template <typename Device, typename T>
+class ApplyAdagradOp : public OpKernel {
+ public:
+ explicit ApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ if (use_exclusive_lock_) {
+ mutex_lock l1(*ctx->input_ref_mutex(0));
+ // Don't try to acquire a lock on the second ref as they share the same
+ // mutex.
+ //
+ // mutex_lock l2(*ctx->input_ref_mutex(1));
+ DoValidate(ctx);
+ if (!ctx->status().ok()) return;
+ DoCompute(ctx);
+ } else {
+ DoValidate(ctx);
+ if (!ctx->status().ok()) return;
+ DoCompute(ctx);
+ }
+ ctx->forward_ref_input_to_ref_output(0, 0);
+ }
+
+ private:
+ bool use_exclusive_lock_;
+
+ void DoValidate(OpKernelContext* ctx) {
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
+ OP_REQUIRES(
+ ctx, var.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(0)));
+ OP_REQUIRES(
+ ctx, accum.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(1)));
+ const Tensor& lr = ctx->input(2);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(lr.shape()),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr.shape().DebugString()));
+ const Tensor& grad = ctx->input(3);
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(accum.shape()),
+ errors::InvalidArgument("var and accum do not have the same shape",
+ var.shape().DebugString(), " ",
+ accum.shape().DebugString()));
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(grad.shape()),
+ errors::InvalidArgument("var and delta do not have the same shape",
+ var.shape().DebugString(), " ",
+ grad.shape().DebugString()));
+ }
+
+ void DoCompute(OpKernelContext* ctx) {
+ const Device& device = ctx->template eigen_device<Device>();
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
+ const Tensor& lr = ctx->input(2);
+ const Tensor& grad = ctx->input(3);
+ functor::ApplyAdagrad<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
+ lr.scalar<T>(), grad.flat<T>());
+ }
+};
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ApplyAdagrad").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ ApplyAdagradOp<D##Device, T>);
+
+REGISTER_KERNELS(CPU, float);
+REGISTER_KERNELS(CPU, double);
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void ApplyAdagrad<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::Flat var, \
+ typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \
+ typename TTypes<T>::ConstFlat grad); \
+ extern template struct ApplyAdagrad<GPUDevice, T>;
+DECLARE_GPU_SPEC(float);
+DECLARE_GPU_SPEC(double);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+REGISTER_KERNELS(GPU, float);
+REGISTER_KERNELS(GPU, double);
+#endif
+#undef REGISTER_KERNELS
+
+// Note, this op works on cpu only.
+template <typename T, typename Tindex>
+class SparseApplyAdagradOp : public OpKernel {
+ public:
+ explicit SparseApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ }
+
+ void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
+ mutex* mu_var = ctx->input_ref_mutex(0);
+ // mu_accum is actually the same mutex as mu_var since currently we use a
+ // global mutex.
+ //
+ // mutex* mu_accum = ctx->input_ref_mutex(1);
+ if (use_exclusive_lock_) {
+ mu_var->lock();
+ }
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
+ OP_REQUIRES(
+ ctx, var.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(0)));
+ OP_REQUIRES(
+ ctx, accum.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(1)));
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(accum.shape()),
+ errors::InvalidArgument("var and accum do not have the same shape",
+ var.shape().DebugString(), " ",
+ accum.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
+ errors::InvalidArgument("var must be at least 1 dimensional"));
+
+ const Tensor& lr = ctx->input(2);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(lr.shape()),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr.shape().DebugString()));
+ const Tensor& grad = ctx->input(3);
+ const Tensor& indices = ctx->input(4);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
+ errors::InvalidArgument("indices must be one-dimensional"));
+
+ for (int d = 1; d < var.dims(); d++) {
+ OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
+ errors::InvalidArgument(strings::StrCat(
+ "var and grad must match in dimension ", d)));
+ }
+ const Tindex N = indices.dim_size(0);
+ OP_REQUIRES(
+ ctx, grad.dim_size(0) == N,
+ errors::InvalidArgument(
+ "grad must be the same size as indices in the first dimension."));
+
+ if (N > 0) {
+ const Tindex first_dim_size = var.dim_size(0);
+ // Validate all the indices are in range
+ auto indices_vec = indices.vec<Tindex>();
+ for (Tindex i = 0; i < N; i++) {
+ const Tindex index = indices_vec(i);
+ OP_REQUIRES(ctx, index >= 0 && index < first_dim_size,
+ errors::InvalidArgument(
+ strings::StrCat("Index ", index, " at offset ", i,
+ " in indices is out of range")));
+ }
+
+ auto var_flat = var.flat_outer_dims<T>();
+ auto accum_flat = accum.flat_outer_dims<T>();
+ auto grad_flat = grad.flat_outer_dims<T>();
+ T lr_scalar = lr.scalar<T>()();
+
+ // Note(yonghui): It might be worth multi-threading square() and rsqrt().
+ for (Tindex i = 0; i < N; i++) {
+ const Tindex index = indices_vec(i);
+ auto a = accum_flat.template chip<0>(index);
+ auto g = grad_flat.template chip<0>(i);
+ auto v = var_flat.template chip<0>(index);
+ a += g.square();
+ v -= g.constant(lr_scalar) * g * a.rsqrt();
+ }
+ }
+ if (use_exclusive_lock_) {
+ mu_var->unlock();
+ }
+
+ ctx->forward_ref_input_to_ref_output(0, 0);
+ }
+
+ private:
+ bool use_exclusive_lock_;
+};
+
+#define REGISTER_KERNELS(T, Tindices) \
+ REGISTER_KERNEL_BUILDER(Name("SparseApplyAdagrad") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
+ SparseApplyAdagradOp<T, Tindices>);
+
+REGISTER_KERNELS(float, int32);
+REGISTER_KERNELS(float, int64);
+REGISTER_KERNELS(double, int32);
+REGISTER_KERNELS(double, int64);
+#undef REGISTER_KERNELS
+
+template <typename Device, typename T>
+class ApplyMomentumOp : public OpKernel {
+ public:
+ explicit ApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ if (use_exclusive_lock_) {
+ mutex_lock l1(*ctx->input_ref_mutex(0));
+ // Don't try to acquire a lock on the second ref as they share the same
+ // mutex.
+ //
+ // mutex_lock l2(*ctx->input_ref_mutex(1));
+ DoValidate(ctx);
+ if (!ctx->status().ok()) return;
+ DoCompute(ctx);
+ } else {
+ DoValidate(ctx);
+ if (!ctx->status().ok()) return;
+ DoCompute(ctx);
+ }
+ ctx->forward_ref_input_to_ref_output(0, 0);
+ }
+
+ private:
+ bool use_exclusive_lock_;
+
+ void DoValidate(OpKernelContext* ctx) {
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
+ OP_REQUIRES(
+ ctx, var.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(0)));
+ OP_REQUIRES(
+ ctx, accum.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(1)));
+ const Tensor& lr = ctx->input(2);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr.shape().DebugString()));
+ const Tensor& grad = ctx->input(3);
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(accum.shape()),
+ errors::InvalidArgument("var and accum do not have the same shape",
+ var.shape().DebugString(), " ",
+ accum.shape().DebugString()));
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(grad.shape()),
+ errors::InvalidArgument("var and delta do not have the same shape",
+ var.shape().DebugString(), " ",
+ grad.shape().DebugString()));
+
+ const Tensor& momentum = ctx->input(4);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
+ errors::InvalidArgument("momentum is not a scalar: ",
+ momentum.shape().DebugString()));
+ }
+
+ void DoCompute(OpKernelContext* ctx) {
+ const Device& device = ctx->template eigen_device<Device>();
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
+ const Tensor& lr = ctx->input(2);
+ const Tensor& grad = ctx->input(3);
+ const Tensor& momentum = ctx->input(4);
+ functor::ApplyMomentum<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
+ lr.scalar<T>(), grad.flat<T>(),
+ momentum.scalar<T>());
+ }
+};
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ApplyMomentum").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ ApplyMomentumOp<D##Device, T>);
+
+REGISTER_KERNELS(CPU, float);
+REGISTER_KERNELS(CPU, double);
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void ApplyMomentum<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::Flat var, \
+ typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \
+ typename TTypes<T>::ConstFlat grad, \
+ typename TTypes<T>::ConstScalar momentum); \
+ extern template struct ApplyMomentum<GPUDevice, T>;
+DECLARE_GPU_SPEC(float);
+DECLARE_GPU_SPEC(double);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+REGISTER_KERNELS(GPU, float);
+REGISTER_KERNELS(GPU, double);
+#endif
+#undef REGISTER_KERNELS
+
+// Note, this op works on cpu only.
+template <typename T, typename Tindex>
+class SparseApplyMomentumOp : public OpKernel {
+ public:
+ explicit SparseApplyMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ }
+
+ void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
+ mutex* mu_var = ctx->input_ref_mutex(0);
+ // mu_accum is actually the same mutex as mu_var since currently we use a
+ // global mutex.
+ //
+ // mutex* mu_accum = ctx->input_ref_mutex(1);
+ if (use_exclusive_lock_) {
+ mu_var->lock();
+ }
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor accum = ctx->mutable_input(1, use_exclusive_lock_);
+ OP_REQUIRES(
+ ctx, var.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(0)));
+ OP_REQUIRES(
+ ctx, accum.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(1)));
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(accum.shape()),
+ errors::InvalidArgument("var and accum do not have the same shape",
+ var.shape().DebugString(), " ",
+ accum.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()),
+ errors::InvalidArgument("var must be at least 1 dimensional"));
+
+ const Tensor& lr = ctx->input(2);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr.shape().DebugString()));
+ const Tensor& grad = ctx->input(3);
+ const Tensor& indices = ctx->input(4);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
+ errors::InvalidArgument("indices must be one-dimensional"));
+
+ for (int d = 1; d < var.dims(); d++) {
+ OP_REQUIRES(ctx, var.dim_size(d) == grad.dim_size(d),
+ errors::InvalidArgument(strings::StrCat(
+ "var and grad must match in dimension ", d)));
+ }
+ const Tindex N = indices.dim_size(0);
+ OP_REQUIRES(
+ ctx, grad.dim_size(0) == N,
+ errors::InvalidArgument(
+ "grad must be the same size as indices in the first dimension."));
+
+ const Tensor& momentum = ctx->input(5);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
+ errors::InvalidArgument("momentum is not a scalar: ",
+ momentum.shape().DebugString()));
+
+ if (N > 0) {
+ const Tindex first_dim_size = var.dim_size(0);
+ // Validate all the indices are in range
+ auto indices_vec = indices.vec<Tindex>();
+ for (Tindex i = 0; i < N; i++) {
+ const Tindex index = indices_vec(i);
+ OP_REQUIRES(ctx, index >= 0 && index < first_dim_size,
+ errors::InvalidArgument(
+ strings::StrCat("Index ", index, " at offset ", i,
+ " in indices is out of range")));
+ }
+
+ auto var_flat = var.flat_outer_dims<T>();
+ auto accum_flat = accum.flat_outer_dims<T>();
+ auto grad_flat = grad.flat_outer_dims<T>();
+ T lr_scalar = lr.scalar<T>()();
+ T momentum_scalar = momentum.scalar<T>()();
+
+ for (Tindex i = 0; i < N; i++) {
+ const Tindex index = indices_vec(i);
+ auto a = accum_flat.template chip<0>(index);
+ auto g = grad_flat.template chip<0>(i);
+ auto v = var_flat.template chip<0>(index);
+ a = a * a.constant(momentum_scalar) + g;
+ v -= a.constant(lr_scalar) * a;
+ }
+ }
+ if (use_exclusive_lock_) {
+ mu_var->unlock();
+ }
+
+ ctx->forward_ref_input_to_ref_output(0, 0);
+ }
+
+ private:
+ bool use_exclusive_lock_;
+};
+
+#define REGISTER_KERNELS(T, Tindices) \
+ REGISTER_KERNEL_BUILDER(Name("SparseApplyMomentum") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
+ SparseApplyMomentumOp<T, Tindices>);
+
+REGISTER_KERNELS(float, int32);
+REGISTER_KERNELS(float, int64);
+REGISTER_KERNELS(double, int32);
+REGISTER_KERNELS(double, int64);
+#undef REGISTER_KERNELS
+
+template <typename Device, typename T>
+class ApplyAdamOp : public OpKernel {
+ public:
+ explicit ApplyAdamOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ if (use_exclusive_lock_) {
+ // all input refs share the same mutex
+ mutex_lock l1(*ctx->input_ref_mutex(0));
+ DoValidate(ctx);
+ if (!ctx->status().ok()) return;
+ DoCompute(ctx);
+ } else {
+ DoValidate(ctx);
+ if (!ctx->status().ok()) return;
+ DoCompute(ctx);
+ }
+ ctx->forward_ref_input_to_ref_output(0, 0);
+ }
+
+ private:
+ bool use_exclusive_lock_;
+
+ void DoValidate(OpKernelContext* ctx) {
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor m = ctx->mutable_input(1, use_exclusive_lock_);
+ Tensor v = ctx->mutable_input(2, use_exclusive_lock_);
+ OP_REQUIRES(
+ ctx, var.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(0)));
+ OP_REQUIRES(
+ ctx, m.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(1)));
+ OP_REQUIRES(
+ ctx, v.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(2)));
+
+ const Tensor& beta1_power = ctx->input(3);
+ const Tensor& beta2_power = ctx->input(4);
+ const Tensor& lr = ctx->input(5);
+ const Tensor& beta1 = ctx->input(6);
+ const Tensor& beta2 = ctx->input(7);
+ const Tensor& epsilon = ctx->input(8);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()),
+ errors::InvalidArgument("beta1_power is not a scalar: ",
+ beta1_power.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_power.shape()),
+ errors::InvalidArgument("beta2_power is not a scalar: ",
+ beta2_power.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()),
+ errors::InvalidArgument("beta1 is not a scalar: ",
+ beta1.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()),
+ errors::InvalidArgument("beta2 is not a scalar: ",
+ beta2.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
+ errors::InvalidArgument("epsilon is not a scalar: ",
+ epsilon.shape().DebugString()));
+
+ const Tensor& grad = ctx->input(9);
+ OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()),
+ errors::InvalidArgument("var and m do not have the same shape",
+ var.shape().DebugString(), " ",
+ m.shape().DebugString()));
+ OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()),
+ errors::InvalidArgument("var and v do not have the same shape",
+ var.shape().DebugString(), " ",
+ v.shape().DebugString()));
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(grad.shape()),
+ errors::InvalidArgument("var and grad do not have the same shape",
+ var.shape().DebugString(), " ",
+ grad.shape().DebugString()));
+ }
+
+ void DoCompute(OpKernelContext* ctx) {
+ const Device& device = ctx->template eigen_device<Device>();
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor m = ctx->mutable_input(1, use_exclusive_lock_);
+ Tensor v = ctx->mutable_input(2, use_exclusive_lock_);
+ const Tensor& beta1_power = ctx->input(3);
+ const Tensor& beta2_power = ctx->input(4);
+ const Tensor& lr = ctx->input(5);
+ const Tensor& beta1 = ctx->input(6);
+ const Tensor& beta2 = ctx->input(7);
+ const Tensor& epsilon = ctx->input(8);
+ const Tensor& grad = ctx->input(9);
+
+ functor::ApplyAdam<Device, T>()(device, var.flat<T>(), m.flat<T>(),
+ v.flat<T>(), beta1_power.scalar<T>(),
+ beta2_power.scalar<T>(), lr.scalar<T>(),
+ beta1.scalar<T>(), beta2.scalar<T>(),
+ epsilon.scalar<T>(), grad.flat<T>());
+ }
+};
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ApplyAdam").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ ApplyAdamOp<D##Device, T>);
+
+REGISTER_KERNELS(CPU, float);
+REGISTER_KERNELS(CPU, double);
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void ApplyAdam<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::Flat var, \
+ typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \
+ typename TTypes<T>::ConstScalar beta1_power, \
+ typename TTypes<T>::ConstScalar beta2_power, \
+ typename TTypes<T>::ConstScalar lr, \
+ typename TTypes<T>::ConstScalar beta1, \
+ typename TTypes<T>::ConstScalar beta2, \
+ typename TTypes<T>::ConstScalar epsilon, \
+ typename TTypes<T>::ConstFlat grad); \
+ extern template struct ApplyAdam<GPUDevice, T>;
+DECLARE_GPU_SPEC(float);
+DECLARE_GPU_SPEC(double);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+REGISTER_KERNELS(GPU, float);
+REGISTER_KERNELS(GPU, double);
+#endif
+#undef REGISTER_KERNELS
+
+template <typename Device, typename T>
+class ApplyRMSPropOp : public OpKernel {
+ public:
+ explicit ApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ if (use_exclusive_lock_) {
+ // all input refs share the same mutex
+ mutex_lock l1(*ctx->input_ref_mutex(0));
+ DoValidate(ctx);
+ if (!ctx->status().ok()) return;
+ DoCompute(ctx);
+ } else {
+ DoValidate(ctx);
+ if (!ctx->status().ok()) return;
+ DoCompute(ctx);
+ }
+ ctx->forward_ref_input_to_ref_output(0, 0);
+ }
+
+ private:
+ bool use_exclusive_lock_;
+
+ void DoValidate(OpKernelContext* ctx) {
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor ms = ctx->mutable_input(1, use_exclusive_lock_);
+ Tensor mom = ctx->mutable_input(2, use_exclusive_lock_);
+
+ OP_REQUIRES(
+ ctx, var.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(0)));
+ OP_REQUIRES(
+ ctx, ms.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(1)));
+ OP_REQUIRES(
+ ctx, mom.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(2)));
+
+ const Tensor& lr = ctx->input(3);
+ const Tensor& rho = ctx->input(4);
+ const Tensor& momentum = ctx->input(5);
+ const Tensor& epsilon = ctx->input(6);
+ const Tensor& grad = ctx->input(7);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
+ errors::InvalidArgument("rho is not a scalar: ",
+ rho.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
+ errors::InvalidArgument("momentum is not a scalar: ",
+ momentum.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
+ errors::InvalidArgument("epsilon is not a scalar: ",
+ epsilon.shape().DebugString()));
+
+ OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()),
+ errors::InvalidArgument("var and ms do not have the same shape",
+ var.shape().DebugString(), " ",
+ ms.shape().DebugString()));
+
+ OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()),
+ errors::InvalidArgument(
+ "var and mom do not have the same shape",
+ var.shape().DebugString(), " ", mom.shape().DebugString()));
+
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(grad.shape()),
+ errors::InvalidArgument("var and grad do not have the same shape",
+ var.shape().DebugString(), " ",
+ grad.shape().DebugString()));
+ }
+
+ void DoCompute(OpKernelContext* ctx) {
+ const Device& device = ctx->template eigen_device<Device>();
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor ms = ctx->mutable_input(1, use_exclusive_lock_);
+ Tensor mom = ctx->mutable_input(2, use_exclusive_lock_);
+ const Tensor& lr = ctx->input(3);
+ const Tensor& rho = ctx->input(4);
+ const Tensor& momentum = ctx->input(5);
+ const Tensor& epsilon = ctx->input(6);
+ const Tensor& grad = ctx->input(7);
+
+ functor::ApplyRMSProp<Device, T>()(device, var.flat<T>(), ms.flat<T>(),
+ mom.flat<T>(), lr.scalar<T>(),
+ rho.scalar<T>(), momentum.scalar<T>(),
+ epsilon.scalar<T>(), grad.flat<T>());
+ }
+};
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ApplyRMSProp").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ ApplyRMSPropOp<D##Device, T>);
+
+REGISTER_KERNELS(CPU, float);
+REGISTER_KERNELS(CPU, double);
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void ApplyRMSProp<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::Flat var, \
+ typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom, \
+ typename TTypes<T>::ConstScalar lr, typename TTypes<T>::ConstScalar rho, \
+ typename TTypes<T>::ConstScalar momentum, \
+ typename TTypes<T>::ConstScalar epsilon, \
+ typename TTypes<T>::ConstFlat grad); \
+ extern template struct ApplyRMSProp<GPUDevice, T>;
+DECLARE_GPU_SPEC(float);
+DECLARE_GPU_SPEC(double);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+REGISTER_KERNELS(GPU, float);
+REGISTER_KERNELS(GPU, double);
+#endif
+#undef REGISTER_KERNELS
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h
new file mode 100644
index 0000000000..71f6d0253d
--- /dev/null
+++ b/tensorflow/core/kernels/training_ops.h
@@ -0,0 +1,65 @@
+#ifndef TENSORFLOW_KERNELS_TRAINING_OPS_H_
+#define TENSORFLOW_KERNELS_TRAINING_OPS_H_
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+// Each training algorithm has a ApplyXYZ functor struct declared in
+// this header file. They are specialized for different devices
+// (CPUDevice in training_ops.cc or GPUDevice in training_ops_gpu.cc).
+
+template <typename Device, typename T>
+struct ApplyGradientDescent {
+ void operator()(const Device& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::ConstScalar alpha,
+ typename TTypes<T>::ConstFlat delta);
+};
+
+template <typename Device, typename T>
+struct ApplyAdagrad {
+ void operator()(const Device& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat accum,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstFlat grad);
+};
+
+template <typename Device, typename T>
+struct ApplyMomentum {
+ void operator()(const Device& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat accum,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstFlat grad,
+ typename TTypes<T>::ConstScalar momentum);
+};
+
+template <typename Device, typename T>
+struct ApplyAdam {
+ void operator()(const Device& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
+ typename TTypes<T>::ConstScalar beta1_power,
+ typename TTypes<T>::ConstScalar beta2_power,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstScalar beta1,
+ typename TTypes<T>::ConstScalar beta2,
+ typename TTypes<T>::ConstScalar epsilon,
+ typename TTypes<T>::ConstFlat grad);
+};
+
+template <typename Device, typename T>
+struct ApplyRMSProp {
+ void operator()(const Device& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstScalar rho,
+ typename TTypes<T>::ConstScalar momentum,
+ typename TTypes<T>::ConstScalar epsilon,
+ typename TTypes<T>::ConstFlat grad);
+};
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_TRAINING_OPS_H_
diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc
new file mode 100644
index 0000000000..3106f29648
--- /dev/null
+++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc
@@ -0,0 +1,127 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/training_ops.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+template <typename T>
+struct ApplyGradientDescent<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::ConstScalar alpha,
+ typename TTypes<T>::ConstFlat delta) {
+ Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
+ bcast[0] = delta.dimension(0);
+ Eigen::Sizes<1> single;
+ var.device(d) -= alpha.reshape(single).broadcast(bcast) * delta;
+ }
+};
+
+template <typename T>
+struct ApplyAdagrad<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat accum,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstFlat grad) {
+ accum.device(d) += grad.square();
+ Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
+ bcast[0] = grad.dimension(0);
+ Eigen::Sizes<1> single;
+ var.device(d) -= lr.reshape(single).broadcast(bcast) * grad * accum.rsqrt();
+ }
+};
+
+template <typename T>
+struct ApplyMomentum<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat accum,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstFlat grad,
+ typename TTypes<T>::ConstScalar momentum) {
+ Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
+ bcast[0] = grad.dimension(0);
+ Eigen::Sizes<1> single;
+ accum.device(d) = accum * momentum.reshape(single).broadcast(bcast) + grad;
+ var.device(d) -= lr.reshape(single).broadcast(bcast) * accum;
+ }
+};
+
+template <typename T>
+struct ApplyAdam<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
+ typename TTypes<T>::ConstScalar beta1_power,
+ typename TTypes<T>::ConstScalar beta2_power,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstScalar beta1,
+ typename TTypes<T>::ConstScalar beta2,
+ typename TTypes<T>::ConstScalar epsilon,
+ typename TTypes<T>::ConstFlat grad) {
+ Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
+ bcast[0] = grad.dimension(0);
+ Eigen::Sizes<1> single;
+ const auto one = static_cast<T>(1.0);
+ m.device(d) =
+ m +
+ (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
+ (grad - m);
+ v.device(d) =
+ v +
+ (beta2.constant(one) - beta2).reshape(single).broadcast(bcast) *
+ (grad.square() - v);
+ var.device(d) -= (lr * (beta2_power.constant(one) - beta2_power).sqrt() /
+ (beta1_power.constant(one) - beta1_power))
+ .reshape(single)
+ .broadcast(bcast) *
+ m / (epsilon.reshape(single).broadcast(bcast) + v.sqrt());
+ }
+};
+
+template <typename T>
+struct ApplyRMSProp<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstScalar rho,
+ typename TTypes<T>::ConstScalar momentum,
+ typename TTypes<T>::ConstScalar epsilon,
+ typename TTypes<T>::ConstFlat grad) {
+ Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
+ bcast[0] = grad.dimension(0);
+ Eigen::Sizes<1> single;
+ const auto one = static_cast<T>(1.0);
+ ms.device(d) = ms +
+ (rho.constant(one) - rho).reshape(single).broadcast(bcast) *
+ (grad.square() - ms);
+ mom.device(d) =
+ mom * momentum.reshape(single).broadcast(bcast) +
+ lr.reshape(single).broadcast(bcast) * grad /
+ ((epsilon.reshape(single).broadcast(bcast) + ms).sqrt());
+ var.device(d) -= mom;
+ }
+};
+
+} // namespace functor
+
+template struct functor::ApplyGradientDescent<GPUDevice, float>;
+template struct functor::ApplyGradientDescent<GPUDevice, double>;
+
+template struct functor::ApplyAdagrad<GPUDevice, float>;
+template struct functor::ApplyAdagrad<GPUDevice, double>;
+
+template struct functor::ApplyMomentum<GPUDevice, float>;
+template struct functor::ApplyMomentum<GPUDevice, double>;
+
+template struct functor::ApplyAdam<GPUDevice, float>;
+template struct functor::ApplyAdam<GPUDevice, double>;
+
+template struct functor::ApplyRMSProp<GPUDevice, float>;
+template struct functor::ApplyRMSProp<GPUDevice, double>;
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/training_ops_test.cc b/tensorflow/core/kernels/training_ops_test.cc
new file mode 100644
index 0000000000..3c629badb6
--- /dev/null
+++ b/tensorflow/core/kernels/training_ops_test.cc
@@ -0,0 +1,226 @@
+#include <gtest/gtest.h>
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+// We focus on the single thread performance of training ops.
+static SessionOptions InitSingleThreadedOptions() {
+ SessionOptions opts;
+ opts.config.set_intra_op_parallelism_threads(1);
+ opts.config.set_inter_op_parallelism_threads(1);
+ return opts;
+}
+
+static SessionOptions* GetOptions() {
+ static SessionOptions opts = InitSingleThreadedOptions();
+ return &opts;
+}
+
+static Node* Var(Graph* g, int n) {
+ return test::graph::Var(g, DT_FLOAT, TensorShape({n}));
+}
+
+static Node* Zeros(Graph* g, int n) {
+ Tensor data(DT_FLOAT, TensorShape({n}));
+ data.flat<float>().setZero();
+ return test::graph::Constant(g, data);
+}
+
+static Node* Random(Graph* g, int n) {
+ Tensor data(DT_FLOAT, TensorShape({n}));
+ data.flat<float>().setRandom();
+ return test::graph::Constant(g, data);
+}
+
+static Node* Scalar(Graph* g, float val) {
+ Tensor data(DT_FLOAT, TensorShape({}));
+ data.flat<float>()(0) = val;
+ return test::graph::Constant(g, data);
+}
+
+static void SGD(int32 n, Graph** init_g, Graph** train_g) {
+ RequireDefaultOps();
+ {
+ Graph* g = new Graph(OpRegistry::Global());
+ auto var = Var(g, n);
+ test::graph::Assign(g, var, Zeros(g, n));
+ *init_g = g;
+ }
+ {
+ Graph* g = new Graph(OpRegistry::Global());
+ auto var = Var(g, n);
+ auto lr = Scalar(g, 0.01);
+ auto grad = Random(g, n);
+ test::graph::Multi(g, "ApplyGradientDescent", {var, lr, grad});
+ *train_g = g;
+ }
+}
+
+static void BM_SGD(int iters, int params) {
+ const int64 tot = static_cast<int64>(iters) * params;
+ testing::ItemsProcessed(tot);
+ testing::BytesProcessed(tot * sizeof(float));
+ Graph* init;
+ Graph* train;
+ SGD(params, &init, &train);
+ test::Benchmark("cpu", train, GetOptions(), init).Run(iters);
+}
+BENCHMARK(BM_SGD)->Arg(128 << 10)->Arg(256 << 10);
+
+static void Adagrad(int32 n, Graph** init_g, Graph** train_g) {
+ RequireDefaultOps();
+ {
+ Graph* g = new Graph(OpRegistry::Global());
+ auto var = Var(g, n);
+ auto accum = Var(g, n);
+ auto zero = Zeros(g, n);
+ test::graph::Assign(g, var, zero);
+ test::graph::Assign(g, accum, zero);
+ *init_g = g;
+ }
+ {
+ Graph* g = new Graph(OpRegistry::Global());
+ auto var = Var(g, n);
+ auto accum = Var(g, n);
+ auto lr = Scalar(g, 0.01);
+ auto grad = Random(g, n);
+ test::graph::Multi(g, "ApplyAdagrad", {var, accum, lr, grad});
+ *train_g = g;
+ }
+}
+
+static void BM_Adagrad(int iters, int params) {
+ const int64 tot = static_cast<int64>(iters) * params;
+ testing::ItemsProcessed(tot);
+ testing::BytesProcessed(tot * sizeof(float));
+ Graph* init;
+ Graph* train;
+ Adagrad(params, &init, &train);
+ test::Benchmark("cpu", train, GetOptions(), init).Run(iters);
+}
+BENCHMARK(BM_Adagrad)->Arg(128 << 10)->Arg(256 << 10);
+
+static void Momentum(int32 n, Graph** init_g, Graph** train_g) {
+ RequireDefaultOps();
+ TensorShape shape({n});
+ {
+ Graph* g = new Graph(OpRegistry::Global());
+ auto var = Var(g, n);
+ auto accum = Var(g, n);
+ auto zero = Zeros(g, n);
+ test::graph::Assign(g, var, zero);
+ test::graph::Assign(g, accum, zero);
+ *init_g = g;
+ }
+ {
+ Graph* g = new Graph(OpRegistry::Global());
+ auto var = Var(g, n);
+ auto accum = Var(g, n);
+ auto lr = Scalar(g, 0.01);
+ auto grad = Random(g, n);
+ auto mom = Scalar(g, 0.01);
+ test::graph::Multi(g, "ApplyMomentum", {var, accum, lr, grad, mom});
+ *train_g = g;
+ }
+}
+
+static void BM_Momentum(int iters, int params) {
+ const int64 tot = static_cast<int64>(iters) * params;
+ testing::ItemsProcessed(tot);
+ testing::BytesProcessed(tot * sizeof(float));
+ Graph* init;
+ Graph* train;
+ Momentum(params, &init, &train);
+ test::Benchmark("cpu", train, GetOptions(), init).Run(iters);
+}
+BENCHMARK(BM_Momentum)->Arg(128 << 10)->Arg(256 << 10);
+
+static void Adam(int32 n, Graph** init_g, Graph** train_g) {
+ RequireDefaultOps();
+ TensorShape shape({n});
+ {
+ Graph* g = new Graph(OpRegistry::Global());
+ auto var = Var(g, n);
+ auto m = Var(g, n);
+ auto v = Var(g, n);
+ auto zero = Zeros(g, n);
+ test::graph::Assign(g, var, zero);
+ test::graph::Assign(g, m, zero);
+ test::graph::Assign(g, v, zero);
+ *init_g = g;
+ }
+ {
+ Graph* g = new Graph(OpRegistry::Global());
+ auto var = Var(g, n);
+ auto m = Var(g, n);
+ auto v = Var(g, n);
+ auto beta1_power = Scalar(g, 0.9);
+ auto beta2_power = Scalar(g, 0.99);
+ auto lr = Scalar(g, 0.01);
+ auto beta1 = Scalar(g, 0.9);
+ auto beta2 = Scalar(g, 0.99);
+ auto epsilon = Scalar(g, 1e-8);
+ auto grad = Random(g, n);
+ test::graph::Multi(g, "ApplyAdam", {var, m, v, beta1_power, beta2_power, lr,
+ beta1, beta2, epsilon, grad});
+ *train_g = g;
+ }
+}
+
+static void BM_Adam(int iters, int params) {
+ const int64 tot = static_cast<int64>(iters) * params;
+ testing::ItemsProcessed(tot);
+ testing::BytesProcessed(tot * sizeof(float));
+ Graph* init;
+ Graph* train;
+ Adam(params, &init, &train);
+ test::Benchmark("cpu", train, GetOptions(), init).Run(iters);
+}
+BENCHMARK(BM_Adam)->Arg(128 << 10)->Arg(256 << 10);
+
+static void RMSProp(int32 n, Graph** init_g, Graph** train_g) {
+ RequireDefaultOps();
+ TensorShape shape({n});
+ {
+ Graph* g = new Graph(OpRegistry::Global());
+ auto var = Var(g, n);
+ auto ms = Var(g, n);
+ auto mom = Var(g, n);
+ auto zero = Zeros(g, n);
+ test::graph::Assign(g, var, zero);
+ test::graph::Assign(g, ms, zero);
+ test::graph::Assign(g, mom, zero);
+ *init_g = g;
+ }
+ {
+ Graph* g = new Graph(OpRegistry::Global());
+ auto var = Var(g, n);
+ auto ms = Var(g, n);
+ auto mom = Var(g, n);
+ auto lr = Scalar(g, 0.01);
+ auto rho = Scalar(g, 0.9);
+ auto momentum = Scalar(g, 0.9);
+ auto epsilon = Scalar(g, 1e-8);
+ auto grad = Random(g, n);
+ test::graph::Multi(g, "ApplyRMSProp",
+ {var, ms, mom, lr, rho, momentum, epsilon, grad});
+ *train_g = g;
+ }
+}
+
+static void BM_RMSProp(int iters, int params) {
+ const int64 tot = static_cast<int64>(iters) * params;
+ testing::ItemsProcessed(tot);
+ testing::BytesProcessed(tot * sizeof(float));
+ Graph* init;
+ Graph* train;
+ RMSProp(params, &init, &train);
+ test::Benchmark("cpu", train, GetOptions(), init).Run(iters);
+}
+BENCHMARK(BM_RMSProp)->Arg(128 << 10)->Arg(256 << 10);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
new file mode 100644
index 0000000000..4f11a881f8
--- /dev/null
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -0,0 +1,190 @@
+// See docs in ../ops/array_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/transpose_op.h"
+#include "tensorflow/core/kernels/transpose_op_functor.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+// inv = InvertPermutationOp(T<int32> p) takes a permutation of
+// integers 0, 1, ..., n - 1 and returns the inverted
+// permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n).
+//
+// REQUIRES: input is a vector of int32.
+// REQUIRES: input is a permutation of 0, 1, ..., n-1.
+
+class InvertPermutationOp : public OpKernel {
+ public:
+ explicit InvertPermutationOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsVector(input.shape()),
+ errors::InvalidArgument("invert_permutation expects a 1D vector."));
+ auto Tin = input.vec<int32>();
+ const int N = Tin.size();
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+ auto Tout = output->vec<int32>();
+ std::fill_n(Tout.data(), N, -1);
+ for (int i = 0; i < N; ++i) {
+ const int32 d = Tin(i);
+ OP_REQUIRES(context, 0 <= d && d < N,
+ errors::InvalidArgument(d, " is not between 0 and ", N));
+ OP_REQUIRES(context, Tout(d) == -1,
+ errors::InvalidArgument(d, " is duplicated in the input."));
+ Tout(d) = i;
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("InvertPermutation").Device(DEVICE_CPU),
+ InvertPermutationOp);
+
+// output = TransposeOp(T<any> input, T<int32> perm) takes a tensor
+// of type T and rank N, and a permutation of 0, 1, ..., N-1. It
+// shuffles the dimensions of the input tensor according to permutation.
+//
+// Specifically, the returned tensor output meets the following condition:
+// 1) output.dims() == input.dims();
+// 2) output.dim_size(i) == input.dim_size(perm[i]);
+// 3) output.tensor<T, N>(i_0, i_1, ..., i_N-1) ==
+// input.tensor<T, N>(j_0, j_1, ..., j_N-1),
+// where i_s == j_{perm[s]}
+//
+// REQUIRES: perm is a vector of int32.
+// REQUIRES: input.dims() == perm.size().
+// REQUIRES: perm is a permutation.
+
+template <typename Device, typename T>
+TransposeOp<Device, T>::TransposeOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+template <typename Device, typename T>
+void TransposeOp<Device, T>::Compute(OpKernelContext* context) {
+ const Tensor& input = context->input(0);
+ const Tensor& perm = context->input(1);
+ // Preliminary validation of sizes.
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(perm.shape()),
+ errors::InvalidArgument("perm must be a vector, not ",
+ perm.shape().DebugString()));
+ auto Vperm = perm.vec<int32>();
+ const int dims = input.dims();
+ static const int kMinDims = 1;
+ static const int kMaxDims = 8;
+ OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims,
+ errors::Unimplemented("Transposing a tensor of rank ", dims,
+ " is not implemented."));
+ OP_REQUIRES(context, dims == Vperm.size(),
+ errors::InvalidArgument(
+ "transpose expects a vector of size ", input.dims(),
+ ". But input(1) is a vector of size ", Vperm.size()));
+ gtl::ArraySlice<int32> permutation(
+ reinterpret_cast<const int32*>(Vperm.data()), dims);
+ TensorShape shape;
+
+ // Check whether permutation is a permutation of integers of [0 .. dims).
+ gtl::InlinedVector<bool, 8> bits(dims);
+ for (const int32 d : permutation) {
+ OP_REQUIRES(
+ context, 0 <= d && d < dims,
+ errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
+ bits[d] = true;
+ shape.AddDim(input.dim_size(d));
+ }
+ for (int i = 0; i < dims; ++i) {
+ OP_REQUIRES(context, bits[i], errors::InvalidArgument(
+ i, " is missing from {",
+ str_util::Join(permutation, ","), "}."));
+ }
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output));
+ switch (dims) {
+#define EXPAND_DIM(N) \
+ case N: { \
+ functor::TransposeFunctor<Device, T, N> func; \
+ func(context->eigen_device<Device>(), output->tensor<T, N>(), \
+ input.tensor<T, N>(), permutation.data()); \
+ break; \
+ }
+ EXPAND_DIM(1);
+ EXPAND_DIM(2);
+ EXPAND_DIM(3);
+ EXPAND_DIM(4);
+ EXPAND_DIM(5);
+ EXPAND_DIM(6);
+ EXPAND_DIM(7);
+ EXPAND_DIM(8);
+ default:
+ LOG(FATAL) << "Unexpected dims: " << dims;
+ }
+#undef EXPAND_CASE
+}
+
+namespace functor {
+
+template <typename Device, typename T, int NDIMS>
+void TransposeMaybeInline(const Device& d,
+ typename TTypes<T, NDIMS>::Tensor out,
+ typename TTypes<T, NDIMS>::ConstTensor in,
+ const int* perm) {
+ // perm[] is a permutation of 0, 1, ..., NDIMS-1. perm[] is on CPU.
+ Eigen::array<int, NDIMS> p;
+ for (int i = 0; i < NDIMS; ++i) p[i] = perm[i];
+ if (out.size() * sizeof(T) < 131072) { // Small transpose on a CPU: do inline
+ out = in.shuffle(p);
+ } else {
+ out.device(d) = in.shuffle(p);
+ }
+}
+
+template <typename T, int NDIMS>
+struct TransposeFunctor<CPUDevice, T, NDIMS> {
+ void operator()(const CPUDevice& d, typename TTypes<T, NDIMS>::Tensor out,
+ typename TTypes<T, NDIMS>::ConstTensor in, const int* perm) {
+ TransposeMaybeInline<CPUDevice, T, NDIMS>(d, out, in, perm);
+ }
+};
+
+} // namespace functor
+
+#define REGISTER(D, T) \
+ template class TransposeOp<D##Device, T>; \
+ REGISTER_KERNEL_BUILDER(Name("Transpose") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("perm"), \
+ TransposeOp<D##Device, T>)
+REGISTER(CPU, float);
+REGISTER(CPU, double);
+REGISTER(CPU, complex64);
+REGISTER(CPU, uint8);
+REGISTER(CPU, int8);
+REGISTER(CPU, int16);
+REGISTER(CPU, int32);
+REGISTER(CPU, int64);
+REGISTER(CPU, string);
+#if GOOGLE_CUDA
+REGISTER(GPU, uint8);
+REGISTER(GPU, int8);
+REGISTER(GPU, int16);
+REGISTER(GPU, int32);
+REGISTER(GPU, int64);
+REGISTER(GPU, float);
+REGISTER(GPU, double);
+#endif
+#undef REGISTER
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/transpose_op.h b/tensorflow/core/kernels/transpose_op.h
new file mode 100644
index 0000000000..f7a5be5c2b
--- /dev/null
+++ b/tensorflow/core/kernels/transpose_op.h
@@ -0,0 +1,19 @@
+#ifndef TENSORFLOW_KERNELS_TRANSPOSE_OP_H_
+#define TENSORFLOW_KERNELS_TRANSPOSE_OP_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+template <typename Device, typename T>
+class TransposeOp : public OpKernel {
+ public:
+ explicit TransposeOp(OpKernelConstruction* context);
+ void Compute(OpKernelContext* context) override;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_TRANSPOSE_OP_H_
diff --git a/tensorflow/core/kernels/transpose_op_functor.h b/tensorflow/core/kernels/transpose_op_functor.h
new file mode 100644
index 0000000000..8cbd1cbb29
--- /dev/null
+++ b/tensorflow/core/kernels/transpose_op_functor.h
@@ -0,0 +1,28 @@
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T, int NDIMS>
+void Transpose(const Device& d, typename TTypes<T, NDIMS>::Tensor out,
+ typename TTypes<T, NDIMS>::ConstTensor in, const int* perm) {
+ // perm[] is a permutation of 0, 1, ..., NDIMS-1. perm[] is on CPU.
+ Eigen::array<int, NDIMS> p;
+ for (int i = 0; i < NDIMS; ++i) p[i] = perm[i];
+ out.device(d) = in.shuffle(p);
+}
+
+template <typename Device, typename T, int NDIMS>
+struct TransposeFunctor {
+ void operator()(const Device& d, typename TTypes<T, NDIMS>::Tensor out,
+ typename TTypes<T, NDIMS>::ConstTensor in, const int* perm);
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/transpose_op_gpu.cu.cc b/tensorflow/core/kernels/transpose_op_gpu.cu.cc
new file mode 100644
index 0000000000..8c04a6544e
--- /dev/null
+++ b/tensorflow/core/kernels/transpose_op_gpu.cu.cc
@@ -0,0 +1,43 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/kernels/transpose_op_functor.h"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename T, int NDIMS>
+struct TransposeFunctor<Eigen::GpuDevice, T, NDIMS> {
+ void operator()(const Eigen::GpuDevice& d,
+ typename TTypes<T, NDIMS>::Tensor out,
+ typename TTypes<T, NDIMS>::ConstTensor in, const int* perm) {
+ Transpose<Eigen::GpuDevice, T, NDIMS>(d, out, in, perm);
+ }
+};
+
+#define DEFINE(T, N) template struct TransposeFunctor<Eigen::GpuDevice, T, N>;
+#define DEFINE_DIM(T) \
+ DEFINE(T, 1); \
+ DEFINE(T, 2); \
+ DEFINE(T, 3); \
+ DEFINE(T, 4); \
+ DEFINE(T, 5); \
+ DEFINE(T, 6); \
+ DEFINE(T, 7); \
+ DEFINE(T, 8);
+DEFINE_DIM(uint8);
+DEFINE_DIM(int8);
+DEFINE_DIM(int16);
+DEFINE_DIM(int32);
+DEFINE_DIM(int64);
+DEFINE_DIM(float);
+DEFINE_DIM(double);
+#undef DEFINE_DIM
+#undef DEFINE
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc
new file mode 100644
index 0000000000..61f4a54583
--- /dev/null
+++ b/tensorflow/core/kernels/unique_op.cc
@@ -0,0 +1,61 @@
+#include <unordered_map>
+#include <utility>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename T>
+class UniqueOp : public OpKernel {
+ public:
+ explicit UniqueOp(OpKernelConstruction* context) : OpKernel(context) {
+ const DataType dt = DataTypeToEnum<T>::v();
+ OP_REQUIRES_OK(context, context->MatchSignature({dt}, {dt, DT_INT32}));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
+ errors::InvalidArgument("unique expects a 1D vector."));
+ auto Tin = input.vec<T>();
+ const int N = Tin.size();
+
+ Tensor* idx = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(1, input.shape(), &idx));
+ auto idx_vec = idx->template vec<int32>();
+
+ std::unordered_map<T, int32> uniq;
+ uniq.reserve(2 * N);
+ for (int i = 0, j = 0; i < N; ++i) {
+ auto it = uniq.insert(std::make_pair(Tin(i), j));
+ idx_vec(i) = it.first->second;
+ if (it.second) {
+ ++j;
+ }
+ }
+ int32 uniq_size = uniq.size();
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, TensorShape({uniq_size}), &output));
+ auto output_vec = output->template vec<T>();
+
+ for (auto it : uniq) {
+ output_vec(it.second) = it.first;
+ }
+ }
+};
+
+#define REGISTER_UNIQUE(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Unique").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ UniqueOp<type>)
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE);
+#undef REGISTER_UNIQUE
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/unique_op_test.cc b/tensorflow/core/kernels/unique_op_test.cc
new file mode 100644
index 0000000000..658f2282cf
--- /dev/null
+++ b/tensorflow/core/kernels/unique_op_test.cc
@@ -0,0 +1,51 @@
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+
+namespace {
+
+static void BM_Unique(int iters, int dim) {
+ testing::StopTiming();
+ RequireDefaultOps();
+ Graph* g = new Graph(OpRegistry::Global());
+
+ Tensor input(DT_INT32, TensorShape({dim}));
+ input.flat<int32>().setRandom();
+
+ Node* node;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Unique")
+ .Input(test::graph::Constant(g, input))
+ .Attr("T", DT_INT32)
+ .Finalize(g, &node));
+
+ testing::BytesProcessed(static_cast<int64>(iters) * dim * sizeof(int32));
+ testing::UseRealTime();
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_Unique)
+ ->Arg(32)
+ ->Arg(256)
+ ->Arg(1024)
+ ->Arg(4 * 1024)
+ ->Arg(16 * 1024)
+ ->Arg(64 * 1024)
+ ->Arg(256 * 1024);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/unpack_op.cc b/tensorflow/core/kernels/unpack_op.cc
new file mode 100644
index 0000000000..36cfb2c8e5
--- /dev/null
+++ b/tensorflow/core/kernels/unpack_op.cc
@@ -0,0 +1,96 @@
+// See docs in ../ops/array_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/split_op.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class UnpackOp : public OpKernel {
+ public:
+ explicit UnpackOp(OpKernelConstruction* c) : OpKernel(c) {}
+
+ void Compute(OpKernelContext* context) override {
+ const int32 num = num_outputs();
+ const Tensor& input = context->input(0);
+ const TensorShape& input_shape = input.shape();
+
+ OP_REQUIRES(
+ context, input_shape.dims() > 0 && input_shape.dim_size(0) == num,
+ errors::InvalidArgument("Input shape must start with ", num, ", got ",
+ input_shape.ShortDebugString()));
+
+ auto output_shape = input_shape;
+ output_shape.RemoveDim(0);
+ const int32 output_size = output_shape.num_elements();
+
+ // Special case: Aligned, so we can share the underlying buffer.
+ //
+ // Apply this optimization conservatively: if input is aligned,
+ // the resulting tensors must be aligned. It's conservative
+ // because if the immediate consumer of the resulting tensors are
+ // not using eigen for computation, its perfectly fine to avoid
+ // the copying.
+ if (output_size == 0 || IsInnerDimsSizeAligned<T>(input_shape)) {
+ for (int i = 0; i < num; ++i) {
+ Tensor output;
+ CHECK(output.CopyFrom(input.Slice(i, i + 1), output_shape));
+ context->set_output(i, output);
+ }
+ return;
+ }
+
+ // Except for shape, unpack is a special case of split, so we reuse the
+ // same computational kernels.
+ auto input_reshaped = input.shaped<T, 3>({1, num, output_size});
+
+ for (int i = 0; i < num; ++i) {
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(i, output_shape, &output));
+ auto output_shaped = output->shaped<T, 3>({1, 1, output_size});
+
+ Eigen::DSizes<ptrdiff_t, 3> indices{0, i, 0};
+ Eigen::DSizes<ptrdiff_t, 3> sizes{1, 1, output_size};
+ functor::Split<Device, T>()(context->eigen_device<Device>(),
+ output_shaped, input_reshaped, indices,
+ sizes);
+ }
+ }
+};
+
+#define REGISTER_UNPACK(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Unpack").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ UnpackOp<CPUDevice, type>)
+
+TF_CALL_ALL_TYPES(REGISTER_UNPACK);
+
+#undef REGISTER_UNPACK
+
+#if GOOGLE_CUDA
+
+#define REGISTER_GPU(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Unpack").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ UnpackOp<GPUDevice, type>)
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
+#undef REGISTER_GPU
+
+#endif // GOOGLE_CUDA
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc
new file mode 100644
index 0000000000..2f1dbc68c0
--- /dev/null
+++ b/tensorflow/core/kernels/variable_ops.cc
@@ -0,0 +1,37 @@
+#define EIGEN_USE_THREADS
+#include "tensorflow/core/kernels/variable_ops.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+REGISTER_KERNEL_BUILDER(Name("Variable").Device(DEVICE_CPU), VariableOp);
+REGISTER_KERNEL_BUILDER(Name("TemporaryVariable").Device(DEVICE_CPU),
+ TemporaryVariableOp);
+REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable").Device(DEVICE_CPU),
+ DestroyTemporaryVariableOp);
+
+#if GOOGLE_CUDA
+// Only register 'Variable' on GPU for the subset of types also supported by
+// 'Assign' (see dense_update_ops.cc.)
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Variable").Device(DEVICE_GPU).TypeConstraint<type>("dtype"), \
+ VariableOp); \
+ REGISTER_KERNEL_BUILDER(Name("TemporaryVariable") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("dtype"), \
+ TemporaryVariableOp); \
+ REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T"), \
+ DestroyTemporaryVariableOp);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
+#undef REGISTER_GPU_KERNELS
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h
new file mode 100644
index 0000000000..77d2da0ad4
--- /dev/null
+++ b/tensorflow/core/kernels/variable_ops.h
@@ -0,0 +1,146 @@
+#ifndef TENSORFLOW_KERNELS_VARIABLE_OPS_H_
+#define TENSORFLOW_KERNELS_VARIABLE_OPS_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+class VariableOp : public OpKernel {
+ public:
+ explicit VariableOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
+ dtype_ = RemoveRefType(context->output_type(0));
+ }
+
+ ~VariableOp() override {
+ if (var_) var_->Unref();
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ mutex_lock l(init_mu_);
+ if (var_ == nullptr) {
+ OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(),
+ true /* use name() */));
+ auto creator = [this](Var** var) {
+ *var = new Var(dtype_);
+ (*var)->tensor()->set_shape(shape_);
+ return Status::OK();
+ };
+ OP_REQUIRES_OK(ctx,
+ cinfo_.resource_manager()->LookupOrCreate<Var>(
+ cinfo_.container(), cinfo_.name(), &var_, creator));
+ }
+ // Output a reference to our tensor, so it may be updated.
+ //
+ // As long as *this is alive, the ref we return here is valid
+ // because *this owns a ref on var_.
+ ctx->set_output_ref(0, var_->mu(), var_->tensor());
+ }
+
+ private:
+ class Var : public ResourceBase {
+ public:
+ explicit Var(DataType dtype) : tensor_(dtype) {}
+ mutex* mu() { return &mu_; }
+ Tensor* tensor() { return &tensor_; }
+
+ string DebugString() override {
+ return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
+ tensor_.shape().ShortDebugString());
+ }
+
+ private:
+ mutex mu_;
+ Tensor tensor_;
+
+ ~Var() override {}
+ TF_DISALLOW_COPY_AND_ASSIGN(Var);
+ };
+
+ DataType dtype_;
+ TensorShape shape_;
+
+ mutex init_mu_;
+ ContainerInfo cinfo_ GUARDED_BY(init_mu_);
+ Var* var_ GUARDED_BY(init_mu_) = nullptr;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(VariableOp);
+};
+
+class TemporaryVariableOp : public OpKernel {
+ public:
+ explicit TemporaryVariableOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
+ OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
+ OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
+ // Variable name defaults to op name if not specified explicitly.
+ if (var_name_ == "") var_name_ = name();
+ }
+
+ void Compute(OpKernelContext* context) override {
+ Status s;
+ ResourceMgr* rm = context->step_resource_manager();
+ OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
+ auto* tmp_var = new TmpVar;
+ OP_REQUIRES(context, tmp_var,
+ errors::ResourceExhausted("Could not allocate TmpVar."));
+ tmp_var->name = var_name_;
+ s = context->allocate_temp(dtype_, shape_, &tmp_var->val);
+ if (!s.ok()) tmp_var->Unref();
+ OP_REQUIRES_OK(context, s);
+ OP_REQUIRES_OK(context, rm->Create("tmp_var", var_name_, tmp_var));
+ context->set_output_ref(0, &tmp_var->mu, &tmp_var->val);
+ }
+
+ private:
+ // Refcounted temporary variable resource.
+ friend class DestroyTemporaryVariableOp;
+ struct TmpVar : public ResourceBase {
+ mutex mu;
+ Tensor val;
+ string name;
+ string DebugString() override { return name; }
+ ~TmpVar() override { VLOG(3) << "TmpVar " << name << " deleted"; }
+ };
+
+ TensorShape shape_;
+ DataType dtype_;
+ string var_name_;
+};
+
+class DestroyTemporaryVariableOp : public OpKernel {
+ public:
+ explicit DestroyTemporaryVariableOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES(context, IsRefType(context->input_type(0)),
+ errors::InvalidArgument("lhs input needs to be a ref type"))
+ OP_REQUIRES_OK(context, context->GetAttr("var_name", &var_name_));
+ OP_REQUIRES(context, var_name_ != "",
+ errors::InvalidArgument("Missing var_name attribute"));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // NOTE(pbar): All other mutators of the Tensor Ref *must* have completed
+ // their execution before this DestroyTemporaryVariable op executes.
+ // This is typically achieved using control dependencies.
+ CHECK(IsRefType(context->input_dtype(0)));
+ Tensor tmpvar = context->mutable_input(0, false);
+ context->set_output(0, tmpvar);
+ ResourceMgr* rm = context->step_resource_manager();
+ OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager."));
+ OP_REQUIRES_OK(
+ context, rm->Delete<TemporaryVariableOp::TmpVar>("tmp_var", var_name_));
+ }
+
+ private:
+ string var_name_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_VARIABLE_OPS_H_
diff --git a/tensorflow/core/kernels/where_op.cc b/tensorflow/core/kernels/where_op.cc
new file mode 100644
index 0000000000..9db0943ea7
--- /dev/null
+++ b/tensorflow/core/kernels/where_op.cc
@@ -0,0 +1,74 @@
+// See docs in ../ops/array_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/where_op.h"
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device>
+class WhereOp : public OpKernel {
+ public:
+ explicit WhereOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+
+ const int input_dims = input.dims();
+ Tensor num_true;
+ OP_REQUIRES_OK(
+ context, context->allocate_temp(DT_INT64, TensorShape({}), &num_true));
+ auto num_true_t = num_true.scalar<int64>();
+
+ functor::NumTrue<Device>::Compute(context->eigen_device<Device>(),
+ input.flat<bool>(), num_true_t);
+ TensorShape output_shape({num_true_t(), input_dims});
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+
+#define HANDLE_DIM(NDIM) \
+ case NDIM: \
+ functor::Where<Device, NDIM>::Compute(context->eigen_device<Device>(), \
+ input.tensor<bool, NDIM>(), \
+ output->matrix<int64>()); \
+ break;
+
+ switch (input_dims) {
+ HANDLE_DIM(1);
+ HANDLE_DIM(2);
+ HANDLE_DIM(3);
+ HANDLE_DIM(4);
+ HANDLE_DIM(5);
+
+ default:
+ OP_REQUIRES(context, false,
+ errors::InvalidArgument(
+ "WhereOp : Unhandled input dimensions: ", input_dims));
+ }
+#undef HANDLE_DIM
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(WhereOp);
+};
+
+#define REGISTER_WHERE() \
+ REGISTER_KERNEL_BUILDER(Name("Where").Device(DEVICE_CPU), WhereOp<CPUDevice>);
+
+REGISTER_WHERE();
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/where_op.h b/tensorflow/core/kernels/where_op.h
new file mode 100644
index 0000000000..c7b835d02f
--- /dev/null
+++ b/tensorflow/core/kernels/where_op.h
@@ -0,0 +1,65 @@
+#ifndef TENSORFLOW_KERNELS_WHERE_OP_H_
+#define TENSORFLOW_KERNELS_WHERE_OP_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+namespace functor {
+
+template <typename Device>
+struct NumTrue {
+ EIGEN_ALWAYS_INLINE static void Compute(
+ const Device& d, typename TTypes<bool>::ConstFlat input,
+ TTypes<int64>::Scalar num_true) {
+ num_true.device(d) = input.template cast<int64>().sum();
+ }
+};
+
+template <typename Device, int NDIM>
+struct Where {
+ EIGEN_ALWAYS_INLINE static void Compute(
+ const Device& d, typename TTypes<bool, NDIM>::ConstTensor input,
+ typename TTypes<int64>::Matrix output) {
+ Eigen::DenseIndex true_n = 0;
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> dims = input.dimensions();
+ Eigen::DSizes<Eigen::DenseIndex, NDIM> strides;
+
+ // Calculate strides for RowMajor order.
+ EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) ==
+ static_cast<int>(Eigen::RowMajor)),
+ INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR);
+
+ strides[NDIM - 1] = 1;
+ for (int i = NDIM - 2; i >= 0; --i) {
+ strides[i] = strides[i + 1] * dims[i + 1];
+ }
+
+ // Note, no bounds checking is done on true_n. It is assumed that
+ // the output was correctly sized via output of NumTrue::Compute.
+ for (Eigen::DenseIndex n = 0; n < input.size(); ++n) {
+ if (input.data()[n]) {
+ WriteIndexRowMajor(output, strides, true_n, n);
+ ++true_n;
+ }
+ }
+ }
+
+ EIGEN_ALWAYS_INLINE static void WriteIndexRowMajor(
+ typename TTypes<int64>::Matrix output,
+ const Eigen::DSizes<Eigen::DenseIndex, NDIM>& strides,
+ Eigen::DenseIndex true_n, Eigen::DenseIndex index) {
+ for (int i = 0; i < NDIM; ++i) {
+ output(true_n, i) = index / strides[i];
+ index %= strides[i];
+ }
+ }
+};
+
+} // namespace functor
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_WHERE_OP_H_
diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc
new file mode 100644
index 0000000000..b940163ec9
--- /dev/null
+++ b/tensorflow/core/kernels/whole_file_read_ops.cc
@@ -0,0 +1,108 @@
+// See docs in ../ops/io_ops.cc.
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/reader_op_kernel.h"
+#include "tensorflow/core/kernels/reader_base.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+static Status ReadEntireFile(Env* env, const string& filename,
+ string* contents) {
+ uint64 file_size = 0;
+ TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
+ contents->resize(file_size);
+ RandomAccessFile* file;
+ TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
+ std::unique_ptr<RandomAccessFile> make_sure_file_gets_deleted(file);
+ StringPiece data;
+ TF_RETURN_IF_ERROR(file->Read(0, file_size, &data, &(*contents)[0]));
+ if (data.size() != file_size) {
+ return errors::DataLoss("Truncated read of '", filename, "' expected ",
+ file_size, " got ", data.size());
+ }
+ if (data.data() != &(*contents)[0]) {
+ memmove(&(*contents)[0], data.data(), data.size());
+ }
+ return Status::OK();
+}
+
+class WholeFileReader : public ReaderBase {
+ public:
+ WholeFileReader(Env* env, const string& node_name)
+ : ReaderBase(strings::StrCat("WholeFileReader '", node_name, "'")),
+ env_(env) {}
+
+ Status ReadLocked(string* key, string* value, bool* produced,
+ bool* at_end) override {
+ *key = current_work();
+ TF_RETURN_IF_ERROR(ReadEntireFile(env_, *key, value));
+ *produced = true;
+ *at_end = true;
+ return Status::OK();
+ }
+
+ // Stores state in a ReaderBaseState proto, since WholeFileReader has
+ // no additional state beyond ReaderBase.
+ Status SerializeStateLocked(string* state) override {
+ ReaderBaseState base_state;
+ SaveBaseState(&base_state);
+ base_state.SerializeToString(state);
+ return Status::OK();
+ }
+
+ Status RestoreStateLocked(const string& state) override {
+ ReaderBaseState base_state;
+ if (!ParseProtoUnlimited(&base_state, state)) {
+ return errors::InvalidArgument("Could not parse state for ", name(), ": ",
+ str_util::CEscape(state));
+ }
+ TF_RETURN_IF_ERROR(RestoreBaseState(base_state));
+ return Status::OK();
+ }
+
+ private:
+ Env* env_;
+};
+
+class WholeFileReaderOp : public ReaderOpKernel {
+ public:
+ explicit WholeFileReaderOp(OpKernelConstruction* context)
+ : ReaderOpKernel(context) {
+ Env* env = context->env();
+ SetReaderFactory(
+ [this, env]() { return new WholeFileReader(env, name()); });
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("WholeFileReader").Device(DEVICE_CPU),
+ WholeFileReaderOp);
+
+class ReadFileOp : public OpKernel {
+ public:
+ using OpKernel::OpKernel;
+ void Compute(OpKernelContext* context) override {
+ const Tensor* input;
+ OP_REQUIRES_OK(context, context->input("filename", &input));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(input->shape()),
+ errors::InvalidArgument(
+ "Input filename tensor must be scalar, but had shape: ",
+ input->shape().DebugString()));
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output("contents",
+ TensorShape({}), &output));
+ OP_REQUIRES_OK(context,
+ ReadEntireFile(context->env(), input->scalar<string>()(),
+ &output->scalar<string>()()));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ReadFile").Device(DEVICE_CPU), ReadFileOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc
new file mode 100644
index 0000000000..ff54d157af
--- /dev/null
+++ b/tensorflow/core/kernels/xent_op.cc
@@ -0,0 +1,90 @@
+// See docs in ../ops/nn_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/kernels/xent_op.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class SoftmaxXentWithLogitsOp : public OpKernel {
+ public:
+ explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& logits_in = context->input(0);
+ const Tensor& labels_in = context->input(1);
+ OP_REQUIRES(context, logits_in.IsSameSize(labels_in),
+ errors::InvalidArgument(
+ "logits and labels must be same size: logits_size=",
+ logits_in.shape().DebugString(), " labels_size=",
+ labels_in.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()),
+ errors::InvalidArgument("logits must be 2-dimensional"));
+ // As we already tested that both inputs have the same shape no need to
+ // check that "labels" is a matrix too.
+
+ // loss is 1-D (one per example), and size is batch_size.
+
+ Tensor scratch;
+ OP_REQUIRES_OK(
+ context, context->allocate_temp(DataTypeToEnum<T>::value,
+ TensorShape({logits_in.dim_size(0), 1}),
+ &scratch));
+
+ Tensor* loss_out = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 0, TensorShape({logits_in.dim_size(0)}), &loss_out));
+ Tensor* back_out = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(1, logits_in.shape(), &back_out));
+
+ functor::XentFunctor<Device, T> functor;
+ functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
+ labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(),
+ back_out->matrix<T>());
+ }
+};
+
+// Partial specialization for a CPUDevice, that uses the Eigen implementation
+// from XentEigenImpl.
+namespace functor {
+template <typename T>
+struct XentFunctor<CPUDevice, T> {
+ void operator()(const CPUDevice& d, typename TTypes<T>::ConstMatrix logits,
+ typename TTypes<T>::ConstMatrix labels,
+ typename TTypes<T>::Matrix scratch,
+ typename TTypes<T>::Vec loss,
+ typename TTypes<T>::Matrix backprop) {
+ XentEigenImpl<CPUDevice, T>::Compute(d, logits, labels, scratch, loss,
+ backprop);
+ }
+};
+} // namespace functor
+
+REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ SoftmaxXentWithLogitsOp<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<double>("T"),
+ SoftmaxXentWithLogitsOp<CPUDevice, double>);
+
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T"),
+ SoftmaxXentWithLogitsOp<GPUDevice, float>);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/xent_op.h b/tensorflow/core/kernels/xent_op.h
new file mode 100644
index 0000000000..edb7d817c8
--- /dev/null
+++ b/tensorflow/core/kernels/xent_op.h
@@ -0,0 +1,102 @@
+#ifndef TENSORFLOW_KERNELS_XENT_OP_H_
+#define TENSORFLOW_KERNELS_XENT_OP_H_
+// Functor definition for XentOp, must be compilable by nvcc.
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+// Functor used by XentOp to do the computations.
+template <typename Device, typename T>
+struct XentFunctor {
+ // Computes Cross Entropy loss and backprop.
+ //
+ // logits: batch_size, num_classes.
+ // labels: batch_size, num_classes.
+ // scratch: temporary tensor, dims: batch_size, 1
+ // loss: output tensor for the loss, dims: batch_size.
+ // backprop: output tensor for the backprop, dims: batch_size, num_classes.
+ void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits,
+ typename TTypes<T>::ConstMatrix labels,
+ typename TTypes<T>::Matrix scratch,
+ typename TTypes<T>::Vec loss,
+ typename TTypes<T>::Matrix backprop);
+};
+
+// Eigen code implementing XentFunctor::operator().
+// This code works for both CPU and GPU and is used by the functor
+// specializations for both device types.
+template <typename Device, typename T>
+struct XentEigenImpl {
+ static void Compute(const Device& d, typename TTypes<T>::ConstMatrix logits,
+ typename TTypes<T>::ConstMatrix labels,
+ typename TTypes<T>::Matrix scratch,
+ typename TTypes<T>::Vec loss,
+ typename TTypes<T>::Matrix backprop) {
+ // NOTE(mdevin): This duplicates some of the computations in softmax_op
+ // because we need the intermediate (logits -max(logits)) values to
+ // avoid a log(exp()) in the computation of the loss.
+
+ const int kBatchDim = 0;
+ const int kClassDim = 1;
+
+ const int batch_size = logits.dimension(kBatchDim);
+ const int num_classes = logits.dimension(kClassDim);
+
+// These arrays are used to reduce along the class dimension, and broadcast
+// the resulting value to all classes.
+#if !defined(EIGEN_HAS_INDEX_LIST)
+ Eigen::array<int, 1> along_class;
+ along_class[0] = kClassDim;
+ Eigen::array<int, 1> batch_only;
+ batch_only[0] = batch_size;
+ Eigen::array<int, 2> batch_by_one;
+ batch_by_one[0] = batch_size;
+ batch_by_one[1] = 1;
+ Eigen::array<int, 2> one_by_class;
+ one_by_class[0] = 1;
+ one_by_class[1] = num_classes;
+#else
+ Eigen::IndexList<Eigen::type2index<kClassDim> > along_class;
+ Eigen::IndexList<int, Eigen::type2index<1> > batch_by_one;
+ batch_by_one.set(0, batch_size);
+ Eigen::IndexList<int> batch_only;
+ batch_only.set(0, batch_size);
+ Eigen::IndexList<Eigen::type2index<1>, int> one_by_class;
+ one_by_class.set(1, num_classes);
+#endif
+
+ // max_logits along classes.
+ scratch.reshape(batch_only).device(d) = logits.maximum(along_class);
+
+ // logits - max_logits.
+ backprop.device(d) = logits - scratch.broadcast(one_by_class);
+
+ // sum(exp(logits - max_logits)) along classes.
+ scratch.reshape(batch_only).device(d) = backprop.exp().sum(along_class);
+
+ // NOTE(keveman): Eigen on GPU dispatches to an optimized implementaion
+ // for an expression of the form lhs = rhs.sum().
+ // lhs = -rhs.sum() doesn't match the above pattern, so folding in the
+ // negation before calling sum().
+ // sum(-labels *
+ // ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
+ // along classes
+ loss.device(d) =
+ (labels * (scratch.log().eval().broadcast(one_by_class) - backprop))
+ .eval()
+ .sum(along_class);
+
+ // backprop: prob - labels, where
+ // prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
+ backprop.device(d) =
+ (backprop.exp() / scratch.broadcast(one_by_class)) - labels;
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_XENT_OP_H_
diff --git a/tensorflow/core/kernels/xent_op_gpu.cu.cc b/tensorflow/core/kernels/xent_op_gpu.cu.cc
new file mode 100644
index 0000000000..eec6a84281
--- /dev/null
+++ b/tensorflow/core/kernels/xent_op_gpu.cu.cc
@@ -0,0 +1,35 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/xent_op.h"
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// Partial specialization for a GPUDevice, that uses the Eigen implementation
+// from XentEigenImpl.
+namespace functor {
+template <typename T>
+struct XentFunctor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::ConstMatrix logits,
+ typename TTypes<T>::ConstMatrix labels,
+ typename TTypes<T>::Matrix scratch,
+ typename TTypes<T>::Vec loss,
+ typename TTypes<T>::Matrix backprop) {
+ XentEigenImpl<GPUDevice, T>::Compute(d, logits, labels, scratch, loss,
+ backprop);
+ }
+};
+} // end namespace functor
+
+// Instantiate the GPU implementation for float.
+template struct functor::XentFunctor<GPUDevice, float>;
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/xent_op_test.cc b/tensorflow/core/kernels/xent_op_test.cc
new file mode 100644
index 0000000000..9aab1b09bf
--- /dev/null
+++ b/tensorflow/core/kernels/xent_op_test.cc
@@ -0,0 +1,46 @@
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/kernels/xent_op.h"
+
+namespace tensorflow {
+
+static Graph* Xent(int batch_size, int num_classes) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor logits(DT_FLOAT, TensorShape({batch_size, num_classes}));
+ logits.flat<float>().setRandom();
+ Tensor labels(DT_FLOAT, TensorShape({batch_size, num_classes}));
+ labels.flat<float>().setRandom();
+ test::graph::Binary(g, "SoftmaxCrossEntropyWithLogits",
+ test::graph::Constant(g, logits),
+ test::graph::Constant(g, labels));
+ return g;
+}
+
+#define BM_XentDev(BATCH, CLASS, DEVICE) \
+ static void BM_Xent##_##BATCH##_##CLASS##_##DEVICE(int iters) { \
+ testing::ItemsProcessed(static_cast<int64>(iters) * BATCH * CLASS); \
+ test::Benchmark(#DEVICE, Xent(BATCH, CLASS)).Run(iters); \
+ } \
+ BENCHMARK(BM_Xent##_##BATCH##_##CLASS##_##DEVICE);
+
+/// The representative tests for ptb_word on GPU
+BM_XentDev(16, 10000, gpu);
+BM_XentDev(16, 30000, gpu);
+BM_XentDev(16, 100000, gpu);
+
+BM_XentDev(32, 10000, gpu);
+BM_XentDev(32, 30000, gpu);
+BM_XentDev(32, 100000, gpu);
+
+BM_XentDev(64, 10000, gpu);
+BM_XentDev(64, 30000, gpu);
+BM_XentDev(64, 100000, gpu);
+
+/// Only the smaller tests for CPU. Otherwise, it's too slow
+BM_XentDev(16, 10000, cpu);
+BM_XentDev(32, 10000, cpu);
+BM_XentDev(64, 10000, cpu);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/lib/core/arena.cc b/tensorflow/core/lib/core/arena.cc
new file mode 100644
index 0000000000..ceb1001af0
--- /dev/null
+++ b/tensorflow/core/lib/core/arena.cc
@@ -0,0 +1,246 @@
+// This approach to arenas overcomes many of the limitations described
+// in the "Specialized allocators" section of
+// http://www.pdos.lcs.mit.edu/~dm/c++-new.html
+//
+// A somewhat similar approach to Gladiator, but for heap-detection, was
+// suggested by Ron van der Wal and Scott Meyers at
+// http://www.aristeia.com/BookErrata/M27Comments_frames.html
+
+#include "tensorflow/core/lib/core/arena.h"
+
+#include <assert.h>
+#include <unistd.h>
+
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+namespace tensorflow {
+namespace core {
+
+static const int kPageSize = getpagesize();
+
+// ----------------------------------------------------------------------
+// Arena::Arena()
+// Arena::~Arena()
+// Destroying the arena automatically calls Reset()
+// ----------------------------------------------------------------------
+
+Arena::Arena(const size_t block_size)
+ : remaining_(0),
+ block_size_(block_size),
+ freestart_(NULL), // set for real in Reset()
+ blocks_alloced_(1),
+ overflow_blocks_(NULL) {
+ assert(block_size > kDefaultAlignment);
+
+ first_blocks_[0].mem = reinterpret_cast<char*>(malloc(block_size_));
+
+ first_blocks_[0].size = block_size_;
+
+ Reset();
+}
+
+Arena::~Arena() {
+ FreeBlocks();
+ assert(overflow_blocks_ == NULL); // FreeBlocks() should do that
+ // The first X blocks stay allocated always by default. Delete them now.
+ for (size_t i = 0; i < blocks_alloced_; ++i) free(first_blocks_[i].mem);
+}
+
+// Returns true iff it advances freestart_ to the first position
+// satisfying alignment without exhausting the current block.
+bool Arena::SatisfyAlignment(size_t alignment) {
+ const size_t overage = reinterpret_cast<size_t>(freestart_) & (alignment - 1);
+ if (overage > 0) {
+ const size_t waste = alignment - overage;
+ if (waste >= remaining_) {
+ return false;
+ }
+ freestart_ += waste;
+ remaining_ -= waste;
+ }
+ DCHECK_EQ(0, reinterpret_cast<size_t>(freestart_) & (alignment - 1));
+ return true;
+}
+
+// ----------------------------------------------------------------------
+// Arena::Reset()
+// Clears all the memory an arena is using.
+// ----------------------------------------------------------------------
+
+void Arena::Reset() {
+ FreeBlocks();
+ freestart_ = first_blocks_[0].mem;
+ remaining_ = first_blocks_[0].size;
+
+ // There is no guarantee the first block is properly aligned, so
+ // enforce that now.
+ CHECK(SatisfyAlignment(kDefaultAlignment));
+
+ freestart_when_empty_ = freestart_;
+}
+
+// ----------------------------------------------------------------------
+// Arena::MakeNewBlock()
+// Our sbrk() equivalent. We always make blocks of the same size
+// (though GetMemory() can also make a new block for really big
+// data.
+// ----------------------------------------------------------------------
+
+void Arena::MakeNewBlock(const uint32 alignment) {
+ AllocatedBlock* block = AllocNewBlock(block_size_, alignment);
+ freestart_ = block->mem;
+ remaining_ = block->size;
+ CHECK(SatisfyAlignment(alignment));
+}
+
+// The following simple numeric routines also exist in util/math/mathutil.h
+// but we don't want to depend on that library.
+
+// Euclid's algorithm for Greatest Common Denominator.
+static uint32 GCD(uint32 x, uint32 y) {
+ while (y != 0) {
+ uint32 r = x % y;
+ x = y;
+ y = r;
+ }
+ return x;
+}
+
+static uint32 LeastCommonMultiple(uint32 a, uint32 b) {
+ if (a > b) {
+ return (a / GCD(a, b)) * b;
+ } else if (a < b) {
+ return (b / GCD(b, a)) * a;
+ } else {
+ return a;
+ }
+}
+
+// -------------------------------------------------------------
+// Arena::AllocNewBlock()
+// Adds and returns an AllocatedBlock.
+// The returned AllocatedBlock* is valid until the next call
+// to AllocNewBlock or Reset. (i.e. anything that might
+// affect overflow_blocks_).
+// -------------------------------------------------------------
+
+Arena::AllocatedBlock* Arena::AllocNewBlock(const size_t block_size,
+ const uint32 alignment) {
+ AllocatedBlock* block;
+ // Find the next block.
+ if (blocks_alloced_ < TF_ARRAYSIZE(first_blocks_)) {
+ // Use one of the pre-allocated blocks
+ block = &first_blocks_[blocks_alloced_++];
+ } else { // oops, out of space, move to the vector
+ if (overflow_blocks_ == NULL)
+ overflow_blocks_ = new std::vector<AllocatedBlock>;
+ // Adds another block to the vector.
+ overflow_blocks_->resize(overflow_blocks_->size() + 1);
+ // block points to the last block of the vector.
+ block = &overflow_blocks_->back();
+ }
+
+ // NOTE(tucker): this utility is made slightly more complex by
+ // not disallowing the case where alignment > block_size.
+ // Can we, without breaking existing code?
+
+ // Must be a multiple of kDefaultAlignment, unless requested
+ // alignment is 1, in which case we don't care at all.
+ const uint32 adjusted_alignment =
+ (alignment > 1 ? LeastCommonMultiple(alignment, kDefaultAlignment) : 1);
+
+ CHECK_LE(adjusted_alignment, 1 << 20)
+ << "Alignment on boundaries greater than 1MB not supported.";
+
+ // If block_size > alignment we force block_size to be a multiple
+ // of alignment; if block_size < alignment we make no adjustment.
+ size_t adjusted_block_size = block_size;
+ if (adjusted_alignment > 1) {
+ if (adjusted_block_size > adjusted_alignment) {
+ const uint32 excess = adjusted_block_size % adjusted_alignment;
+ adjusted_block_size += (excess > 0 ? adjusted_alignment - excess : 0);
+ }
+ block->mem = reinterpret_cast<char*>(
+ port::aligned_malloc(adjusted_block_size, adjusted_alignment));
+ } else {
+ block->mem = reinterpret_cast<char*>(malloc(adjusted_block_size));
+ }
+ block->size = adjusted_block_size;
+ CHECK(NULL != block->mem) << "block_size=" << block_size
+ << " adjusted_block_size=" << adjusted_block_size
+ << " alignment=" << alignment
+ << " adjusted_alignment=" << adjusted_alignment;
+
+ return block;
+}
+
+// ----------------------------------------------------------------------
+// Arena::GetMemoryFallback()
+// We take memory out of our pool, aligned on the byte boundary
+// requested. If we don't have space in our current pool, we
+// allocate a new block (wasting the remaining space in the
+// current block) and give you that. If your memory needs are
+// too big for a single block, we make a special your-memory-only
+// allocation -- this is equivalent to not using the arena at all.
+// ----------------------------------------------------------------------
+
+void* Arena::GetMemoryFallback(const size_t size, const int alignment) {
+ if (0 == size) {
+ return NULL; // stl/stl_alloc.h says this is okay
+ }
+
+ // alignment must be a positive power of 2.
+ CHECK(alignment > 0 && 0 == (alignment & (alignment - 1)));
+
+ // If the object is more than a quarter of the block size, allocate
+ // it separately to avoid wasting too much space in leftover bytes.
+ if (block_size_ == 0 || size > block_size_ / 4) {
+ return AllocNewBlock(size, alignment)->mem;
+ }
+
+ // Enforce alignment on freestart_ then check for adequate space,
+ // which may require starting a new block.
+ if (!SatisfyAlignment(alignment) || size > remaining_) {
+ MakeNewBlock(alignment);
+ }
+ CHECK_LE(size, remaining_);
+
+ remaining_ -= size;
+ void* result = freestart_;
+ freestart_ += size;
+
+ return result;
+}
+
+// ----------------------------------------------------------------------
+// Arena::ReturnMemoryFallback()
+// Arena::FreeBlocks()
+// Unlike GetMemory(), which does actual work, ReturnMemory() is a
+// no-op: we don't "free" memory until Reset() is called. We do
+// update some stats, though. Note we do no checking that the
+// pointer you pass in was actually allocated by us, or that it
+// was allocated for the size you say, so be careful here!
+// FreeBlocks() does the work for Reset(), actually freeing all
+// memory allocated in one fell swoop.
+// ----------------------------------------------------------------------
+
+void Arena::FreeBlocks() {
+ for (size_t i = 1; i < blocks_alloced_; ++i) { // keep first block alloced
+ free(first_blocks_[i].mem);
+ first_blocks_[i].mem = NULL;
+ first_blocks_[i].size = 0;
+ }
+ blocks_alloced_ = 1;
+ if (overflow_blocks_ != NULL) {
+ std::vector<AllocatedBlock>::iterator it;
+ for (it = overflow_blocks_->begin(); it != overflow_blocks_->end(); ++it) {
+ free(it->mem);
+ }
+ delete overflow_blocks_; // These should be used very rarely
+ overflow_blocks_ = NULL;
+ }
+}
+
+} // namespace core
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/arena.h b/tensorflow/core/lib/core/arena.h
new file mode 100644
index 0000000000..59896803bb
--- /dev/null
+++ b/tensorflow/core/lib/core/arena.h
@@ -0,0 +1,90 @@
+// TODO(vrv): Switch this to an open-sourced version of Arena.
+
+#ifndef TENSORFLOW_LIB_CORE_ARENA_H_
+#define TENSORFLOW_LIB_CORE_ARENA_H_
+
+#include <assert.h>
+
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace core {
+
+// This class is "thread-compatible": different threads can access the
+// arena at the same time without locking, as long as they use only
+// const methods.
+class Arena {
+ public:
+ // Allocates a thread-compatible arena with the specified block size.
+ explicit Arena(const size_t block_size);
+ ~Arena();
+
+ char* Alloc(const size_t size) {
+ return reinterpret_cast<char*>(GetMemory(size, 1));
+ }
+
+ void Reset();
+
+// This should be the worst-case alignment for any type. This is
+// good for IA-32, SPARC version 7 (the last one I know), and
+// supposedly Alpha. i386 would be more time-efficient with a
+// default alignment of 8, but ::operator new() uses alignment of 4,
+// and an assertion will fail below after the call to MakeNewBlock()
+// if you try to use a larger alignment.
+#ifdef __i386__
+ static const int kDefaultAlignment = 4;
+#else
+ static const int kDefaultAlignment = 8;
+#endif
+
+ protected:
+ bool SatisfyAlignment(const size_t alignment);
+ void MakeNewBlock(const uint32 alignment);
+ void* GetMemoryFallback(const size_t size, const int align);
+ void* GetMemory(const size_t size, const int align) {
+ assert(remaining_ <= block_size_); // an invariant
+ if (size > 0 && size < remaining_ && align == 1) { // common case
+ void* result = freestart_;
+ freestart_ += size;
+ remaining_ -= size;
+ return result;
+ }
+ return GetMemoryFallback(size, align);
+ }
+
+ size_t remaining_;
+
+ private:
+ struct AllocatedBlock {
+ char* mem;
+ size_t size;
+ };
+
+ // Allocate new new block of at least block_size, with the specified
+ // alignment.
+ // The returned AllocatedBlock* is valid until the next call to AllocNewBlock
+ // or Reset (i.e. anything that might affect overflow_blocks_).
+ AllocatedBlock* AllocNewBlock(const size_t block_size,
+ const uint32 alignment);
+
+ const size_t block_size_;
+ char* freestart_; // beginning of the free space in most recent block
+ char* freestart_when_empty_; // beginning of the free space when we're empty
+ // STL vector isn't as efficient as it could be, so we use an array at first
+ size_t blocks_alloced_; // how many of the first_blocks_ have been alloced
+ AllocatedBlock first_blocks_[16]; // the length of this array is arbitrary
+ // if the first_blocks_ aren't enough, expand into overflow_blocks_.
+ std::vector<AllocatedBlock>* overflow_blocks_;
+
+ void FreeBlocks(); // Frees all except first block
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Arena);
+};
+
+} // namespace core
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_ARENA_H_
diff --git a/tensorflow/core/lib/core/arena_test.cc b/tensorflow/core/lib/core/arena_test.cc
new file mode 100644
index 0000000000..fa147c3014
--- /dev/null
+++ b/tensorflow/core/lib/core/arena_test.cc
@@ -0,0 +1,92 @@
+#include "tensorflow/core/lib/core/arena.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace core {
+namespace {
+
+// Write random data to allocated memory
+static void TestMemory(void* mem, int size) {
+ // Check that we can memset the entire memory
+ memset(mem, 0xaa, size);
+
+ // Do some memory allocation to check that the arena doesn't mess up
+ // the internal memory allocator
+ char* tmp[100];
+ for (size_t i = 0; i < TF_ARRAYSIZE(tmp); i++) {
+ tmp[i] = new char[i * i + 1];
+ }
+
+ memset(mem, 0xcc, size);
+
+ // Free up the allocated memory;
+ for (size_t i = 0; i < TF_ARRAYSIZE(tmp); i++) {
+ delete[] tmp[i];
+ }
+
+ // Check that we can memset the entire memory
+ memset(mem, 0xee, size);
+}
+
+TEST(ArenaTest, TestBasicArena) {
+ Arena a(1024);
+ char* memory = a.Alloc(100);
+ ASSERT_NE(memory, nullptr);
+ TestMemory(memory, 100);
+
+ // Allocate again
+ memory = a.Alloc(100);
+ ASSERT_NE(memory, nullptr);
+ TestMemory(memory, 100);
+}
+
+TEST(ArenaTest, TestVariousArenaSizes) {
+ {
+ Arena a(1024);
+
+ // Allocate blocksize
+ char* memory = a.Alloc(1024);
+ ASSERT_NE(memory, nullptr);
+ TestMemory(memory, 1024);
+
+ // Allocate another blocksize
+ char* memory2 = a.Alloc(1024);
+ ASSERT_NE(memory2, nullptr);
+ TestMemory(memory2, 1024);
+ }
+
+ // Allocate an arena and allocate two blocks
+ // that together exceed a block size
+ {
+ Arena a(1024);
+
+ //
+ char* memory = a.Alloc(768);
+ ASSERT_NE(memory, nullptr);
+ TestMemory(memory, 768);
+
+ // Allocate another blocksize
+ char* memory2 = a.Alloc(768);
+ ASSERT_NE(memory2, nullptr);
+ TestMemory(memory2, 768);
+ }
+
+ // Allocate larger than a blocksize
+ {
+ Arena a(1024);
+
+ char* memory = a.Alloc(10240);
+ ASSERT_NE(memory, nullptr);
+ TestMemory(memory, 10240);
+
+ // Allocate another blocksize
+ char* memory2 = a.Alloc(1234);
+ ASSERT_NE(memory2, nullptr);
+ TestMemory(memory2, 1234);
+ }
+}
+
+} // namespace
+} // namespace core
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/bit_cast_test.cc b/tensorflow/core/lib/core/bit_cast_test.cc
new file mode 100644
index 0000000000..0ea583e96f
--- /dev/null
+++ b/tensorflow/core/lib/core/bit_cast_test.cc
@@ -0,0 +1,95 @@
+// Unit test for bit_cast template.
+
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+// Marshall and unmarshall.
+// ISO spec C++ section 3.9 promises this will work.
+
+template <int N>
+struct marshall {
+ char buf[N];
+};
+
+template <class T>
+void TestMarshall(const T values[], int num_values) {
+ for (int i = 0; i < num_values; ++i) {
+ T t0 = values[i];
+ marshall<sizeof(T)> m0 = bit_cast<marshall<sizeof(T)> >(t0);
+ T t1 = bit_cast<T>(m0);
+ marshall<sizeof(T)> m1 = bit_cast<marshall<sizeof(T)> >(t1);
+ ASSERT_EQ(0, memcmp(&t0, &t1, sizeof(T)));
+ ASSERT_EQ(0, memcmp(&m0, &m1, sizeof(T)));
+ }
+}
+
+// Convert back and forth to an integral type. The C++ standard does
+// not guarantee this will work.
+//
+// There are implicit assumptions about sizeof(float) and
+// sizeof(double). These assumptions are quite extant everywhere.
+
+template <class T, class I>
+void TestIntegral(const T values[], int num_values) {
+ for (int i = 0; i < num_values; ++i) {
+ T t0 = values[i];
+ I i0 = bit_cast<I>(t0);
+ T t1 = bit_cast<T>(i0);
+ I i1 = bit_cast<I>(t1);
+ ASSERT_EQ(0, memcmp(&t0, &t1, sizeof(T)));
+ ASSERT_EQ(i0, i1);
+ }
+}
+
+TEST(BitCast, Bool) {
+ LOG(INFO) << "Test bool";
+ static const bool bool_list[] = {false, true};
+ TestMarshall<bool>(bool_list, TF_ARRAYSIZE(bool_list));
+}
+
+TEST(BitCast, Int32) {
+ static const int32 int_list[] = {0, 1, 100, 2147483647,
+ -1, -100, -2147483647, -2147483647 - 1};
+ TestMarshall<int32>(int_list, TF_ARRAYSIZE(int_list));
+}
+
+TEST(BitCast, Int64) {
+ static const int64 int64_list[] = {0, 1, 1LL << 40, -1, -(1LL << 40)};
+ TestMarshall<int64>(int64_list, TF_ARRAYSIZE(int64_list));
+}
+
+TEST(BitCast, Uint64) {
+ static const uint64 uint64_list[] = {0, 1, 1LLU << 40, 1LLU << 63};
+ TestMarshall<uint64>(uint64_list, TF_ARRAYSIZE(uint64_list));
+}
+
+TEST(BitCast, Float) {
+ static const float float_list[] = {0.0, 1.0, -1.0, 10.0, -10.0, 1e10,
+ 1e20, 1e-10, 1e-20, 2.71828, 3.14159};
+ TestMarshall<float>(float_list, TF_ARRAYSIZE(float_list));
+ TestIntegral<float, int32>(float_list, TF_ARRAYSIZE(float_list));
+ TestIntegral<float, uint32>(float_list, TF_ARRAYSIZE(float_list));
+}
+
+TEST(BitCast, Double) {
+ static const double double_list[] = {
+ 0.0,
+ 1.0,
+ -1.0,
+ 10.0,
+ -10.0,
+ 1e10,
+ 1e100,
+ 1e-10,
+ 1e-100,
+ 2.718281828459045,
+ 3.141592653589793238462643383279502884197169399375105820974944};
+ TestMarshall<double>(double_list, TF_ARRAYSIZE(double_list));
+ TestIntegral<double, int64>(double_list, TF_ARRAYSIZE(double_list));
+ TestIntegral<double, uint64>(double_list, TF_ARRAYSIZE(double_list));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/bits.h b/tensorflow/core/lib/core/bits.h
new file mode 100644
index 0000000000..5456a63168
--- /dev/null
+++ b/tensorflow/core/lib/core/bits.h
@@ -0,0 +1,84 @@
+#ifndef TENSORFLOW_LIB_CORE_BITS_H_
+#define TENSORFLOW_LIB_CORE_BITS_H_
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0.
+int Log2Floor(uint32 n);
+int Log2Floor64(uint64 n);
+
+// Return ceiling(log2(n)) for positive integer n. Returns -1 iff n == 0.
+int Log2Ceiling(uint32 n);
+int Log2Ceiling64(uint64 n);
+
+// ------------------------------------------------------------------------
+// Implementation details follow
+// ------------------------------------------------------------------------
+
+#if defined(__GNUC__)
+
+// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0.
+inline int Log2Floor(uint32 n) {
+ return n == 0 ? -1 : 31 ^ __builtin_clz(n);
+}
+
+// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0.
+inline int Log2Floor64(uint64 n) {
+ return n == 0 ? -1 : 63 ^ __builtin_clzll(n);
+}
+
+#else
+
+// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0.
+inline int Log2Floor(uint32 n) {
+ if (n == 0)
+ return -1;
+ int log = 0;
+ uint32 value = n;
+ for (int i = 4; i >= 0; --i) {
+ int shift = (1 << i);
+ uint32 x = value >> shift;
+ if (x != 0) {
+ value = x;
+ log += shift;
+ }
+ }
+ assert(value == 1);
+ return log;
+}
+
+// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0.
+// Log2Floor64() is defined in terms of Log2Floor32()
+inline int Log2Floor64(uint64 n) {
+ const uint32 topbits = static_cast<uint32>(n >> 32);
+ if (topbits == 0) {
+ // Top bits are zero, so scan in bottom bits
+ return Log2Floor(static_cast<uint32>(n));
+ } else {
+ return 32 + Log2Floor(topbits);
+ }
+}
+
+#endif
+
+inline int Log2Ceiling(uint32 n) {
+ int floor = Log2Floor(n);
+ if (n == (n & ~(n - 1))) // zero or a power of two
+ return floor;
+ else
+ return floor + 1;
+}
+
+inline int Log2Ceiling64(uint64 n) {
+ int floor = Log2Floor64(n);
+ if (n == (n & ~(n - 1))) // zero or a power of two
+ return floor;
+ else
+ return floor + 1;
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_BITS_H_
diff --git a/tensorflow/core/lib/core/blocking_counter.h b/tensorflow/core/lib/core/blocking_counter.h
new file mode 100644
index 0000000000..f141be2c76
--- /dev/null
+++ b/tensorflow/core/lib/core/blocking_counter.h
@@ -0,0 +1,41 @@
+#ifndef TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_
+#define TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+class BlockingCounter {
+ public:
+ BlockingCounter(int initial_count) : count_(initial_count) {
+ CHECK_GE(count_, 0);
+ }
+
+ ~BlockingCounter() {}
+
+ inline void DecrementCount() {
+ mutex_lock l(mu_);
+ --count_;
+ CHECK(count_ >= 0);
+ if (count_ == 0) {
+ cond_var_.notify_all();
+ }
+ }
+
+ inline void Wait() {
+ mutex_lock l(mu_);
+ while (count_ > 0) {
+ cond_var_.wait(l);
+ }
+ }
+
+ private:
+ int count_;
+ mutex mu_;
+ condition_variable cond_var_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_BLOCKING_COUNTER_H_
diff --git a/tensorflow/core/lib/core/blocking_counter_test.cc b/tensorflow/core/lib/core/blocking_counter_test.cc
new file mode 100644
index 0000000000..feb0342086
--- /dev/null
+++ b/tensorflow/core/lib/core/blocking_counter_test.cc
@@ -0,0 +1,36 @@
+#include <gtest/gtest.h>
+
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(BlockingCounterTest, TestZero) {
+ BlockingCounter bc(0);
+ bc.Wait();
+}
+
+TEST(BlockingCounterTest, TestSingleThread) {
+ BlockingCounter bc(2);
+ bc.DecrementCount();
+ bc.DecrementCount();
+ bc.Wait();
+}
+
+TEST(BlockingCounterTest, TestMultipleThread) {
+ int N = 3;
+ thread::ThreadPool* thread_pool =
+ new thread::ThreadPool(Env::Default(), "test", N);
+
+ BlockingCounter bc(N);
+ for (int i = 0; i < N; ++i) {
+ thread_pool->Schedule([&bc] { bc.DecrementCount(); });
+ }
+
+ bc.Wait();
+ delete thread_pool;
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/casts.h b/tensorflow/core/lib/core/casts.h
new file mode 100644
index 0000000000..5b72048ac5
--- /dev/null
+++ b/tensorflow/core/lib/core/casts.h
@@ -0,0 +1,85 @@
+// Various Google-specific casting templates.
+//
+// This code is compiled directly on many platforms, including client
+// platforms like Windows, Mac, and embedded systems. Before making
+// any changes here, make sure that you're not breaking any platforms.
+//
+
+#ifndef TENSORFLOW_LIB_CORE_CASTS_H_
+#define TENSORFLOW_LIB_CORE_CASTS_H_
+
+#include <string.h> // for memcpy
+
+namespace tensorflow {
+
+// bit_cast<Dest,Source> is a template function that implements the
+// equivalent of "*reinterpret_cast<Dest*>(&source)". We need this in
+// very low-level functions like the protobuf library and fast math
+// support.
+//
+// float f = 3.14159265358979;
+// int i = bit_cast<int32>(f);
+// // i = 0x40490fdb
+//
+// The classical address-casting method is:
+//
+// // WRONG
+// float f = 3.14159265358979; // WRONG
+// int i = * reinterpret_cast<int*>(&f); // WRONG
+//
+// The address-casting method actually produces undefined behavior
+// according to ISO C++ specification section 3.10 -15 -. Roughly, this
+// section says: if an object in memory has one type, and a program
+// accesses it with a different type, then the result is undefined
+// behavior for most values of "different type".
+//
+// This is true for any cast syntax, either *(int*)&f or
+// *reinterpret_cast<int*>(&f). And it is particularly true for
+// conversions between integral lvalues and floating-point lvalues.
+//
+// The purpose of 3.10 -15- is to allow optimizing compilers to assume
+// that expressions with different types refer to different memory. gcc
+// 4.0.1 has an optimizer that takes advantage of this. So a
+// non-conforming program quietly produces wildly incorrect output.
+//
+// The problem is not the use of reinterpret_cast. The problem is type
+// punning: holding an object in memory of one type and reading its bits
+// back using a different type.
+//
+// The C++ standard is more subtle and complex than this, but that
+// is the basic idea.
+//
+// Anyways ...
+//
+// bit_cast<> calls memcpy() which is blessed by the standard,
+// especially by the example in section 3.9 . Also, of course,
+// bit_cast<> wraps up the nasty logic in one place.
+//
+// Fortunately memcpy() is very fast. In optimized mode, with a
+// constant size, gcc 2.95.3, gcc 4.0.1, and msvc 7.1 produce inline
+// code with the minimal amount of data movement. On a 32-bit system,
+// memcpy(d,s,4) compiles to one load and one store, and memcpy(d,s,8)
+// compiles to two loads and two stores.
+//
+// I tested this code with gcc 2.95.3, gcc 4.0.1, icc 8.1, and msvc 7.1.
+//
+// WARNING: if Dest or Source is a non-POD type, the result of the memcpy
+// is likely to surprise you.
+//
+// Props to Bill Gibbons for the compile time assertion technique and
+// Art Komninos and Igor Tandetnik for the msvc experiments.
+//
+// -- mec 2005-10-17
+
+template <class Dest, class Source>
+inline Dest bit_cast(const Source& source) {
+ static_assert(sizeof(Dest) == sizeof(Source), "Sizes do not match");
+
+ Dest dest;
+ memcpy(&dest, &source, sizeof(dest));
+ return dest;
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_CASTS_H_
diff --git a/tensorflow/core/lib/core/coding.cc b/tensorflow/core/lib/core/coding.cc
new file mode 100644
index 0000000000..efff554742
--- /dev/null
+++ b/tensorflow/core/lib/core/coding.cc
@@ -0,0 +1,164 @@
+#include "tensorflow/core/lib/core/coding.h"
+
+namespace tensorflow {
+namespace core {
+
+void EncodeFixed32(char* buf, uint32 value) {
+ if (port::kLittleEndian) {
+ memcpy(buf, &value, sizeof(value));
+ } else {
+ buf[0] = value & 0xff;
+ buf[1] = (value >> 8) & 0xff;
+ buf[2] = (value >> 16) & 0xff;
+ buf[3] = (value >> 24) & 0xff;
+ }
+}
+
+void EncodeFixed64(char* buf, uint64 value) {
+ if (port::kLittleEndian) {
+ memcpy(buf, &value, sizeof(value));
+ } else {
+ buf[0] = value & 0xff;
+ buf[1] = (value >> 8) & 0xff;
+ buf[2] = (value >> 16) & 0xff;
+ buf[3] = (value >> 24) & 0xff;
+ buf[4] = (value >> 32) & 0xff;
+ buf[5] = (value >> 40) & 0xff;
+ buf[6] = (value >> 48) & 0xff;
+ buf[7] = (value >> 56) & 0xff;
+ }
+}
+
+void PutFixed32(string* dst, uint32 value) {
+ char buf[sizeof(value)];
+ EncodeFixed32(buf, value);
+ dst->append(buf, sizeof(buf));
+}
+
+void PutFixed64(string* dst, uint64 value) {
+ char buf[sizeof(value)];
+ EncodeFixed64(buf, value);
+ dst->append(buf, sizeof(buf));
+}
+
+char* EncodeVarint32(char* dst, uint32 v) {
+ // Operate on characters as unsigneds
+ unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
+ static const int B = 128;
+ if (v < (1 << 7)) {
+ *(ptr++) = v;
+ } else if (v < (1 << 14)) {
+ *(ptr++) = v | B;
+ *(ptr++) = v >> 7;
+ } else if (v < (1 << 21)) {
+ *(ptr++) = v | B;
+ *(ptr++) = (v >> 7) | B;
+ *(ptr++) = v >> 14;
+ } else if (v < (1 << 28)) {
+ *(ptr++) = v | B;
+ *(ptr++) = (v >> 7) | B;
+ *(ptr++) = (v >> 14) | B;
+ *(ptr++) = v >> 21;
+ } else {
+ *(ptr++) = v | B;
+ *(ptr++) = (v >> 7) | B;
+ *(ptr++) = (v >> 14) | B;
+ *(ptr++) = (v >> 21) | B;
+ *(ptr++) = v >> 28;
+ }
+ return reinterpret_cast<char*>(ptr);
+}
+
+void PutVarint32(string* dst, uint32 v) {
+ char buf[5];
+ char* ptr = EncodeVarint32(buf, v);
+ dst->append(buf, ptr - buf);
+}
+
+char* EncodeVarint64(char* dst, uint64 v) {
+ static const int B = 128;
+ unsigned char* ptr = reinterpret_cast<unsigned char*>(dst);
+ while (v >= B) {
+ *(ptr++) = (v & (B - 1)) | B;
+ v >>= 7;
+ }
+ *(ptr++) = static_cast<unsigned char>(v);
+ return reinterpret_cast<char*>(ptr);
+}
+
+void PutVarint64(string* dst, uint64 v) {
+ char buf[10];
+ char* ptr = EncodeVarint64(buf, v);
+ dst->append(buf, ptr - buf);
+}
+
+int VarintLength(uint64_t v) {
+ int len = 1;
+ while (v >= 128) {
+ v >>= 7;
+ len++;
+ }
+ return len;
+}
+
+const char* GetVarint32PtrFallback(const char* p, const char* limit,
+ uint32* value) {
+ uint32 result = 0;
+ for (uint32 shift = 0; shift <= 28 && p < limit; shift += 7) {
+ uint32 byte = *(reinterpret_cast<const unsigned char*>(p));
+ p++;
+ if (byte & 128) {
+ // More bytes are present
+ result |= ((byte & 127) << shift);
+ } else {
+ result |= (byte << shift);
+ *value = result;
+ return reinterpret_cast<const char*>(p);
+ }
+ }
+ return NULL;
+}
+
+bool GetVarint32(StringPiece* input, uint32* value) {
+ const char* p = input->data();
+ const char* limit = p + input->size();
+ const char* q = GetVarint32Ptr(p, limit, value);
+ if (q == NULL) {
+ return false;
+ } else {
+ *input = StringPiece(q, limit - q);
+ return true;
+ }
+}
+
+const char* GetVarint64Ptr(const char* p, const char* limit, uint64* value) {
+ uint64 result = 0;
+ for (uint32 shift = 0; shift <= 63 && p < limit; shift += 7) {
+ uint64 byte = *(reinterpret_cast<const unsigned char*>(p));
+ p++;
+ if (byte & 128) {
+ // More bytes are present
+ result |= ((byte & 127) << shift);
+ } else {
+ result |= (byte << shift);
+ *value = result;
+ return reinterpret_cast<const char*>(p);
+ }
+ }
+ return NULL;
+}
+
+bool GetVarint64(StringPiece* input, uint64* value) {
+ const char* p = input->data();
+ const char* limit = p + input->size();
+ const char* q = GetVarint64Ptr(p, limit, value);
+ if (q == NULL) {
+ return false;
+ } else {
+ *input = StringPiece(q, limit - q);
+ return true;
+ }
+}
+
+} // namespace core
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/coding.h b/tensorflow/core/lib/core/coding.h
new file mode 100644
index 0000000000..0c14bf1bbf
--- /dev/null
+++ b/tensorflow/core/lib/core/coding.h
@@ -0,0 +1,55 @@
+// Endian-neutral encoding:
+// * Fixed-length numbers are encoded with least-significant byte first
+// * In addition we support variable length "varint" encoding
+// * Strings are encoded prefixed by their length in varint format
+
+#ifndef TENSORFLOW_LIB_CORE_CODING_H_
+#define TENSORFLOW_LIB_CORE_CODING_H_
+
+#include "tensorflow/core/lib/core/raw_coding.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace core {
+
+// Lower-level versions of Put... that write directly into a character buffer
+// REQUIRES: dst has enough space for the value being written
+extern void EncodeFixed32(char* dst, uint32 value);
+extern void EncodeFixed64(char* dst, uint64 value);
+extern void PutFixed32(string* dst, uint32 value);
+extern void PutFixed64(string* dst, uint64 value);
+
+extern void PutVarint32(string* dst, uint32 value);
+extern void PutVarint64(string* dst, uint64 value);
+
+extern bool GetVarint32(StringPiece* input, uint32* value);
+extern bool GetVarint64(StringPiece* input, uint64* value);
+
+extern const char* GetVarint32Ptr(const char* p, const char* limit, uint32* v);
+extern const char* GetVarint64Ptr(const char* p, const char* limit, uint64* v);
+
+// Internal routine for use by fallback path of GetVarint32Ptr
+extern const char* GetVarint32PtrFallback(const char* p, const char* limit,
+ uint32* value);
+inline const char* GetVarint32Ptr(const char* p, const char* limit,
+ uint32* value) {
+ if (p < limit) {
+ uint32 result = *(reinterpret_cast<const unsigned char*>(p));
+ if ((result & 128) == 0) {
+ *value = result;
+ return p + 1;
+ }
+ }
+ return GetVarint32PtrFallback(p, limit, value);
+}
+
+extern char* EncodeVarint64(char* dst, uint64 v);
+
+// Returns the length of the varint32 or varint64 encoding of "v"
+extern int VarintLength(uint64_t v);
+
+} // namespace core
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_CODING_H_
diff --git a/tensorflow/core/lib/core/coding_test.cc b/tensorflow/core/lib/core/coding_test.cc
new file mode 100644
index 0000000000..5e9e2c5e96
--- /dev/null
+++ b/tensorflow/core/lib/core/coding_test.cc
@@ -0,0 +1,168 @@
+#include "tensorflow/core/lib/core/coding.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace core {
+
+TEST(Coding, Fixed32) {
+ static const int N = 100000;
+
+ string s;
+ for (uint32 v = 0; v < N; v++) {
+ char buf[sizeof(uint32)];
+ EncodeFixed32(buf, v);
+ s.append(buf, sizeof(buf));
+ }
+
+ const char* p = s.data();
+ for (uint32 v = 0; v < N; v++) {
+ uint32 actual = DecodeFixed32(p);
+ ASSERT_EQ(v, actual);
+ p += sizeof(uint32);
+ }
+}
+
+TEST(Coding, Fixed64) {
+ string s;
+ for (int power = 0; power <= 63; power++) {
+ uint64 v = static_cast<uint64>(1) << power;
+ char buf[sizeof(uint64)];
+ EncodeFixed64(buf, v - 1);
+ s.append(buf, sizeof(buf));
+ EncodeFixed64(buf, v + 0);
+ s.append(buf, sizeof(buf));
+ EncodeFixed64(buf, v + 1);
+ s.append(buf, sizeof(buf));
+ }
+
+ const char* p = s.data();
+ for (int power = 0; power <= 63; power++) {
+ uint64 v = static_cast<uint64>(1) << power;
+ uint64 actual;
+ actual = DecodeFixed64(p);
+ ASSERT_EQ(v - 1, actual);
+ p += sizeof(uint64);
+
+ actual = DecodeFixed64(p);
+ ASSERT_EQ(v + 0, actual);
+ p += sizeof(uint64);
+
+ actual = DecodeFixed64(p);
+ ASSERT_EQ(v + 1, actual);
+ p += sizeof(uint64);
+ }
+}
+
+// Test that encoding routines generate little-endian encodings
+TEST(Coding, EncodingOutput) {
+ char dst[8];
+ EncodeFixed32(dst, 0x04030201);
+ ASSERT_EQ(0x01, static_cast<int>(dst[0]));
+ ASSERT_EQ(0x02, static_cast<int>(dst[1]));
+ ASSERT_EQ(0x03, static_cast<int>(dst[2]));
+ ASSERT_EQ(0x04, static_cast<int>(dst[3]));
+
+ EncodeFixed64(dst, 0x0807060504030201ull);
+ ASSERT_EQ(0x01, static_cast<int>(dst[0]));
+ ASSERT_EQ(0x02, static_cast<int>(dst[1]));
+ ASSERT_EQ(0x03, static_cast<int>(dst[2]));
+ ASSERT_EQ(0x04, static_cast<int>(dst[3]));
+ ASSERT_EQ(0x05, static_cast<int>(dst[4]));
+ ASSERT_EQ(0x06, static_cast<int>(dst[5]));
+ ASSERT_EQ(0x07, static_cast<int>(dst[6]));
+ ASSERT_EQ(0x08, static_cast<int>(dst[7]));
+}
+
+TEST(Coding, Varint32) {
+ string s;
+ for (uint32 i = 0; i < (32 * 32); i++) {
+ uint32 v = (i / 32) << (i % 32);
+ PutVarint32(&s, v);
+ }
+
+ const char* p = s.data();
+ const char* limit = p + s.size();
+ for (uint32 i = 0; i < (32 * 32); i++) {
+ uint32 expected = (i / 32) << (i % 32);
+ uint32 actual;
+ p = GetVarint32Ptr(p, limit, &actual);
+ ASSERT_TRUE(p != NULL);
+ ASSERT_EQ(expected, actual);
+ }
+ ASSERT_EQ(p, s.data() + s.size());
+}
+
+TEST(Coding, Varint64) {
+ // Construct the list of values to check
+ std::vector<uint64> values;
+ // Some special values
+ values.push_back(0);
+ values.push_back(100);
+ values.push_back(~static_cast<uint64>(0));
+ values.push_back(~static_cast<uint64>(0) - 1);
+ for (uint32 k = 0; k < 64; k++) {
+ // Test values near powers of two
+ const uint64 power = 1ull << k;
+ values.push_back(power);
+ values.push_back(power - 1);
+ values.push_back(power + 1);
+ }
+
+ string s;
+ for (size_t i = 0; i < values.size(); i++) {
+ PutVarint64(&s, values[i]);
+ }
+
+ const char* p = s.data();
+ const char* limit = p + s.size();
+ for (size_t i = 0; i < values.size(); i++) {
+ ASSERT_TRUE(p < limit);
+ uint64 actual;
+ p = GetVarint64Ptr(p, limit, &actual);
+ ASSERT_TRUE(p != NULL);
+ ASSERT_EQ(values[i], actual);
+ }
+ ASSERT_EQ(p, limit);
+}
+
+TEST(Coding, Varint32Overflow) {
+ uint32 result;
+ string input("\x81\x82\x83\x84\x85\x11");
+ ASSERT_TRUE(GetVarint32Ptr(input.data(), input.data() + input.size(),
+ &result) == NULL);
+}
+
+TEST(Coding, Varint32Truncation) {
+ uint32 large_value = (1u << 31) + 100;
+ string s;
+ PutVarint32(&s, large_value);
+ uint32 result;
+ for (size_t len = 0; len < s.size() - 1; len++) {
+ ASSERT_TRUE(GetVarint32Ptr(s.data(), s.data() + len, &result) == NULL);
+ }
+ ASSERT_TRUE(GetVarint32Ptr(s.data(), s.data() + s.size(), &result) != NULL);
+ ASSERT_EQ(large_value, result);
+}
+
+TEST(Coding, Varint64Overflow) {
+ uint64 result;
+ string input("\x81\x82\x83\x84\x85\x81\x82\x83\x84\x85\x11");
+ ASSERT_TRUE(GetVarint64Ptr(input.data(), input.data() + input.size(),
+ &result) == NULL);
+}
+
+TEST(Coding, Varint64Truncation) {
+ uint64 large_value = (1ull << 63) + 100ull;
+ string s;
+ PutVarint64(&s, large_value);
+ uint64 result;
+ for (size_t len = 0; len < s.size() - 1; len++) {
+ ASSERT_TRUE(GetVarint64Ptr(s.data(), s.data() + len, &result) == NULL);
+ }
+ ASSERT_TRUE(GetVarint64Ptr(s.data(), s.data() + s.size(), &result) != NULL);
+ ASSERT_EQ(large_value, result);
+}
+
+} // namespace core
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/command_line_flags.cc b/tensorflow/core/lib/core/command_line_flags.cc
new file mode 100644
index 0000000000..0f1072ffaa
--- /dev/null
+++ b/tensorflow/core/lib/core/command_line_flags.cc
@@ -0,0 +1,94 @@
+#include "tensorflow/core/lib/core/command_line_flags.h"
+
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+namespace tensorflow {
+namespace {
+
+// Templated function to convert a string to target values.
+// Return true if the conversion is successful. Otherwise, return false.
+template <typename T>
+bool StringToValue(const string& content, T* value);
+
+template <>
+bool StringToValue<int32>(const string& content, int* value) {
+ return str_util::NumericParse32(content, value);
+}
+
+// Parse a single argument by linearly searching through the command table.
+// The input format is: --argument=value.
+// Return OK if the argument is used. It store the extracted value into the
+// matching flag.
+// Return NOT_FOUND if the argument is not recognized.
+// Retrun INVALID_ARGUMENT if the command is recognized, but fails to extract
+// its value.
+template <typename T>
+Status ParseArgument(const string& argument) {
+ for (auto& command :
+ internal::CommandLineFlagRegistry<int>::Instance()->commands) {
+ string prefix = strings::StrCat("--", command.name, "=");
+ if (tensorflow::StringPiece(argument).starts_with(prefix)) {
+ string content = argument.substr(prefix.length());
+ if (StringToValue<T>(content, command.value)) {
+ return Status::OK();
+ }
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("Cannot parse integer in: ", argument));
+ }
+ }
+ return Status(error::NOT_FOUND,
+ strings::StrCat("Unknown command: ", argument));
+}
+
+// A specialization for booleans. The input format is:
+// "--argument" or "--noargument".
+// Parse a single argument by linearly searching through the command table.
+// Return OK if the argument is used. The value is stored in the matching flag.
+// Return NOT_FOUND if the argument is not recognized.
+template <>
+Status ParseArgument<bool>(const string& argument) {
+ for (auto& command :
+ internal::CommandLineFlagRegistry<bool>::Instance()->commands) {
+ if (argument == strings::StrCat("--", command.name)) {
+ *command.value = true;
+ return Status::OK();
+ } else if (argument == strings::StrCat("--no", command.name)) {
+ *command.value = false;
+ return Status::OK();
+ }
+ }
+ return Status(error::NOT_FOUND,
+ strings::StrCat("Unknown command: ", argument));
+}
+} // namespace
+
+Status ParseCommandLineFlags(int* argc, char* argv[]) {
+ int unused_argc = 1;
+ for (int index = 1; index < *argc; ++index) {
+ Status s;
+ // Search bool commands.
+ s = ParseArgument<bool>(argv[index]);
+ if (s.ok()) {
+ continue;
+ }
+ if (s.code() != error::NOT_FOUND) {
+ return s;
+ }
+ // Search int32 commands.
+ s = ParseArgument<int32>(argv[index]);
+ if (s.ok()) {
+ continue;
+ }
+ if (s.code() != error::NOT_FOUND) {
+ return s;
+ }
+ // Pointer swap the unused argument to the front.
+ std::swap(argv[unused_argc++], argv[index]);
+ }
+ *argc = unused_argc;
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/command_line_flags.h b/tensorflow/core/lib/core/command_line_flags.h
new file mode 100644
index 0000000000..f1a94c11f9
--- /dev/null
+++ b/tensorflow/core/lib/core/command_line_flags.h
@@ -0,0 +1,60 @@
+#ifndef TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_
+#define TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+namespace internal {
+
+template <typename T>
+struct CommandLineFlagRegistry {
+ static CommandLineFlagRegistry* Instance() {
+ static CommandLineFlagRegistry instance_;
+ return &instance_;
+ }
+ struct Command {
+ string name;
+ T* value;
+ string text;
+ };
+ std::vector<Command> commands;
+
+ private:
+ CommandLineFlagRegistry() {}
+ TF_DISALLOW_COPY_AND_ASSIGN(CommandLineFlagRegistry);
+};
+
+template <typename T>
+struct CommandLineFlagRegister {
+ CommandLineFlagRegister(const string& name, T* val, const string& text) {
+ CommandLineFlagRegistry<T>::Instance()->commands.push_back(
+ {name, val, text});
+ }
+};
+
+#define TF_DEFINE_variable(type, name, default_value, text) \
+ type FLAGS_##name = default_value; \
+ namespace TF_flags_internal { \
+ tensorflow::internal::CommandLineFlagRegister<type> \
+ TF_flags_internal_var_##name(#name, &FLAGS_##name, text); \
+ } // namespace TF_flags_internal
+
+} // namespace internal
+
+#define TF_DEFINE_int32(name, default_value, text) \
+ TF_DEFINE_variable(int32, name, default_value, text);
+
+#define TF_DEFINE_bool(name, default_value, text) \
+ TF_DEFINE_variable(bool, name, default_value, text);
+
+// Parse argv[1]..argv[*argc-1] to options. Remove used arguments from the argv.
+// Returned the number of unused arguments in *argc.
+// Return error Status if the parsing encounters errors.
+// TODO(opensource): switch to a command line argument parser that can be
+// shared with other tests.
+Status ParseCommandLineFlags(int* argc, char* argv[]);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_COMMAND_LINE_FLAGS_H_
diff --git a/tensorflow/core/lib/core/error_codes.proto b/tensorflow/core/lib/core/error_codes.proto
new file mode 100644
index 0000000000..6735fd8f88
--- /dev/null
+++ b/tensorflow/core/lib/core/error_codes.proto
@@ -0,0 +1,145 @@
+syntax = "proto3";
+
+package tensorflow.error;
+// option cc_enable_arenas = true;
+
+// The canonical error codes for TensorFlow APIs.
+//
+// Warnings:
+//
+// - Do not change any numeric assignments.
+// - Changes to this list should only be made if there is a compelling
+// need that can't be satisfied in another way. Such changes
+// must be approved by at least two OWNERS.
+//
+// Sometimes multiple error codes may apply. Services should return
+// the most specific error code that applies. For example, prefer
+// OUT_OF_RANGE over FAILED_PRECONDITION if both codes apply.
+// Similarly prefer NOT_FOUND or ALREADY_EXISTS over FAILED_PRECONDITION.
+enum Code {
+ // Not an error; returned on success
+ OK = 0;
+
+ // The operation was cancelled (typically by the caller).
+ CANCELLED = 1;
+
+ // Unknown error. An example of where this error may be returned is
+ // if a Status value received from another address space belongs to
+ // an error-space that is not known in this address space. Also
+ // errors raised by APIs that do not return enough error information
+ // may be converted to this error.
+ UNKNOWN = 2;
+
+ // Client specified an invalid argument. Note that this differs
+ // from FAILED_PRECONDITION. INVALID_ARGUMENT indicates arguments
+ // that are problematic regardless of the state of the system
+ // (e.g., a malformed file name).
+ INVALID_ARGUMENT = 3;
+
+ // Deadline expired before operation could complete. For operations
+ // that change the state of the system, this error may be returned
+ // even if the operation has completed successfully. For example, a
+ // successful response from a server could have been delayed long
+ // enough for the deadline to expire.
+ DEADLINE_EXCEEDED = 4;
+
+ // Some requested entity (e.g., file or directory) was not found.
+ // For privacy reasons, this code *may* be returned when the client
+ // does not have the access right to the entity.
+ NOT_FOUND = 5;
+
+ // Some entity that we attempted to create (e.g., file or directory)
+ // already exists.
+ ALREADY_EXISTS = 6;
+
+ // The caller does not have permission to execute the specified
+ // operation. PERMISSION_DENIED must not be used for rejections
+ // caused by exhausting some resource (use RESOURCE_EXHAUSTED
+ // instead for those errors). PERMISSION_DENIED must not be
+ // used if the caller can not be identified (use UNAUTHENTICATED
+ // instead for those errors).
+ PERMISSION_DENIED = 7;
+
+ // The request does not have valid authentication credentials for the
+ // operation.
+ UNAUTHENTICATED = 16;
+
+ // Some resource has been exhausted, perhaps a per-user quota, or
+ // perhaps the entire file system is out of space.
+ RESOURCE_EXHAUSTED = 8;
+
+ // Operation was rejected because the system is not in a state
+ // required for the operation's execution. For example, directory
+ // to be deleted may be non-empty, an rmdir operation is applied to
+ // a non-directory, etc.
+ //
+ // A litmus test that may help a service implementor in deciding
+ // between FAILED_PRECONDITION, ABORTED, and UNAVAILABLE:
+ // (a) Use UNAVAILABLE if the client can retry just the failing call.
+ // (b) Use ABORTED if the client should retry at a higher-level
+ // (e.g., restarting a read-modify-write sequence).
+ // (c) Use FAILED_PRECONDITION if the client should not retry until
+ // the system state has been explicitly fixed. E.g., if an "rmdir"
+ // fails because the directory is non-empty, FAILED_PRECONDITION
+ // should be returned since the client should not retry unless
+ // they have first fixed up the directory by deleting files from it.
+ // (d) Use FAILED_PRECONDITION if the client performs conditional
+ // REST Get/Update/Delete on a resource and the resource on the
+ // server does not match the condition. E.g., conflicting
+ // read-modify-write on the same resource.
+ FAILED_PRECONDITION = 9;
+
+ // The operation was aborted, typically due to a concurrency issue
+ // like sequencer check failures, transaction aborts, etc.
+ //
+ // See litmus test above for deciding between FAILED_PRECONDITION,
+ // ABORTED, and UNAVAILABLE.
+ ABORTED = 10;
+
+ // Operation was attempted past the valid range. E.g., seeking or
+ // reading past end of file.
+ //
+ // Unlike INVALID_ARGUMENT, this error indicates a problem that may
+ // be fixed if the system state changes. For example, a 32-bit file
+ // system will generate INVALID_ARGUMENT if asked to read at an
+ // offset that is not in the range [0,2^32-1], but it will generate
+ // OUT_OF_RANGE if asked to read from an offset past the current
+ // file size.
+ //
+ // There is a fair bit of overlap between FAILED_PRECONDITION and
+ // OUT_OF_RANGE. We recommend using OUT_OF_RANGE (the more specific
+ // error) when it applies so that callers who are iterating through
+ // a space can easily look for an OUT_OF_RANGE error to detect when
+ // they are done.
+ OUT_OF_RANGE = 11;
+
+ // Operation is not implemented or not supported/enabled in this service.
+ UNIMPLEMENTED = 12;
+
+ // Internal errors. Means some invariants expected by underlying
+ // system has been broken. If you see one of these errors,
+ // something is very broken.
+ INTERNAL = 13;
+
+ // The service is currently unavailable. This is a most likely a
+ // transient condition and may be corrected by retrying with
+ // a backoff.
+ //
+ // See litmus test above for deciding between FAILED_PRECONDITION,
+ // ABORTED, and UNAVAILABLE.
+ UNAVAILABLE = 14;
+
+ // Unrecoverable data loss or corruption.
+ DATA_LOSS = 15;
+
+ // An extra enum entry to prevent people from writing code that
+ // fails to compile when a new code is added.
+ //
+ // Nobody should ever reference this enumeration entry. In particular,
+ // if you write C++ code that switches on this enumeration, add a default:
+ // case instead of a case that mentions this enumeration entry.
+ //
+ // Nobody should rely on the value (currently 20) listed here. It
+ // may change in the future.
+ DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_ = 20;
+}
diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h
new file mode 100644
index 0000000000..b0badd8c4d
--- /dev/null
+++ b/tensorflow/core/lib/core/errors.h
@@ -0,0 +1,131 @@
+#ifndef TENSORFLOW_LIB_CORE_ERRORS_H_
+#define TENSORFLOW_LIB_CORE_ERRORS_H_
+
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace errors {
+
+typedef ::tensorflow::error::Code Code;
+
+// Append some context to an error message. Each time we append
+// context put it on a new line, since it is possible for there
+// to be several layers of additional context.
+template <typename... Args>
+void AppendToMessage(::tensorflow::Status* status, Args... args) {
+ *status = ::tensorflow::Status(
+ status->code(),
+ strings::StrCat(status->error_message(), "\n\t", args...));
+}
+
+// For propagating errors when calling a function.
+#define TF_RETURN_IF_ERROR(expr) \
+ do { \
+ const ::tensorflow::Status _status = (expr); \
+ if (TF_PREDICT_FALSE(!_status.ok())) return _status; \
+ } while (0)
+
+#define TF_RETURN_WITH_CONTEXT_IF_ERROR(expr, ...) \
+ do { \
+ ::tensorflow::Status _status = (expr); \
+ if (TF_PREDICT_FALSE(!_status.ok())) { \
+ ::tensorflow::errors::AppendToMessage(&_status, __VA_ARGS__); \
+ return _status; \
+ } \
+ } while (0)
+
+// Convenience functions for generating and using error status.
+// Example usage:
+// status.Update(errors::InvalidArgument("The ", foo, " isn't right."));
+// if (errors::IsInvalidArgument(status)) { ... }
+// switch (status.code()) { case error::INVALID_ARGUMENT: ... }
+
+#define DECLARE_ERROR(FUNC, CONST) \
+ template <typename... Args> \
+ inline ::tensorflow::Status FUNC(Args... args) { \
+ return ::tensorflow::Status(::tensorflow::error::CONST, \
+ strings::StrCat(args...)); \
+ } \
+ inline bool Is##FUNC(const ::tensorflow::Status& status) { \
+ return status.code() == ::tensorflow::error::CONST; \
+ }
+
+DECLARE_ERROR(Cancelled, CANCELLED)
+DECLARE_ERROR(InvalidArgument, INVALID_ARGUMENT)
+DECLARE_ERROR(NotFound, NOT_FOUND)
+DECLARE_ERROR(AlreadyExists, ALREADY_EXISTS)
+DECLARE_ERROR(ResourceExhausted, RESOURCE_EXHAUSTED)
+DECLARE_ERROR(Unavailable, UNAVAILABLE)
+DECLARE_ERROR(FailedPrecondition, FAILED_PRECONDITION)
+DECLARE_ERROR(OutOfRange, OUT_OF_RANGE)
+DECLARE_ERROR(Unimplemented, UNIMPLEMENTED)
+DECLARE_ERROR(Internal, INTERNAL)
+DECLARE_ERROR(Aborted, ABORTED)
+DECLARE_ERROR(DeadlineExceeded, DEADLINE_EXCEEDED)
+DECLARE_ERROR(DataLoss, DATA_LOSS)
+DECLARE_ERROR(Unknown, UNKNOWN)
+DECLARE_ERROR(PermissionDenied, PERMISSION_DENIED)
+DECLARE_ERROR(Unauthenticated, UNAUTHENTICATED)
+
+#undef DECLARE_ERROR
+
+// The CanonicalCode() for non-errors.
+using ::tensorflow::error::OK;
+
+// Convenience macros for asserting and handling exceptional conditions.
+// Analogous to the CHECK* macros provided by logging.h.
+//
+// Example use:
+// void Compute(OperationContext* context) {
+// OP_REQUIRES(context, context->num_inputs() == 2,
+// errors::InvalidArgument("FooOp requires 2 arguments"));
+// ...
+// Status status = SomeUncertainMethod();
+// OP_REQUIRES_OK(context, status);
+// ...
+// }
+
+#define OP_REQUIRES(CTX, EXP, STATUS) \
+ if (!(EXP)) { \
+ ::tensorflow::Status _s(STATUS); \
+ VLOG(1) << _s; \
+ (CTX)->SetStatus(_s); \
+ return; \
+ }
+
+#define OP_REQUIRES_OK(CTX, STATUS) \
+ do { \
+ ::tensorflow::Status _s(STATUS); \
+ if (!_s.ok()) { \
+ LOG(WARNING) << _s; \
+ (CTX)->SetStatus(_s); \
+ return; \
+ } \
+ } while (0)
+
+#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \
+ if (!(EXP)) { \
+ ::tensorflow::Status _s(STATUS); \
+ VLOG(1) << _s; \
+ (CTX)->SetStatus(_s); \
+ (CALLBACK)(); \
+ return; \
+ }
+
+#define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \
+ do { \
+ ::tensorflow::Status _s(STATUS); \
+ if (!_s.ok()) { \
+ LOG(WARNING) << _s; \
+ (CTX)->SetStatus(_s); \
+ (CALLBACK)(); \
+ return; \
+ } \
+ } while (0)
+
+} // namespace errors
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_ERRORS_H_
diff --git a/tensorflow/core/lib/core/notification.h b/tensorflow/core/lib/core/notification.h
new file mode 100644
index 0000000000..071e24285a
--- /dev/null
+++ b/tensorflow/core/lib/core/notification.h
@@ -0,0 +1,42 @@
+#ifndef TENSORFLOW_UTIL_NOTIFICATION_H_
+#define TENSORFLOW_UTIL_NOTIFICATION_H_
+
+#include <assert.h>
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+class Notification {
+ public:
+ Notification() : notified_(false) {}
+ ~Notification() {}
+
+ void Notify() {
+ mutex_lock l(mu_);
+ assert(!notified_);
+ notified_ = true;
+ cv_.notify_all();
+ }
+
+ bool HasBeenNotified() {
+ mutex_lock l(mu_);
+ return notified_;
+ }
+
+ void WaitForNotification() {
+ mutex_lock l(mu_);
+ while (!notified_) {
+ cv_.wait(l);
+ }
+ }
+
+ private:
+ mutex mu_;
+ condition_variable cv_;
+ bool notified_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_NOTIFICATION_H_
diff --git a/tensorflow/core/lib/core/notification_test.cc b/tensorflow/core/lib/core/notification_test.cc
new file mode 100644
index 0000000000..a9e8942f05
--- /dev/null
+++ b/tensorflow/core/lib/core/notification_test.cc
@@ -0,0 +1,64 @@
+#include <gtest/gtest.h>
+
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(NotificationTest, TestSingleNotification) {
+ thread::ThreadPool* thread_pool =
+ new thread::ThreadPool(Env::Default(), "test", 1);
+
+ int counter = 0;
+ Notification start;
+ Notification proceed;
+ thread_pool->Schedule([&start, &proceed, &counter] {
+ start.Notify();
+ proceed.WaitForNotification();
+ ++counter;
+ });
+
+ // Wait for the thread to start
+ start.WaitForNotification();
+
+ // The thread should be waiting for the 'proceed' notification.
+ EXPECT_EQ(0, counter);
+
+ // Unblock the thread
+ proceed.Notify();
+
+ delete thread_pool; // Wait for closure to finish.
+
+ // Verify the counter has been incremented
+ EXPECT_EQ(1, counter);
+}
+
+TEST(NotificationTest, TestMultipleThreadsWaitingOnNotification) {
+ const int num_closures = 4;
+ thread::ThreadPool* thread_pool =
+ new thread::ThreadPool(Env::Default(), "test", num_closures);
+
+ mutex lock;
+ int counter = 0;
+ Notification n;
+
+ for (int i = 0; i < num_closures; ++i) {
+ thread_pool->Schedule([&n, &lock, &counter] {
+ n.WaitForNotification();
+ mutex_lock l(lock);
+ ++counter;
+ });
+ }
+ sleep(1);
+
+ EXPECT_EQ(0, counter);
+
+ n.Notify();
+ delete thread_pool; // Wait for all closures to finish.
+ EXPECT_EQ(4, counter);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/raw_coding.h b/tensorflow/core/lib/core/raw_coding.h
new file mode 100644
index 0000000000..1fe49b75bb
--- /dev/null
+++ b/tensorflow/core/lib/core/raw_coding.h
@@ -0,0 +1,43 @@
+#ifndef TENSORFLOW_LIB_CORE_RAW_CODING_H_
+#define TENSORFLOW_LIB_CORE_RAW_CODING_H_
+
+#include <string.h>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace core {
+
+// Lower-level versions of Get... that read directly from a character buffer
+// without any bounds checking.
+
+inline uint32 DecodeFixed32(const char* ptr) {
+ if (port::kLittleEndian) {
+ // Load the raw bytes
+ uint32 result;
+ memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load
+ return result;
+ } else {
+ return ((static_cast<uint32>(static_cast<unsigned char>(ptr[0]))) |
+ (static_cast<uint32>(static_cast<unsigned char>(ptr[1])) << 8) |
+ (static_cast<uint32>(static_cast<unsigned char>(ptr[2])) << 16) |
+ (static_cast<uint32>(static_cast<unsigned char>(ptr[3])) << 24));
+ }
+}
+
+inline uint64 DecodeFixed64(const char* ptr) {
+ if (port::kLittleEndian) {
+ // Load the raw bytes
+ uint64 result;
+ memcpy(&result, ptr, sizeof(result)); // gcc optimizes this to a plain load
+ return result;
+ } else {
+ uint64 lo = DecodeFixed32(ptr);
+ uint64 hi = DecodeFixed32(ptr + 4);
+ return (hi << 32) | lo;
+ }
+}
+
+} // namespace core
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_RAW_CODING_H_
diff --git a/tensorflow/core/lib/core/refcount.cc b/tensorflow/core/lib/core/refcount.cc
new file mode 100644
index 0000000000..3ed8c58eb8
--- /dev/null
+++ b/tensorflow/core/lib/core/refcount.cc
@@ -0,0 +1,35 @@
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace core {
+
+RefCounted::RefCounted() : ref_(1) {}
+
+RefCounted::~RefCounted() { DCHECK_EQ(ref_.load(), 0); }
+
+void RefCounted::Ref() const {
+ DCHECK_GE(ref_.load(), 1);
+ ref_.fetch_add(1, std::memory_order_relaxed);
+}
+
+bool RefCounted::Unref() const {
+ DCHECK_GT(ref_.load(), 0);
+ // If ref_==1, this object is owned only by the caller. Bypass a locked op
+ // in that case.
+ if (ref_.load(std::memory_order_acquire) == 1 || ref_.fetch_sub(1) == 1) {
+ // Make DCHECK in ~RefCounted happy
+ DCHECK((ref_.store(0), true));
+ delete this;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+bool RefCounted::RefCountIsOne() const {
+ return (ref_.load(std::memory_order_acquire) == 1);
+}
+
+} // namespace core
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/refcount.h b/tensorflow/core/lib/core/refcount.h
new file mode 100644
index 0000000000..f727750f9e
--- /dev/null
+++ b/tensorflow/core/lib/core/refcount.h
@@ -0,0 +1,63 @@
+#ifndef TENSORFLOW_LIB_CORE_REFCOUNT_H_
+#define TENSORFLOW_LIB_CORE_REFCOUNT_H_
+
+#include <atomic>
+
+namespace tensorflow {
+namespace core {
+
+class RefCounted {
+ public:
+ // Initial reference count is one.
+ RefCounted();
+
+ // Increments reference count by one.
+ void Ref() const;
+
+ // Decrements reference count by one. If the count remains
+ // positive, returns false. When the count reaches zero, returns
+ // true and deletes this, in which case the caller must not access
+ // the object afterward.
+ bool Unref() const;
+
+ // Return whether the reference count is one.
+ // If the reference count is used in the conventional way, a
+ // reference count of 1 implies that the current thread owns the
+ // reference and no other thread shares it.
+ // This call performs the test for a reference count of one, and
+ // performs the memory barrier needed for the owning thread
+ // to act on the object, knowing that it has exclusive access to the
+ // object.
+ bool RefCountIsOne() const;
+
+ protected:
+ // Make destructor protected so that RefCounted objects cannot
+ // be instantiated directly. Only subclasses can be instantiated.
+ virtual ~RefCounted();
+
+ private:
+ mutable std::atomic_int_fast32_t ref_;
+
+ RefCounted(const RefCounted&) = delete;
+ void operator=(const RefCounted&) = delete;
+};
+
+// Helper class to unref an object when out-of-scope.
+class ScopedUnref {
+ public:
+ explicit ScopedUnref(RefCounted* o) : obj_(o) {}
+ ~ScopedUnref() {
+ if (obj_) obj_->Unref();
+ }
+
+ private:
+ RefCounted* obj_;
+
+ ScopedUnref(const ScopedUnref&) = delete;
+ void operator=(const ScopedUnref&) = delete;
+};
+
+} // namespace core
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_REFCOUNT_H_
diff --git a/tensorflow/core/lib/core/refcount_test.cc b/tensorflow/core/lib/core/refcount_test.cc
new file mode 100644
index 0000000000..c042be2d61
--- /dev/null
+++ b/tensorflow/core/lib/core/refcount_test.cc
@@ -0,0 +1,92 @@
+#include "tensorflow/core/lib/core/refcount.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace core {
+namespace {
+
+static int constructed = 0;
+static int destroyed = 0;
+
+class MyRef : public RefCounted {
+ public:
+ MyRef() { constructed++; }
+ ~MyRef() override { destroyed++; }
+};
+
+class RefTest : public testing::Test {
+ public:
+ RefTest() {
+ constructed = 0;
+ destroyed = 0;
+ }
+};
+
+TEST_F(RefTest, New) {
+ MyRef* ref = new MyRef;
+ ASSERT_EQ(1, constructed);
+ ASSERT_EQ(0, destroyed);
+ ref->Unref();
+ ASSERT_EQ(1, constructed);
+ ASSERT_EQ(1, destroyed);
+}
+
+TEST_F(RefTest, RefUnref) {
+ MyRef* ref = new MyRef;
+ ASSERT_EQ(1, constructed);
+ ASSERT_EQ(0, destroyed);
+ ref->Ref();
+ ASSERT_EQ(0, destroyed);
+ ref->Unref();
+ ASSERT_EQ(0, destroyed);
+ ref->Unref();
+ ASSERT_EQ(1, destroyed);
+}
+
+TEST_F(RefTest, RefCountOne) {
+ MyRef* ref = new MyRef;
+ ASSERT_TRUE(ref->RefCountIsOne());
+ ref->Unref();
+}
+
+TEST_F(RefTest, RefCountNotOne) {
+ MyRef* ref = new MyRef;
+ ref->Ref();
+ ASSERT_FALSE(ref->RefCountIsOne());
+ ref->Unref();
+ ref->Unref();
+}
+
+TEST_F(RefTest, ConstRefUnref) {
+ const MyRef* cref = new MyRef;
+ ASSERT_EQ(1, constructed);
+ ASSERT_EQ(0, destroyed);
+ cref->Ref();
+ ASSERT_EQ(0, destroyed);
+ cref->Unref();
+ ASSERT_EQ(0, destroyed);
+ cref->Unref();
+ ASSERT_EQ(1, destroyed);
+}
+
+TEST_F(RefTest, ReturnOfUnref) {
+ MyRef* ref = new MyRef;
+ ref->Ref();
+ EXPECT_FALSE(ref->Unref());
+ EXPECT_TRUE(ref->Unref());
+}
+
+TEST_F(RefTest, ScopedUnref) {
+ { ScopedUnref unref(new MyRef); }
+ EXPECT_EQ(destroyed, 1);
+}
+
+TEST_F(RefTest, ScopedUnref_Nullptr) {
+ { ScopedUnref unref(nullptr); }
+ EXPECT_EQ(destroyed, 0);
+}
+
+} // namespace
+} // namespace core
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/status.cc b/tensorflow/core/lib/core/status.cc
new file mode 100644
index 0000000000..24ce842560
--- /dev/null
+++ b/tensorflow/core/lib/core/status.cc
@@ -0,0 +1,107 @@
+#include "tensorflow/core/public/status.h"
+#include <stdio.h>
+
+namespace tensorflow {
+
+Status::Status(tensorflow::error::Code code, StringPiece msg) {
+ assert(code != tensorflow::error::OK);
+ state_ = new State;
+ state_->code = code;
+ state_->msg = msg.ToString();
+}
+Status::~Status() { delete state_; }
+
+void Status::Update(const Status& new_status) {
+ if (ok()) {
+ *this = new_status;
+ }
+}
+
+void Status::SlowCopyFrom(const State* src) {
+ delete state_;
+ if (src == nullptr) {
+ state_ = nullptr;
+ } else {
+ state_ = new State(*src);
+ }
+}
+
+const string& Status::empty_string() {
+ static string* empty = new string;
+ return *empty;
+}
+
+string Status::ToString() const {
+ if (state_ == NULL) {
+ return "OK";
+ } else {
+ char tmp[30];
+ const char* type;
+ switch (code()) {
+ case tensorflow::error::CANCELLED:
+ type = "Cancelled";
+ break;
+ case tensorflow::error::UNKNOWN:
+ type = "Unknown";
+ break;
+ case tensorflow::error::INVALID_ARGUMENT:
+ type = "Invalid argument";
+ break;
+ case tensorflow::error::DEADLINE_EXCEEDED:
+ type = "Deadline exceeded";
+ break;
+ case tensorflow::error::NOT_FOUND:
+ type = "Not found";
+ break;
+ case tensorflow::error::ALREADY_EXISTS:
+ type = "Already exists";
+ break;
+ case tensorflow::error::PERMISSION_DENIED:
+ type = "Permission denied";
+ break;
+ case tensorflow::error::UNAUTHENTICATED:
+ type = "Unauthenticated";
+ break;
+ case tensorflow::error::RESOURCE_EXHAUSTED:
+ type = "Resource exhausted";
+ break;
+ case tensorflow::error::FAILED_PRECONDITION:
+ type = "Failed precondition";
+ break;
+ case tensorflow::error::ABORTED:
+ type = "Aborted";
+ break;
+ case tensorflow::error::OUT_OF_RANGE:
+ type = "Out of range";
+ break;
+ case tensorflow::error::UNIMPLEMENTED:
+ type = "Unimplemented";
+ break;
+ case tensorflow::error::INTERNAL:
+ type = "Internal";
+ break;
+ case tensorflow::error::UNAVAILABLE:
+ type = "Unavailable";
+ break;
+ case tensorflow::error::DATA_LOSS:
+ type = "Data loss";
+ break;
+ default:
+ snprintf(tmp, sizeof(tmp), "Unknown code(%d)",
+ static_cast<int>(code()));
+ type = tmp;
+ break;
+ }
+ string result(type);
+ result += ": ";
+ result += state_->msg;
+ return result;
+ }
+}
+
+std::ostream& operator<<(std::ostream& os, const Status& x) {
+ os << x.ToString();
+ return os;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/status_test.cc b/tensorflow/core/lib/core/status_test.cc
new file mode 100644
index 0000000000..3ef6b3302a
--- /dev/null
+++ b/tensorflow/core/lib/core/status_test.cc
@@ -0,0 +1,84 @@
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TEST(Status, OK) {
+ EXPECT_EQ(Status::OK().code(), error::OK);
+ EXPECT_EQ(Status::OK().error_message(), "");
+ EXPECT_OK(Status::OK());
+ ASSERT_OK(Status::OK());
+ EXPECT_EQ(Status::OK(), Status());
+ Status s;
+ EXPECT_TRUE(s.ok());
+}
+
+TEST(DeathStatus, CheckOK) {
+ Status status(errors::InvalidArgument("Invalid"));
+ ASSERT_DEATH(TF_CHECK_OK(status), "Invalid");
+}
+
+TEST(Status, Set) {
+ Status status;
+ status = Status(error::CANCELLED, "Error message");
+ EXPECT_EQ(status.code(), error::CANCELLED);
+ EXPECT_EQ(status.error_message(), "Error message");
+}
+
+TEST(Status, Copy) {
+ Status a(errors::InvalidArgument("Invalid"));
+ Status b(a);
+ ASSERT_EQ(a.ToString(), b.ToString());
+}
+
+TEST(Status, Assign) {
+ Status a(errors::InvalidArgument("Invalid"));
+ Status b;
+ b = a;
+ ASSERT_EQ(a.ToString(), b.ToString());
+}
+
+TEST(Status, Update) {
+ Status s;
+ s.Update(Status::OK());
+ ASSERT_TRUE(s.ok());
+ Status a(errors::InvalidArgument("Invalid"));
+ s.Update(a);
+ ASSERT_EQ(s.ToString(), a.ToString());
+ Status b(errors::Internal("Internal"));
+ s.Update(b);
+ ASSERT_EQ(s.ToString(), a.ToString());
+ s.Update(Status::OK());
+ ASSERT_EQ(s.ToString(), a.ToString());
+ ASSERT_FALSE(s.ok());
+}
+
+TEST(Status, EqualsOK) { ASSERT_EQ(Status::OK(), Status()); }
+
+TEST(Status, EqualsSame) {
+ Status a(errors::InvalidArgument("Invalid"));
+ Status b(errors::InvalidArgument("Invalid"));
+ ASSERT_EQ(a, b);
+}
+
+TEST(Status, EqualsCopy) {
+ const Status a(errors::InvalidArgument("Invalid"));
+ const Status b = a;
+ ASSERT_EQ(a, b);
+}
+
+TEST(Status, EqualsDifferentCode) {
+ const Status a(errors::InvalidArgument("message"));
+ const Status b(errors::Internal("message"));
+ ASSERT_NE(a, b);
+}
+
+TEST(Status, EqualsDifferentMessage) {
+ const Status a(errors::InvalidArgument("message"));
+ const Status b(errors::InvalidArgument("another"));
+ ASSERT_NE(a, b);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/status_test_util.h b/tensorflow/core/lib/core/status_test_util.h
new file mode 100644
index 0000000000..b3b4db429f
--- /dev/null
+++ b/tensorflow/core/lib/core/status_test_util.h
@@ -0,0 +1,20 @@
+#ifndef TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_
+#define TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/public/status.h"
+
+// Macros for testing the results of functions that return util::Status.
+
+#define EXPECT_OK(statement) EXPECT_EQ(::tensorflow::Status::OK(), (statement))
+#define ASSERT_OK(statement) ASSERT_EQ(::tensorflow::Status::OK(), (statement))
+
+// There are no EXPECT_NOT_OK/ASSERT_NOT_OK macros since they would not
+// provide much value (when they fail, they would just print the OK status
+// which conveys no more information than EXPECT_FALSE(status.ok());
+// If you want to check for particular errors, better alternatives are:
+// EXPECT_EQ(::util::Status(...expected error...), status.StripMessage());
+// EXPECT_THAT(status.ToString(), HasSubstr("expected error"));
+// Also, see testing/lib/util/status_util.h.
+
+#endif // TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_
diff --git a/tensorflow/core/lib/core/stringpiece.cc b/tensorflow/core/lib/core/stringpiece.cc
new file mode 100644
index 0000000000..57c5139f47
--- /dev/null
+++ b/tensorflow/core/lib/core/stringpiece.cc
@@ -0,0 +1,57 @@
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+#include <iostream>
+#include "tensorflow/core/lib/hash/hash.h"
+
+namespace tensorflow {
+
+size_t StringPiece::Hasher::operator()(StringPiece s) const {
+ return Hash64(s.data(), s.size());
+}
+
+std::ostream& operator<<(std::ostream& o, StringPiece piece) {
+ o.write(piece.data(), piece.size());
+ return o;
+}
+
+bool StringPiece::contains(StringPiece s) const {
+ return memmem(data_, size_, s.data_, s.size_) != nullptr;
+}
+
+size_t StringPiece::find(char c, size_t pos) const {
+ if (pos >= size_) {
+ return npos;
+ }
+ const char* result =
+ reinterpret_cast<const char*>(memchr(data_ + pos, c, size_ - pos));
+ return result != NULL ? result - data_ : npos;
+}
+
+// Search range is [0..pos] inclusive. If pos == npos, search everything.
+size_t StringPiece::rfind(char c, size_t pos) const {
+ if (size_ == 0) return npos;
+ for (const char* p = data_ + std::min(pos, size_ - 1); p >= data_; p--) {
+ if (*p == c) {
+ return p - data_;
+ }
+ }
+ return npos;
+}
+
+bool StringPiece::Consume(StringPiece x) {
+ if (starts_with(x)) {
+ remove_prefix(x.size_);
+ return true;
+ }
+ return false;
+}
+
+StringPiece StringPiece::substr(size_t pos, size_t n) const {
+ if (pos > size_) pos = size_;
+ if (n > size_ - pos) n = size_ - pos;
+ return StringPiece(data_ + pos, n);
+}
+
+const StringPiece::size_type StringPiece::npos = size_type(-1);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h
new file mode 100644
index 0000000000..17d4b294e9
--- /dev/null
+++ b/tensorflow/core/lib/core/stringpiece.h
@@ -0,0 +1,159 @@
+// StringPiece is a simple structure containing a pointer into some external
+// storage and a size. The user of a StringPiece must ensure that the slice
+// is not used after the corresponding external storage has been
+// deallocated.
+//
+// Multiple threads can invoke const methods on a StringPiece without
+// external synchronization, but if any of the threads may call a
+// non-const method, all threads accessing the same StringPiece must use
+// external synchronization.
+
+#ifndef TENSORFLOW_LIB_CORE_STRINGPIECE_H_
+#define TENSORFLOW_LIB_CORE_STRINGPIECE_H_
+
+#include <assert.h>
+#include <stddef.h>
+#include <string.h>
+#include <iosfwd>
+#include <string>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+class StringPiece {
+ public:
+ typedef size_t size_type;
+
+ // Create an empty slice.
+ StringPiece() : data_(""), size_(0) {}
+
+ // Create a slice that refers to d[0,n-1].
+ StringPiece(const char* d, size_t n) : data_(d), size_(n) {}
+
+ // Create a slice that refers to the contents of "s"
+ StringPiece(const string& s) : data_(s.data()), size_(s.size()) {}
+
+ // Create a slice that refers to s[0,strlen(s)-1]
+ StringPiece(const char* s) : data_(s), size_(strlen(s)) {}
+
+ void set(const void* data, size_t len) {
+ data_ = reinterpret_cast<const char*>(data);
+ size_ = len;
+ }
+
+ // Return a pointer to the beginning of the referenced data
+ const char* data() const { return data_; }
+
+ // Return the length (in bytes) of the referenced data
+ size_t size() const { return size_; }
+
+ // Return true iff the length of the referenced data is zero
+ bool empty() const { return size_ == 0; }
+
+ typedef const char* const_iterator;
+ typedef const char* iterator;
+ iterator begin() const { return data_; }
+ iterator end() const { return data_ + size_; }
+
+ static const size_t npos;
+
+ // Return the ith byte in the referenced data.
+ // REQUIRES: n < size()
+ char operator[](size_t n) const {
+ assert(n < size());
+ return data_[n];
+ }
+
+ // Change this slice to refer to an empty array
+ void clear() {
+ data_ = "";
+ size_ = 0;
+ }
+
+ // Drop the first "n" bytes from this slice.
+ void remove_prefix(size_t n) {
+ assert(n <= size());
+ data_ += n;
+ size_ -= n;
+ }
+
+ void remove_suffix(size_t n) {
+ assert(size_ >= n);
+ size_ -= n;
+ }
+
+ size_t find(char c, size_t pos = 0) const;
+ size_t rfind(char c, size_t pos = npos) const;
+ bool contains(StringPiece s) const;
+
+ // Checks whether StringPiece starts with x and if so advances the beginning
+ // of it to past the match. It's basically a shortcut for starts_with
+ // followed by remove_prefix.
+ bool Consume(StringPiece x);
+
+ StringPiece substr(size_t pos, size_t n = npos) const;
+
+ struct Hasher {
+ size_t operator()(StringPiece arg) const;
+ };
+
+ // Return a string that contains the copy of the referenced data.
+ std::string ToString() const { return std::string(data_, size_); }
+
+ // Three-way comparison. Returns value:
+ // < 0 iff "*this" < "b",
+ // == 0 iff "*this" == "b",
+ // > 0 iff "*this" > "b"
+ int compare(StringPiece b) const;
+
+ // Return true iff "x" is a prefix of "*this"
+ bool starts_with(StringPiece x) const {
+ return ((size_ >= x.size_) && (memcmp(data_, x.data_, x.size_) == 0));
+ }
+ // Return true iff "x" is a suffix of "*this"
+ bool ends_with(StringPiece x) const {
+ return ((size_ >= x.size_) &&
+ (memcmp(data_ + (size_ - x.size_), x.data_, x.size_) == 0));
+ }
+
+ private:
+ const char* data_;
+ size_t size_;
+
+ // Intentionally copyable
+};
+
+inline bool operator==(StringPiece x, StringPiece y) {
+ return ((x.size() == y.size()) &&
+ (memcmp(x.data(), y.data(), x.size()) == 0));
+}
+
+inline bool operator!=(StringPiece x, StringPiece y) { return !(x == y); }
+
+inline bool operator<(StringPiece x, StringPiece y) { return x.compare(y) < 0; }
+inline bool operator>(StringPiece x, StringPiece y) { return x.compare(y) > 0; }
+inline bool operator<=(StringPiece x, StringPiece y) {
+ return x.compare(y) <= 0;
+}
+inline bool operator>=(StringPiece x, StringPiece y) {
+ return x.compare(y) >= 0;
+}
+
+inline int StringPiece::compare(StringPiece b) const {
+ const size_t min_len = (size_ < b.size_) ? size_ : b.size_;
+ int r = memcmp(data_, b.data_, min_len);
+ if (r == 0) {
+ if (size_ < b.size_)
+ r = -1;
+ else if (size_ > b.size_)
+ r = +1;
+ }
+ return r;
+}
+
+// allow StringPiece to be logged
+extern std::ostream& operator<<(std::ostream& o, tensorflow::StringPiece piece);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_STRINGPIECE_H_
diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc
new file mode 100644
index 0000000000..e9b84d3102
--- /dev/null
+++ b/tensorflow/core/lib/core/threadpool.cc
@@ -0,0 +1,108 @@
+#include "tensorflow/core/lib/core/threadpool.h"
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/tracing.h"
+
+namespace tensorflow {
+namespace thread {
+
+struct ThreadPool::Waiter {
+ condition_variable cv;
+ bool ready;
+};
+
+ThreadPool::ThreadPool(Env* env, const string& name, int num_threads)
+ : ThreadPool(env, ThreadOptions(), name, num_threads) {}
+
+ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
+ const string& name, int num_threads)
+ : name_(name) {
+ CHECK_GE(num_threads, 1);
+ string name_prefix = "tf_" + name_;
+ for (int i = 0; i < num_threads; i++) {
+ threads_.push_back(env->StartThread(thread_options, name_prefix,
+ [this]() { WorkerLoop(); }));
+ }
+}
+
+ThreadPool::~ThreadPool() {
+ {
+ // Wait for all work to get done.
+ mutex_lock l(mu_);
+
+ // Inform every thread to exit.
+ for (size_t i = 0; i < threads_.size(); ++i) {
+ pending_.push_back({nullptr, 0});
+ }
+
+ // Wakeup all waiters.
+ for (auto w : waiters_) {
+ w->ready = true;
+ w->cv.notify_one();
+ }
+ }
+
+ // Wait for threads to finish.
+ for (auto t : threads_) {
+ delete t;
+ }
+}
+
+bool ThreadPool::HasPendingClosures() const {
+ mutex_lock l(mu_);
+ return pending_.size() != 0;
+}
+
+void ThreadPool::Schedule(std::function<void()> fn) {
+ CHECK(fn != nullptr);
+ uint64 id = 0;
+ if (port::Tracing::IsActive()) {
+ id = port::Tracing::UniqueId();
+ port::Tracing::RecordEvent(port::Tracing::EventCategory::kScheduleClosure,
+ id);
+ }
+
+ mutex_lock l(mu_);
+ pending_.push_back({fn, id});
+ if (!waiters_.empty()) {
+ Waiter* w = waiters_.back();
+ waiters_.pop_back();
+ w->ready = true;
+ w->cv.notify_one();
+ }
+}
+
+void ThreadPool::WorkerLoop() {
+ port::Tracing::RegisterCurrentThread(name_.c_str());
+ mutex_lock l(mu_);
+ Waiter w;
+ while (true) {
+ while (pending_.empty()) {
+ // Wait for work to be assigned to me
+ w.ready = false;
+ waiters_.push_back(&w);
+ while (!w.ready) {
+ w.cv.wait(l);
+ }
+ }
+ // Pick up pending work
+ Item item = pending_.front();
+ pending_.pop_front();
+ if (item.fn == nullptr) {
+ break;
+ }
+ mu_.unlock();
+ if (item.id != 0) {
+ port::Tracing::ScopedActivity region(
+ port::Tracing::EventCategory::kRunClosure, item.id);
+ item.fn();
+ } else {
+ item.fn();
+ }
+ mu_.lock();
+ }
+}
+
+} // namespace thread
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h
new file mode 100644
index 0000000000..5cf780fa86
--- /dev/null
+++ b/tensorflow/core/lib/core/threadpool.h
@@ -0,0 +1,59 @@
+#ifndef TENSORFLOW_LIB_CORE_THREADPOOL_H_
+#define TENSORFLOW_LIB_CORE_THREADPOOL_H_
+
+#include <deque>
+#include <functional>
+#include <thread>
+#include <vector>
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace thread {
+
+class ThreadPool {
+ public:
+ // Construct a pool that contains "num_threads" threads with specified "name".
+ // env->StartThread() is used to create individual threads.
+ //
+ // REQUIRES: num_threads > 0
+ ThreadPool(Env* env, const string& name, int num_threads);
+
+ // Construct a pool that contains "num_threads" threads with specified "name".
+ // env->StartThread() is used to create individual threads.
+ //
+ // REQUIRES: num_threads > 0
+ ThreadPool(Env* env, const ThreadOptions& thread_options, const string& name,
+ int num_threads);
+
+ // Wait until all scheduled work has finished and then destroy the
+ // set of threads.
+ virtual ~ThreadPool();
+
+ // Schedule fn() for execution in the pool of threads.
+ virtual void Schedule(std::function<void()> fn);
+
+ virtual bool HasPendingClosures() const;
+
+ private:
+ struct Waiter;
+ struct Item {
+ std::function<void()> fn;
+ uint64 id;
+ };
+
+ void WorkerLoop();
+
+ const string name_;
+ mutable mutex mu_;
+ std::vector<Thread*> threads_; // All threads
+ std::vector<Waiter*> waiters_; // Stack of waiting threads.
+ std::deque<Item> pending_; // Queue of pending work
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ThreadPool);
+};
+
+} // namespace thread
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_CORE_THREADPOOL_H_
diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc
new file mode 100644
index 0000000000..f4909c445c
--- /dev/null
+++ b/tensorflow/core/lib/core/threadpool_test.cc
@@ -0,0 +1,93 @@
+#include "tensorflow/core/lib/core/threadpool.h"
+
+#include <atomic>
+
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/env.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace thread {
+
+static const int kNumThreads = 30;
+
+TEST(ThreadPool, Empty) {
+ for (int num_threads = 1; num_threads < kNumThreads; num_threads++) {
+ fprintf(stderr, "Testing with %d threads\n", num_threads);
+ ThreadPool pool(Env::Default(), "test", num_threads);
+ }
+}
+
+TEST(ThreadPool, DoWork) {
+ for (int num_threads = 1; num_threads < kNumThreads; num_threads++) {
+ fprintf(stderr, "Testing with %d threads\n", num_threads);
+ const int kWorkItems = 15;
+ bool work[kWorkItems];
+ for (int i = 0; i < kWorkItems; i++) {
+ work[i] = false;
+ }
+ {
+ ThreadPool pool(Env::Default(), "test", num_threads);
+ for (int i = 0; i < kWorkItems; i++) {
+ pool.Schedule([&work, i]() {
+ ASSERT_FALSE(work[i]);
+ work[i] = true;
+ });
+ }
+ }
+ for (int i = 0; i < kWorkItems; i++) {
+ ASSERT_TRUE(work[i]);
+ }
+ }
+}
+
+static void BM_Sequential(int iters) {
+ ThreadPool pool(Env::Default(), "test", kNumThreads);
+ // Decrement count sequentially until 0.
+ int count = iters;
+ mutex done_lock;
+ condition_variable done;
+ bool done_flag = false;
+ std::function<void()> work = [&pool, &count, &done_lock, &done, &done_flag,
+ &work]() {
+ if (count--) {
+ pool.Schedule(work);
+ } else {
+ mutex_lock l(done_lock);
+ done_flag = true;
+ done.notify_all();
+ }
+ };
+ work();
+ mutex_lock l(done_lock);
+ if (!done_flag) {
+ done.wait(l);
+ }
+}
+BENCHMARK(BM_Sequential);
+
+static void BM_Parallel(int iters) {
+ ThreadPool pool(Env::Default(), "test", kNumThreads);
+ // Decrement count concurrently until 0.
+ std::atomic_int_fast32_t count(iters);
+ mutex done_lock;
+ condition_variable done;
+ bool done_flag = false;
+ for (int i = 0; i < iters; ++i) {
+ pool.Schedule([&count, &done_lock, &done, &done_flag]() {
+ if (count.fetch_sub(1) == 1) {
+ mutex_lock l(done_lock);
+ done_flag = true;
+ done.notify_all();
+ }
+ });
+ }
+ mutex_lock l(done_lock);
+ if (!done_flag) {
+ done.wait(l);
+ }
+}
+BENCHMARK(BM_Parallel);
+
+} // namespace thread
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/array_slice.h b/tensorflow/core/lib/gtl/array_slice.h
new file mode 100644
index 0000000000..813fb126e3
--- /dev/null
+++ b/tensorflow/core/lib/gtl/array_slice.h
@@ -0,0 +1,299 @@
+// An ArraySlice<T> represents an immutable array of elements of type
+// T. It has a length "length", and a base pointer "ptr", and the
+// array it represents contains the elements "ptr[0] .. ptr[len-1]".
+// The backing store for the array is *not* owned by the ArraySlice
+// object, and clients must arrange for the backing store to remain
+// live while the ArraySlice object is in use.
+//
+// An ArraySlice<T> is somewhat analogous to a StringPiece, but for
+// array elements of type T.
+//
+// Implicit conversion operations are provided from types such as
+// std::vector<T> and util::gtl::InlinedVector<T, N>. Note that ArraySlice
+// objects constructed from types in this way may be invalidated by
+// any operations that mutate the underlying vector.
+//
+// One common use for ArraySlice is when passing arguments to a
+// routine where you want to be able to accept a variety of array
+// types (e.g. a vector, a util::gtl::InlinedVector, a C-style array,
+// etc.). The usual approach here is to have the client explicitly
+// pass in a pointer and a length, as in:
+//
+// void MyRoutine(const int* elems, int N) {
+// for (int i = 0; i < N; i++) { .. do something with elems[i] .. }
+// }
+//
+// Unfortunately, this leads to ugly and error-prone code at the call site:
+//
+// std::vector<int> my_vector;
+// MyRoutine(vector_as_array(&my_vector), my_vector.size());
+//
+// util::gtl::InlinedVector<int, 4> my_inline_vector;
+// MyRoutine(my_inline_vector.array(), my_inline_vector.size());
+//
+// int my_array[10];
+// MyRoutine(my_array, 10);
+//
+// Instead, you can use an ArraySlice as the argument to the routine:
+//
+// void MyRoutine(ArraySlice<int> a) {
+// for (int i = 0; i < a.size(); i++) { .. do something with a[i] .. }
+// }
+//
+// This makes the call sites cleaner, for the most part:
+//
+// std::vector<int> my_vector;
+// MyRoutine(my_vector);
+//
+// util::gtl::InlinedVector<int, 4> my_inline_vector;
+// MyRoutine(my_inline_vector);
+//
+// int my_array[10];
+// MyRoutine(my_array);
+//
+// int* my_array = new int[10];
+// MyRoutine(gtl::ArraySlice<int>(my_array, 10));
+//
+// MutableArraySlice<T> represents a mutable array of elements, and, like
+// ArraySlice, does not own the backing store. The implicit constructors it
+// provides allow functions not to worry about whether their mutable arguments
+// refer to vectors, arrays, proto2::RepeatedFields, etc.:
+//
+// void MyMutatingRoutine(MutableArraySlice<int> a) {
+// for (int i = 0; i < a.size(); i++) { .. mutate a[i] .. }
+// }
+//
+// std::vector<int> my_vector;
+// MyMutatingRoutine(&my_vector);
+//
+// int my_array[10];
+// MyMutatingRoutine(my_array);
+//
+// int* my_array = new int[10];
+// MyMutatingRoutine(gtl::MutableArraySlice<int>(my_array, 10));
+//
+// MyProto my_proto;
+// for (int i = 0; i < 10; ++i) { my_proto.add_value(i); }
+// MyMutatingRoutine(my_proto.mutable_value());
+
+#ifndef TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_
+#define TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_
+
+#include <initializer_list>
+#include <type_traits>
+#include <vector>
+
+#include "tensorflow/core/lib/gtl/array_slice_internal.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+namespace tensorflow {
+namespace gtl {
+
+template <typename T>
+class ArraySlice {
+ private:
+ typedef array_slice_internal::ArraySliceImpl<T> Impl;
+
+ public:
+ typedef T value_type;
+ typedef typename Impl::pointer pointer;
+ typedef typename Impl::const_pointer const_pointer;
+ typedef typename Impl::reference reference;
+ typedef typename Impl::const_reference const_reference;
+ typedef typename Impl::iterator iterator;
+ typedef typename Impl::const_iterator const_iterator;
+ typedef typename Impl::reverse_iterator reverse_iterator;
+ typedef typename Impl::const_reverse_iterator const_reverse_iterator;
+ typedef typename Impl::size_type size_type;
+ typedef typename Impl::difference_type difference_type;
+
+ static const size_type npos = Impl::npos;
+
+ ArraySlice() : impl_(nullptr, 0) {}
+ ArraySlice(const_pointer array, size_type length) : impl_(array, length) {}
+
+ // Implicit conversion constructors
+ ArraySlice(const std::vector<value_type>& v) // NOLINT(runtime/explicit)
+ : impl_(v.data(), v.size()) {}
+
+ template <size_t N>
+ ArraySlice(const value_type (&a)[N]) // NOLINT(runtime/explicit)
+ : impl_(a, N) {}
+
+ template <int N>
+ ArraySlice(const InlinedVector<value_type, N>& v) // NOLINT(runtime/explicit)
+ : impl_(v.array(), v.size()) {}
+
+ // The constructor for any class supplying 'data() const' that returns either
+ // const T* or a less const-qualified version of it, and 'some_integral_type
+ // size() const'. proto2::RepeatedField<T>, string and (since C++11)
+ // std::vector<T,A> and std::array<T, N> are examples of this. See
+ // array_slice_internal.h for details.
+ template <typename V,
+ typename = typename Impl::template EnableIfConvertibleFrom<V>>
+ ArraySlice(const V& v) // NOLINT(runtime/explicit)
+ : impl_(v) {}
+
+ // Implicitly constructs an ArraySlice from an initializer list. This makes it
+ // possible to pass a brace-enclosed initializer list to a function expecting
+ // an ArraySlice:
+ // void Process(ArraySlice<int> x);
+ // Process({1, 2, 3});
+ // The data referenced by the initializer_list must outlive this
+ // ArraySlice. For example, "ArraySlice<int> s={1,2};" and "return
+ // ArraySlice<int>({3,4});" are errors, as the resulting ArraySlice may
+ // reference data that is no longer valid.
+ ArraySlice(std::initializer_list<value_type> v) // NOLINT(runtime/explicit)
+ : impl_(v.begin(), v.size()) {}
+
+ // Substring of another ArraySlice.
+ // pos must be non-negative and <= x.length().
+ // len must be non-negative and will be pinned to at most x.length() - pos.
+ // If len==npos, the substring continues till the end of x.
+ ArraySlice(const ArraySlice& x, size_type pos, size_type len)
+ : impl_(x.impl_, pos, len) {}
+
+ const_pointer data() const { return impl_.data(); }
+ size_type size() const { return impl_.size(); }
+ size_type length() const { return size(); }
+ bool empty() const { return size() == 0; }
+
+ void clear() { impl_.clear(); }
+
+ const_reference operator[](size_type i) const { return impl_[i]; }
+ const_reference at(size_type i) const { return impl_.at(i); }
+ const_reference front() const { return impl_.front(); }
+ const_reference back() const { return impl_.back(); }
+
+ const_iterator begin() const { return impl_.begin(); }
+ const_iterator end() const { return impl_.end(); }
+ const_reverse_iterator rbegin() const { return impl_.rbegin(); }
+ const_reverse_iterator rend() const { return impl_.rend(); }
+
+ void remove_prefix(size_type n) { impl_.remove_prefix(n); }
+ void remove_suffix(size_type n) { impl_.remove_suffix(n); }
+ void pop_back() { remove_suffix(1); }
+ void pop_front() { remove_prefix(1); }
+
+ // These relational operators have the same semantics as the
+ // std::vector<T> relational operators: they do deep (elementwise)
+ // comparisons. Array slices are equal iff their size is the same
+ // and all their elements are equal.
+ bool operator==(ArraySlice<T> other) const { return impl_ == other.impl_; }
+ bool operator!=(ArraySlice<T> other) const { return impl_ != other.impl_; }
+
+ private:
+ Impl impl_;
+};
+
+// Mutable version of ArraySlice, which allows the clients to mutate the
+// underlying data. It is implicitly convertible to ArraySlice since it provides
+// the data() and size() methods with correct signatures. When a
+// MutableArraySlice is created from a pointer to a container (as opposed to raw
+// memory pointer), the pointer must not be null.
+//
+// A note on const-ness: "mutable" here refers to the mutability of the
+// underlying data, not of the slice itself. It is perfectly reasonable to have
+// a variable of type "const MutableArraySlice<T>"; this means that the bounds
+// of the view on the array cannot be changed, but the underlying data in the
+// array still may be modified. This is akin to a "T* const" pointer, as opposed
+// to a "const T*" pointer (corresponding to a non-const ArraySlice<T>).
+template <typename T>
+class MutableArraySlice {
+ private:
+ typedef array_slice_internal::MutableArraySliceImpl<T> Impl;
+
+ public:
+ typedef T value_type;
+ typedef typename Impl::pointer pointer;
+ typedef typename Impl::const_pointer const_pointer;
+ typedef typename Impl::reference reference;
+ typedef typename Impl::const_reference const_reference;
+ typedef typename Impl::iterator iterator;
+ typedef typename Impl::const_iterator const_iterator;
+ typedef typename Impl::reverse_iterator reverse_iterator;
+ typedef typename Impl::const_reverse_iterator const_reverse_iterator;
+ typedef typename Impl::size_type size_type;
+ typedef typename Impl::difference_type difference_type;
+
+ static const size_type npos = Impl::npos;
+
+ MutableArraySlice() : impl_(nullptr, 0) {}
+ MutableArraySlice(pointer array, size_type length) : impl_(array, length) {}
+
+ // Implicit conversion constructors
+ MutableArraySlice(std::vector<value_type>* v) // NOLINT(runtime/explicit)
+ : impl_(v->data(), v->size()) {}
+
+ template <size_t N>
+ MutableArraySlice(value_type (&a)[N]) // NOLINT(runtime/explicit)
+ : impl_(a, N) {}
+
+ template <int N>
+ MutableArraySlice(
+ InlinedVector<value_type, N>* v) // NOLINT(runtime/explicit)
+ : impl_(v->mutable_array(), v->size()) {}
+
+ // The constructor for any class supplying 'T* data()' or 'T* mutable_data()'
+ // (the former is called if both exist), and 'some_integral_type size()
+ // const'. proto2::RepeatedField is an example of this. Also supports string
+ // arguments, when T==char. The appropriate ctor is selected using SFINAE. See
+ // array_slice_internal.h for details.
+ template <typename V,
+ typename = typename Impl::template EnableIfConvertibleFrom<V>>
+ MutableArraySlice(V* v) // NOLINT(runtime/explicit)
+ : impl_(v) {}
+
+ // Substring of another MutableArraySlice.
+ // pos must be non-negative and <= x.length().
+ // len must be non-negative and will be pinned to at most x.length() - pos.
+ // If len==npos, the substring continues till the end of x.
+ MutableArraySlice(const MutableArraySlice& x, size_type pos, size_type len)
+ : impl_(x.impl_, pos, len) {}
+
+ // Accessors.
+ pointer data() const { return impl_.data(); }
+ size_type size() const { return impl_.size(); }
+ size_type length() const { return size(); }
+ bool empty() const { return size() == 0; }
+
+ void clear() { impl_.clear(); }
+
+ reference operator[](size_type i) const { return impl_[i]; }
+ reference at(size_type i) const { return impl_.at(i); }
+ reference front() const { return impl_.front(); }
+ reference back() const { return impl_.back(); }
+
+ iterator begin() const { return impl_.begin(); }
+ iterator end() const { return impl_.end(); }
+ reverse_iterator rbegin() const { return impl_.rbegin(); }
+ reverse_iterator rend() const { return impl_.rend(); }
+
+ void remove_prefix(size_type n) { impl_.remove_prefix(n); }
+ void remove_suffix(size_type n) { impl_.remove_suffix(n); }
+ void pop_back() { remove_suffix(1); }
+ void pop_front() { remove_prefix(1); }
+
+ bool operator==(ArraySlice<T> other) const {
+ return ArraySlice<T>(*this) == other;
+ }
+ bool operator!=(ArraySlice<T> other) const {
+ return ArraySlice<T>(*this) != other;
+ }
+
+ // DEPRECATED(jacobsa): Please use data() instead.
+ pointer mutable_data() const { return impl_.data(); }
+
+ private:
+ Impl impl_;
+};
+
+template <typename T>
+const typename ArraySlice<T>::size_type ArraySlice<T>::npos;
+template <typename T>
+const typename MutableArraySlice<T>::size_type MutableArraySlice<T>::npos;
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_
diff --git a/tensorflow/core/lib/gtl/array_slice_internal.h b/tensorflow/core/lib/gtl/array_slice_internal.h
new file mode 100644
index 0000000000..080f0a38d8
--- /dev/null
+++ b/tensorflow/core/lib/gtl/array_slice_internal.h
@@ -0,0 +1,253 @@
+// NOT FOR INCLUSION BY CLIENT CODE. This file is only to be included by
+// array_slice.h.
+
+// Helper functions and templates for ArraySlice.
+
+#ifndef TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
+#define TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
+
+#include <stddef.h>
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace gtl {
+namespace array_slice_internal {
+
+// Template logic for generic constructors.
+
+// Wrappers whose Get() delegates to the appropriate method of a container, and
+// is defined when this method exists. Delegates to the const method if C is a
+// const type.
+struct Data {
+ template <typename C>
+ static decltype(std::declval<C>().data()) Get(C* v) {
+ return v->data();
+ }
+};
+
+struct MutableData {
+ template <typename C>
+ static decltype(std::declval<C>().mutable_data()) Get(C* v) {
+ return v->mutable_data();
+ }
+};
+
+struct Size {
+ template <typename C>
+ static decltype(std::declval<C>().size()) Get(C* v) {
+ return v->size();
+ }
+};
+
+struct MutableStringData {
+ // Defined only for string.
+ static char* Get(string* v) { return v->empty() ? nullptr : &*v->begin(); }
+};
+
+// Checks whether M::Get(C*) is defined and has a return type R such that
+// Checker::valid<R>()==true.
+template <typename M, typename Checker, typename C>
+struct HasGetHelper : public M {
+ private:
+ struct None {};
+ // M::Get is selected when it is viable. Get(...) is selected otherwise.
+ using M::Get;
+ static None Get(...);
+
+ public:
+ static constexpr bool HasGet() {
+ using Result = decltype(Get(std::declval<C*>()));
+ return !std::is_same<Result, None>() && Checker::template valid<Result>();
+ }
+};
+
+// Defines HasGet() for a particular method, container, and checker. If
+// HasGet()==true, provides Get() that delegates to the method.
+template <typename M, typename Checker, typename C,
+ bool /*has_get*/ = HasGetHelper<M, Checker, C>::HasGet()>
+struct Wrapper {
+ static constexpr bool HasGet() { return false; }
+};
+
+template <typename M, typename Checker, typename C>
+struct Wrapper<M, Checker, C, true> {
+ static constexpr bool HasGet() { return true; }
+ static decltype(M::Get(std::declval<C*>())) Get(C* v) { return M::Get(v); }
+};
+
+// Type checker for a method returning an integral value.
+struct SizeChecker {
+ template <typename R>
+ static constexpr bool valid() {
+ return std::is_integral<R>::value;
+ }
+};
+
+// Type checker for a method returning either a pointer to T or a less const
+// version of that.
+template <typename T>
+struct DataChecker {
+ // We want to enable conversion from std::vector<T*> to ArraySlice<const T*>
+ // but
+ // disable conversion from std::vector<Derived> to ArraySlice<Base>. Here we
+ // use
+ // the fact that U** is convertible to Q* const* if and only if Q is the same
+ // type or a more cv-qualified version of U.
+ template <typename R>
+ static constexpr bool valid() {
+ return std::is_convertible<R*, T* const*>::value;
+ }
+};
+
+// Aliases to A if A::HasGet()==true, or to B otherwise.
+template <typename A, typename B>
+using FirstWithGet = typename std::conditional<A::HasGet(), A, B>::type;
+
+// Wraps C::data() const, returning a pointer to const data.
+template <typename T, typename C>
+using ContainerData = Wrapper<Data, DataChecker<const T>, const C>;
+
+// Wraps a method returning a pointer to mutable data. Prefers data() over
+// mutable_data(), and handles strings when T==char. If data() returns a pointer
+// to mutable data, it is most likely overloaded, but may also be a single
+// method 'T* C::data() const' in a non-STL-compliant container.
+template <typename T, typename C>
+using ContainerMutableData =
+ FirstWithGet<Wrapper<Data, DataChecker<T>, C>,
+ FirstWithGet<Wrapper<MutableData, DataChecker<T>, C>,
+ Wrapper<MutableStringData, DataChecker<T>, C>>>;
+
+// Wraps C::size() const.
+template <typename C>
+using ContainerSize = Wrapper<Size, SizeChecker, const C>;
+
+// Implementation class for ArraySlice and MutableArraySlice. In the case of
+// ArraySlice, T will be a const type; for MutableArraySlice, T will be a
+// mutable type.
+template <typename T>
+class ArraySliceImplBase {
+ public:
+ typedef T* pointer;
+ typedef const T* const_pointer;
+ typedef T& reference;
+ typedef const T& const_reference;
+ typedef pointer iterator;
+ typedef const_pointer const_iterator;
+ typedef std::reverse_iterator<iterator> reverse_iterator;
+ typedef std::reverse_iterator<const_iterator> const_reverse_iterator;
+ typedef size_t size_type;
+ typedef ptrdiff_t difference_type;
+
+ static const size_type npos = -1;
+
+ ArraySliceImplBase(pointer array, size_type length)
+ : ptr_(array), length_(length) {}
+
+ // Substring of another ArraySlice.
+ // pos must be non-negative and <= x.length().
+ // len must be non-negative and will be pinned to at most x.length() - pos.
+ ArraySliceImplBase(const ArraySliceImplBase& x, size_type pos, size_type len)
+ : ptr_(x.ptr_ + pos), length_(std::min(x.length_ - pos, len)) {}
+
+ // Some of the const methods below return pointers and references to mutable
+ // data. This is only the case in this internal class; ArraySlice and
+ // MutableArraySlice provide deep-constness.
+
+ pointer data() const { return ptr_; }
+ size_type size() const { return length_; }
+
+ void clear() {
+ ptr_ = nullptr;
+ length_ = 0;
+ }
+
+ reference operator[](size_type i) const { return ptr_[i]; }
+ reference at(size_type i) const {
+ DCHECK_LT(i, length_);
+ return ptr_[i];
+ }
+ reference front() const {
+ DCHECK_GT(length_, 0);
+ return ptr_[0];
+ }
+ reference back() const {
+ DCHECK_GT(length_, 0);
+ return ptr_[length_ - 1];
+ }
+
+ void remove_prefix(size_type n) {
+ DCHECK_GE(length_, n);
+ ptr_ += n;
+ length_ -= n;
+ }
+ void remove_suffix(size_type n) {
+ DCHECK_GE(length_, n);
+ length_ -= n;
+ }
+
+ iterator begin() const { return ptr_; }
+ iterator end() const { return ptr_ + length_; }
+ reverse_iterator rbegin() const { return reverse_iterator(end()); }
+ reverse_iterator rend() const { return reverse_iterator(begin()); }
+
+ bool operator==(const ArraySliceImplBase& other) const {
+ if (size() != other.size()) return false;
+ if (data() == other.data()) return true;
+ return std::equal(data(), data() + size(), other.data());
+ }
+ bool operator!=(const ArraySliceImplBase& other) const {
+ return !(*this == other);
+ }
+
+ private:
+ pointer ptr_;
+ size_type length_;
+};
+
+template <typename T>
+class ArraySliceImpl : public ArraySliceImplBase<const T> {
+ public:
+ using ArraySliceImplBase<const T>::ArraySliceImplBase;
+
+ // Defined iff the data and size accessors for the container C have been
+ // defined.
+ template <typename C>
+ using EnableIfConvertibleFrom =
+ typename std::enable_if<ContainerData<T, C>::HasGet() &&
+ ContainerSize<C>::HasGet()>::type;
+
+ // Constructs from a container when EnableIfConvertibleFrom is
+ // defined. std::addressof handles types with overloaded operator&.
+ template <typename C>
+ explicit ArraySliceImpl(const C& v)
+ : ArraySliceImplBase<const T>(ContainerData<T, C>::Get(std::addressof(v)),
+ ContainerSize<C>::Get(std::addressof(v))) {}
+};
+
+template <typename T>
+class MutableArraySliceImpl : public ArraySliceImplBase<T> {
+ public:
+ using ArraySliceImplBase<T>::ArraySliceImplBase;
+
+ template <typename C>
+ using EnableIfConvertibleFrom =
+ typename std::enable_if<ContainerMutableData<T, C>::HasGet() &&
+ ContainerSize<C>::HasGet()>::type;
+
+ template <typename C>
+ explicit MutableArraySliceImpl(C* v)
+ : ArraySliceImplBase<T>(ContainerMutableData<T, C>::Get(v),
+ ContainerSize<C>::Get(v)) {}
+};
+
+} // namespace array_slice_internal
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
diff --git a/tensorflow/core/lib/gtl/array_slice_test.cc b/tensorflow/core/lib/gtl/array_slice_test.cc
new file mode 100644
index 0000000000..33ee8fc8dd
--- /dev/null
+++ b/tensorflow/core/lib/gtl/array_slice_test.cc
@@ -0,0 +1,646 @@
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+#include <algorithm>
+#include <array>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/port.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace gtl {
+namespace {
+
+typedef ArraySlice<int> IntSlice;
+typedef ArraySlice<char> CharSlice;
+typedef MutableArraySlice<int> MutableIntSlice;
+typedef MutableArraySlice<char> MutableCharSlice;
+typedef std::vector<int> IntVec;
+
+// Append 0..len-1 to *v
+template <typename Vector>
+static void Fill(Vector* v, int len, int offset = 0) {
+ for (int i = 0; i < len; i++) {
+ v->push_back(i + offset);
+ }
+}
+
+static void TestHelper(const IntSlice& vorig, const IntVec& vec) {
+ IntSlice other; // To test the assignment return value.
+ IntSlice v = other = vorig;
+ const int len = vec.size();
+ EXPECT_EQ(v.size(), vec.size());
+
+ for (int i = 0; i < len; i++) {
+ EXPECT_EQ(v[i], vec[i]);
+ EXPECT_EQ(v.at(i), vec[i]);
+ }
+ EXPECT_EQ(v.begin(), gtl::vector_as_array(&vec));
+
+ int counter = 0;
+ for (IntSlice::iterator it = v.begin(); it != v.end(); ++it) {
+ EXPECT_EQ(counter, *it);
+ counter++;
+ }
+ EXPECT_EQ(counter, len);
+
+ counter = 0;
+ for (IntSlice::const_iterator it = v.begin(); it != v.end(); ++it) {
+ EXPECT_EQ(counter, *it);
+ counter++;
+ }
+ EXPECT_EQ(counter, len);
+
+ if (len > 0) {
+ EXPECT_EQ(0, v.front());
+ EXPECT_EQ(len - 1, v.back());
+ v.pop_back();
+ EXPECT_EQ(len - 1, v.size());
+ for (size_t i = 0; i < v.size(); ++i) {
+ EXPECT_EQ(i, v[i]);
+ }
+ if (len > 1) {
+ v.pop_front();
+ EXPECT_EQ(len - 2, v.size());
+ for (size_t i = 0; i < v.size(); ++i) {
+ EXPECT_EQ(i + 1, v[i]);
+ }
+ }
+ }
+}
+
+// The element access test that is applicable both when MutableArraySlice is
+// const and when it's not.
+template <class V>
+void MutableTestHelperTemplated(V v, int* ptr, const int len) {
+ CHECK_EQ(v.size(), len);
+
+ for (int i = 0; i < len; i++) {
+ EXPECT_EQ(ptr + i, &v[i]);
+ EXPECT_EQ(ptr + i, &v.at(i));
+ }
+ EXPECT_EQ(ptr, v.begin());
+ EXPECT_EQ(ptr + len, v.end());
+ EXPECT_EQ(ptr, v.data());
+
+ int counter = 0;
+ for (MutableIntSlice::const_iterator it = v.begin(); it != v.end(); ++it) {
+ EXPECT_EQ(ptr + counter, &*it);
+ counter++;
+ }
+ EXPECT_EQ(counter, len);
+
+ EXPECT_EQ(len, std::distance(v.rbegin(), v.rend()));
+
+ if (len > 0) {
+ EXPECT_EQ(ptr, &v.front());
+ EXPECT_EQ(ptr + len - 1, &v.back());
+ EXPECT_EQ(ptr + len - 1, &*v.rbegin());
+ EXPECT_EQ(ptr, &*(v.rend() - 1));
+ }
+}
+
+static void MutableTestHelper(const MutableIntSlice& vorig, int* ptr,
+ const int len) {
+ // Test the data accessors both when the MutableArraySlice is declared const,
+ // and when it is not.
+ MutableTestHelperTemplated<const MutableIntSlice&>(vorig, ptr, len);
+ MutableTestHelperTemplated<MutableIntSlice>(vorig, ptr, len);
+
+ MutableIntSlice other; // To test the assignment return value.
+ MutableIntSlice v = other = vorig;
+ EXPECT_EQ(ptr, v.mutable_data());
+
+ int counter = 0;
+ for (MutableIntSlice::iterator it = v.begin(); it != v.end(); ++it) {
+ EXPECT_EQ(ptr + counter, &*it);
+ counter++;
+ }
+ EXPECT_EQ(counter, len);
+
+ if (len > 0) {
+ // Test that elements are assignable.
+ v[0] = 1;
+ v.front() = 2;
+ v.back() = 5;
+ *v.mutable_data() = 4;
+ std::fill(v.begin(), v.end(), 5);
+ std::fill(v.rbegin(), v.rend(), 6);
+ // Test size-changing methods.
+ v.pop_back();
+ EXPECT_EQ(len - 1, v.size());
+ for (size_t i = 0; i < v.size(); ++i) {
+ EXPECT_EQ(ptr + i, &v[i]);
+ }
+ if (len > 1) {
+ v.pop_front();
+ EXPECT_EQ(len - 2, v.size());
+ for (size_t i = 0; i < v.size(); ++i) {
+ EXPECT_EQ(ptr + i + 1, &v[i]);
+ }
+ }
+ }
+}
+
+template <typename Vector>
+static void TestImplicitConversion(const IntSlice& v, const Vector& vec) {
+ EXPECT_EQ(v.size(), vec.size());
+ for (size_t i = 0; i < v.size(); i++) {
+ EXPECT_EQ(v[i], vec[i]);
+ }
+}
+
+template <typename Vector>
+static void TestImplicitConversion(const CharSlice& v, const Vector& vec) {
+ TestImplicitConversion(IntVec(v.begin(), v.end()), vec);
+}
+
+static void TestImplicitConversion(const MutableIntSlice& v, const int* data,
+ int size) {
+ EXPECT_EQ(size, v.size());
+ for (size_t i = 0; i < v.size(); i++) {
+ EXPECT_EQ(data + i, &v[i]);
+ }
+}
+
+static void TestImplicitConversion(const MutableCharSlice& v, const char* data,
+ int size) {
+ EXPECT_EQ(size, v.size());
+ for (size_t i = 0; i < v.size(); i++) {
+ EXPECT_EQ(data + i, &v[i]);
+ }
+}
+// A struct supplying the data(), mutable_data() and size() methods, just like
+// e.g. proto2::RepeatedField.
+struct RepeatedField {
+ std::vector<int> storage;
+ const int* data() const { return storage.data(); }
+ int* mutable_data() { return storage.data(); }
+ int size() const { return storage.size(); }
+};
+
+// A struct supplying the data() (both mutable and const versions) and
+// size(). It also supplies mutable_data() but we test that data() is selected
+// instead.
+struct ContainerWithOverloads {
+ std::vector<int> storage;
+ std::vector<int> wrong_storage;
+ const int* data() const { return storage.data(); }
+ int* data() { return storage.data(); }
+ // MutableArraySlice should not call mutable_data(), preferring data()
+ // instead.
+ int* mutable_data() { return wrong_storage.data(); }
+ int size() const { return storage.size(); }
+};
+
+// A struct supplying data() and size() methods.
+struct ContainerWithShallowConstData {
+ std::vector<int> storage;
+ int* data() const { return const_cast<int*>(storage.data()); }
+ int size() const { return storage.size(); }
+};
+
+TEST(IntSlice, Simple) {
+ for (int len = 0; len < 20; len++) {
+ IntVec vec;
+ Fill(&vec, len);
+ TestHelper(IntSlice(vec), vec);
+ TestHelper(IntSlice(vec.data(), vec.size()), vec);
+ }
+}
+
+TEST(IntSlice, WithPosAndLen) {
+ IntVec vec;
+ Fill(&vec, 20);
+ for (size_t len = 0; len < vec.size(); len++) {
+ IntVec subvec(vec.begin(), vec.begin() + len);
+ TestImplicitConversion(IntSlice(vec, 0, len), subvec);
+ TestImplicitConversion(IntSlice(IntSlice(vec), 0, len), subvec);
+ }
+ EXPECT_EQ(0, IntSlice(vec, 0, 0).size());
+ EXPECT_EQ(0, IntSlice(IntSlice(vec), 0, 0).size());
+ TestImplicitConversion(IntSlice(vec, 0, IntSlice::npos), vec);
+}
+
+TEST(IntSlice, Clear) {
+ for (int len = 0; len < 20; len++) {
+ IntVec vec;
+ Fill(&vec, len);
+ IntSlice v(vec);
+ v.clear();
+ EXPECT_EQ(0, v.size());
+ EXPECT_EQ(v.begin(), v.end());
+ }
+}
+
+TEST(IntSlice, Swap) {
+ for (int l1 = 0; l1 < 20; l1++) {
+ for (int l2 = 0; l2 < 20; l2++) {
+ IntVec avec, bvec;
+ Fill(&avec, l1);
+ Fill(&bvec, l2, 100);
+ IntSlice a(avec), b(bvec);
+ using std::swap;
+ swap(a, b);
+ EXPECT_EQ(l1, b.size());
+ EXPECT_EQ(l2, a.size());
+ for (int i = 0; i < l1; i++) {
+ EXPECT_EQ(i, b[i]);
+ }
+ for (int i = 0; i < l2; i++) {
+ EXPECT_EQ(100 + i, a[i]);
+ }
+ }
+ }
+}
+
+TEST(IntSlice, ImplicitConversion) {
+ for (int len = 0; len < 20; len++) {
+ IntVec vec;
+ Fill(&vec, len);
+ IntSlice slice;
+ slice = vec;
+ TestImplicitConversion(vec, vec);
+ TestImplicitConversion(slice, vec);
+ TestImplicitConversion(IntSlice(vec.data(), vec.size()), vec);
+ }
+}
+
+TEST(IntSlice, InlinedVectorConversion) {
+ for (int len = 0; len < 20; len++) {
+ InlinedVector<int, 4> inline_vec;
+ for (int i = 0; i < len; i++) {
+ inline_vec.push_back(i);
+ }
+ IntVec vec;
+ Fill(&vec, len);
+ IntSlice v = inline_vec; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(inline_vec, vec);
+ }
+}
+
+TEST(IntSlice, StaticArrayConversion) {
+ int array[20];
+ IntVec vec;
+ Fill(&vec, TF_ARRAYSIZE(array));
+ std::copy(vec.begin(), vec.end(), array);
+ IntSlice v = array; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(array, vec);
+}
+
+TEST(IntSlice, StdArrayConversion) {
+ std::array<int, 20> array;
+ IntVec vec;
+ Fill(&vec, array.size());
+ std::copy(vec.begin(), vec.end(), array.begin());
+
+ // Check assignment.
+ {
+ IntSlice v = array;
+ static_cast<void>(v);
+ }
+
+ // Check sub-slice initialization.
+ {
+ IntSlice v = {array, 10, 15};
+ static_cast<void>(v);
+ }
+
+ TestImplicitConversion(array, vec);
+}
+
+// Values according to the Fill function.
+static const int test_const_array[] = {0, 1, 2};
+
+TEST(IntSlice, ConstStaticArrayConversion) {
+ IntVec vec;
+ Fill(&vec, TF_ARRAYSIZE(test_const_array));
+ IntSlice v = test_const_array; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(test_const_array, vec);
+}
+
+TEST(IntSlice, RepeatedFieldConversion) {
+ RepeatedField repeated_field;
+ IntVec vec;
+ Fill(&vec, 20);
+ repeated_field.storage = vec;
+ IntSlice v = repeated_field; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(repeated_field, vec);
+}
+
+TEST(IntSlice, ContainerWithOverloadsConversion) {
+ ContainerWithOverloads container;
+ Fill(&container.storage, 20);
+ container.wrong_storage.resize(container.size());
+ IntSlice v = container; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(container, container.storage);
+}
+
+TEST(IntSlice, ContainerWithShallowConstDataConversion) {
+ ContainerWithShallowConstData container;
+ Fill(&container.storage, 20);
+ IntSlice v = container; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(container, container.storage);
+}
+
+TEST(IntSlice, MutableIntSliceConversion) {
+ IntVec vec(20);
+ IntSlice slice = MutableIntSlice(&vec);
+ EXPECT_EQ(vec.size(), slice.size());
+ EXPECT_EQ(vec.data(), slice.data());
+}
+
+TEST(IntSlice, Equality) {
+ IntVec vec1(20);
+ IntVec vec2(20);
+ // These two slices are from different vectors, but have the same
+ // size and have the same elements (right now). They should
+ // compare equal.
+ const IntSlice from1(vec1);
+ const IntSlice from2(vec2);
+ EXPECT_EQ(from1, from1);
+ EXPECT_EQ(from1, from2);
+
+ // This verifies that MutableArraySlices can be compared freely with
+ // ArraySlices.
+ const MutableIntSlice mutable_from1(&vec1);
+ const MutableIntSlice mutable_from2(&vec2);
+ EXPECT_EQ(from1, mutable_from1);
+ EXPECT_EQ(mutable_from1, from1);
+ EXPECT_EQ(mutable_from1, mutable_from2);
+ EXPECT_EQ(mutable_from2, mutable_from1);
+
+ // With a different size, the array slices should not be equal.
+ EXPECT_NE(from1, IntSlice(from1, 0, from1.size() - 1));
+
+ // With different contents, the array slices should not be equal.
+ ++vec2.back();
+ EXPECT_NE(from1, from2);
+}
+
+// Compile-asserts that the argument has the expected type.
+template <typename Expected, typename T>
+void CheckType(const T& value) {
+ testing::StaticAssertTypeEq<Expected, T>();
+}
+
+TEST(IntSlice, ExposesContainerTypesAndConsts) {
+ IntSlice slice;
+ const IntSlice const_slice;
+ CheckType<IntSlice::iterator>(slice.begin());
+ CheckType<IntSlice::const_iterator>(const_slice.end());
+ CheckType<IntSlice::const_reverse_iterator>(const_slice.rbegin());
+ CheckType<IntSlice::reverse_iterator>(slice.rend());
+ testing::StaticAssertTypeEq<int, IntSlice::value_type>();
+ testing::StaticAssertTypeEq<const int*, IntSlice::pointer>();
+ testing::StaticAssertTypeEq<const int&, IntSlice::const_reference>();
+ EXPECT_EQ(static_cast<IntSlice::size_type>(-1), IntSlice::npos);
+}
+
+void TestEmpty(IntSlice slice) { ASSERT_TRUE(slice.empty()); }
+
+void TestRange(IntSlice slice, int from, int to) {
+ ASSERT_EQ(to - from + 1, slice.size());
+ for (size_t i = 0; i < slice.size(); ++i) {
+ EXPECT_EQ(from + i, slice[i]);
+ }
+}
+
+TEST(IntSlice, InitializerListConversion) {
+ TestEmpty({});
+ TestRange({1}, 1, 1);
+ TestRange({10, 11, 12, 13}, 10, 13);
+}
+
+TEST(CharSlice, StringConversion) {
+ IntVec vec;
+ Fill(&vec, 20);
+ string str(vec.begin(), vec.end());
+ CharSlice v = str; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(str, vec);
+}
+
+TEST(IntPtrSlice, ConstConversion) {
+ int one = 1;
+ int two = 2;
+ std::vector<int*> vec;
+ vec.push_back(&one);
+ vec.push_back(&two);
+ ArraySlice<const int*> v = vec;
+ ASSERT_EQ(2, v.size());
+ EXPECT_EQ(&one, v[0]);
+ EXPECT_EQ(&two, v[1]);
+}
+
+TEST(MutableIntSlice, Simple) {
+ for (int len = 0; len < 20; len++) {
+ IntVec vec(len);
+ MutableTestHelper(MutableIntSlice(&vec), vec.data(), len);
+ MutableTestHelper(MutableIntSlice(vec.data(), vec.size()), vec.data(), len);
+ }
+}
+
+TEST(MutableIntSlice, WithPosAndLen) {
+ IntVec vec(20);
+ for (size_t len = 0; len < vec.size(); len++) {
+ TestImplicitConversion(MutableIntSlice(&vec, 0, len), vec.data(), len);
+ TestImplicitConversion(MutableIntSlice(MutableIntSlice(&vec), 0, len),
+ vec.data(), len);
+ }
+ EXPECT_EQ(0, MutableIntSlice(&vec, 0, 0).size());
+ EXPECT_EQ(0, MutableIntSlice(MutableIntSlice(&vec), 0, 0).size());
+ TestImplicitConversion(MutableIntSlice(&vec, 0, MutableIntSlice::npos),
+ vec.data(), vec.size());
+}
+
+TEST(MutableIntSlice, Clear) {
+ for (int len = 0; len < 20; len++) {
+ IntVec vec(len);
+ MutableIntSlice v(&vec);
+ v.clear();
+ EXPECT_EQ(0, v.size());
+ EXPECT_EQ(v.begin(), v.end());
+ }
+}
+
+TEST(MutableIntSlice, Swap) {
+ for (int l1 = 0; l1 < 20; l1++) {
+ for (int l2 = 0; l2 < 20; l2++) {
+ IntVec avec(l1), bvec(l2);
+ MutableIntSlice a(&avec), b(&bvec);
+ using std::swap;
+ swap(a, b);
+ EXPECT_EQ(l1, b.size());
+ EXPECT_EQ(l2, a.size());
+ for (int i = 0; i < l1; i++) {
+ EXPECT_EQ(&avec[i], &b[i]);
+ }
+ for (int i = 0; i < l2; i++) {
+ EXPECT_EQ(&bvec[i], &a[i]);
+ }
+ }
+ }
+}
+
+TEST(MutableIntSlice, ImplicitConversion) {
+ for (int len = 0; len < 20; len++) {
+ IntVec vec(len);
+ MutableIntSlice slice;
+ slice = &vec;
+ TestImplicitConversion(&vec, vec.data(), len);
+ TestImplicitConversion(slice, vec.data(), len);
+ TestImplicitConversion(MutableIntSlice(vec.data(), vec.size()), vec.data(),
+ len);
+ }
+}
+
+TEST(MutableIntSlice, InlinedVectorConversion) {
+ for (int len = 0; len < 20; len++) {
+ InlinedVector<int, 4> inline_vec;
+ for (int i = 0; i < len; i++) {
+ inline_vec.push_back(i);
+ }
+ MutableIntSlice v = &inline_vec; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(&inline_vec, inline_vec.array(), inline_vec.size());
+ }
+}
+
+TEST(MutableIntSlice, StaticArrayConversion) {
+ int array[20];
+ MutableIntSlice v = array; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(array, array, TF_ARRAYSIZE(array));
+}
+
+TEST(MutableIntSlice, StdArrayConversion) {
+ std::array<int, 20> array;
+
+ // Check assignment.
+ {
+ MutableIntSlice v = &array;
+ static_cast<void>(v);
+ }
+
+ // Check sub-slice initialization.
+ {
+ MutableIntSlice v = {&array, 10, 15};
+ static_cast<void>(v);
+ }
+
+ TestImplicitConversion(&array, &array[0], array.size());
+}
+
+TEST(MutableIntSlice, RepeatedFieldConversion) {
+ RepeatedField repeated_field;
+ Fill(&repeated_field.storage, 20);
+ MutableIntSlice v = &repeated_field; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(&repeated_field, repeated_field.storage.data(),
+ repeated_field.storage.size());
+}
+
+TEST(MutableIntSlice, ContainerWithOverloadsConversion) {
+ ContainerWithOverloads container;
+ Fill(&container.storage, 20);
+ container.wrong_storage.resize(container.size());
+ MutableIntSlice v = &container; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(&container, container.storage.data(),
+ container.storage.size());
+}
+
+TEST(MutableIntSlice, ContainerWithShallowConstDataConversion) {
+ ContainerWithShallowConstData container;
+ Fill(&container.storage, 20);
+ MutableIntSlice v = &container; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(&container, container.storage.data(),
+ container.storage.size());
+}
+
+TEST(MutableIntSlice, TypedefsAndConstants) {
+ testing::StaticAssertTypeEq<int, MutableIntSlice::value_type>();
+ testing::StaticAssertTypeEq<int*, MutableIntSlice::pointer>();
+ testing::StaticAssertTypeEq<const int*, MutableIntSlice::const_pointer>();
+ testing::StaticAssertTypeEq<int&, MutableIntSlice::reference>();
+ testing::StaticAssertTypeEq<const int&, MutableIntSlice::const_reference>();
+
+ EXPECT_EQ(static_cast<MutableIntSlice::size_type>(-1), MutableIntSlice::npos);
+}
+
+TEST(MutableIntSlice, IteratorsAndReferences) {
+ auto accept_pointer = [](int* x) {};
+ auto accept_reference = [](int& x) {};
+ auto accept_iterator = [](MutableIntSlice::iterator x) {};
+ auto accept_reverse_iterator = [](MutableIntSlice::reverse_iterator x) {};
+
+ int a[1];
+ MutableIntSlice s = a;
+
+ accept_pointer(s.data());
+ accept_pointer(s.mutable_data());
+ accept_iterator(s.begin());
+ accept_iterator(s.end());
+ accept_reverse_iterator(s.rbegin());
+ accept_reverse_iterator(s.rend());
+
+ accept_reference(s[0]);
+ accept_reference(s.at(0));
+ accept_reference(s.front());
+ accept_reference(s.back());
+}
+
+TEST(MutableIntSlice, IteratorsAndReferences_Const) {
+ auto accept_pointer = [](int* x) {};
+ auto accept_reference = [](int& x) {};
+ auto accept_iterator = [](MutableIntSlice::iterator x) {};
+ auto accept_reverse_iterator = [](MutableIntSlice::reverse_iterator x) {};
+
+ int a[1];
+ const MutableIntSlice s = a;
+
+ accept_pointer(s.data());
+ accept_pointer(s.mutable_data());
+ accept_iterator(s.begin());
+ accept_iterator(s.end());
+ accept_reverse_iterator(s.rbegin());
+ accept_reverse_iterator(s.rend());
+
+ accept_reference(s[0]);
+ accept_reference(s.at(0));
+ accept_reference(s.front());
+ accept_reference(s.back());
+}
+
+bool TestMutableOverload(MutableIntSlice slice) { return false; }
+
+bool TestMutableOverload(MutableCharSlice slice) { return true; }
+
+TEST(MutableCharSlice, StringConversion) {
+ for (int len = 0; len < 20; len++) {
+ string str(len, '\0');
+ MutableCharSlice v = &str; // Test assignment
+ static_cast<void>(v);
+ TestImplicitConversion(v, str.data(), str.size());
+ }
+ // Verify that only the correct overload is feasible. Note that this would
+ // fail if the string ctor was declared simply as MutableArraySlice(string*),
+ // since in that case both overloads would be feasible.
+ string str;
+ EXPECT_TRUE(TestMutableOverload(&str));
+}
+
+} // namespace
+} // namespace gtl
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/edit_distance.h b/tensorflow/core/lib/gtl/edit_distance.h
new file mode 100644
index 0000000000..82b6c2299f
--- /dev/null
+++ b/tensorflow/core/lib/gtl/edit_distance.h
@@ -0,0 +1,82 @@
+#ifndef TENSORFLOW_LIB_GTL_EDIT_DISTANCE_H_
+#define TENSORFLOW_LIB_GTL_EDIT_DISTANCE_H_
+
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+namespace tensorflow {
+namespace gtl {
+
+// Calculate the Levenshtein Edit Distance between two contiguous
+// sequences, s and t, of type T.
+//
+// The Levenshtein distance is a symmetric distance defined as the
+// smallest number of insertions, deletions, and substitutions
+// required to convert sequence s to t (and vice versa).
+// Note, this distance does not consider transpositions.
+//
+// For more details and a reference implementation, see:
+// https://en.wikipedia.org/wiki/Levenshtein_distance
+//
+// This implementation has time complexity O(|s|*|t|)
+// and space complexity O(min(|s|, |t|)), where
+// |x| := x.size()
+//
+// A simple call to LevenshteinDistance looks like:
+//
+// int64 dist = LevenshteinDistance("hi", "bye", std::equal_to<char>());
+//
+template <typename T, typename Cmp>
+inline int64 LevenshteinDistance(const gtl::ArraySlice<T>& s,
+ const gtl::ArraySlice<T>& t, const Cmp& cmp) {
+ const int64 s_size = s.size();
+ const int64 t_size = t.size();
+
+ if (s_size == 0) return t_size;
+ if (t_size == 0) return s_size;
+ if (s == t) return 0;
+ if (t_size > s_size) return LevenshteinDistance(t, s, cmp);
+
+ // Create work vectors
+ gtl::InlinedVector<int64, 32> scratch0(t_size + 1);
+ gtl::InlinedVector<int64, 32> scratch1(t_size + 1);
+
+ int64* previous = scratch0.data();
+ int64* current = scratch1.data();
+
+ // Initialize previous row of distances
+ std::iota(scratch0.begin(), scratch0.end(), 0);
+
+ for (int64 i = 0; i < s_size; ++i) {
+ // Swap current and previous rows for next iteration
+ std::swap(previous, current);
+
+ // Calculate current row distances from previous row
+ current[0] = i + 1;
+
+ // Fill in the rest of the row
+ for (int64 j = 0; j < t_size; ++j) {
+ const int64 cost = cmp(s[i], t[j]) ? 0 : 1;
+ current[j + 1] =
+ std::min(current[j] + 1, // deletion cost
+ std::min(previous[j + 1] + 1, // insertion cost
+ previous[j] + cost)); // substitution cost
+ }
+ }
+
+ return current[t_size];
+}
+
+template <typename Container1, typename Container2, typename Cmp>
+inline int64 LevenshteinDistance(const Container1& s, const Container2& t,
+ const Cmp& cmp) {
+ return LevenshteinDistance(
+ gtl::ArraySlice<typename Container1::value_type>(s.data(), s.size()),
+ gtl::ArraySlice<typename Container1::value_type>(t.data(), t.size()),
+ cmp);
+}
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_EDIT_DISTANCE_H_
diff --git a/tensorflow/core/lib/gtl/edit_distance_test.cc b/tensorflow/core/lib/gtl/edit_distance_test.cc
new file mode 100644
index 0000000000..0526ee0a05
--- /dev/null
+++ b/tensorflow/core/lib/gtl/edit_distance_test.cc
@@ -0,0 +1,125 @@
+#include "tensorflow/core/lib/gtl/edit_distance.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace gtl {
+namespace {
+
+class LevenshteinDistanceTest : public ::testing::Test {
+ protected:
+ std::vector<char> empty_;
+ std::string s1_;
+ std::string s1234_;
+ std::string s567_;
+ std::string kilo_;
+ std::string kilogram_;
+ std::string mother_;
+ std::string grandmother_;
+ std::string lower_;
+ std::string upper_;
+
+ void SetUp() override {
+ s1_ = "1";
+ s1234_ = "1234";
+ s567_ = "567";
+ kilo_ = "kilo";
+ kilogram_ = "kilogram";
+ mother_ = "mother";
+ grandmother_ = "grandmother";
+ lower_ = "lower case";
+ upper_ = "UPPER case";
+ }
+};
+
+TEST_F(LevenshteinDistanceTest, BothEmpty) {
+ ASSERT_EQ(LevenshteinDistance(empty_, empty_, std::equal_to<char>()), 0);
+}
+
+TEST_F(LevenshteinDistanceTest, OneEmpty) {
+ ASSERT_EQ(LevenshteinDistance(s1234_, empty_, std::equal_to<char>()), 4);
+ ASSERT_EQ(LevenshteinDistance(empty_, s567_, std::equal_to<char>()), 3);
+}
+
+TEST_F(LevenshteinDistanceTest, SingleElement) {
+ ASSERT_EQ(LevenshteinDistance(s1234_, s1_, std::equal_to<char>()), 3);
+ ASSERT_EQ(LevenshteinDistance(s1_, s1234_, std::equal_to<char>()), 3);
+}
+
+TEST_F(LevenshteinDistanceTest, Prefix) {
+ ASSERT_EQ(LevenshteinDistance(kilo_, kilogram_, std::equal_to<char>()), 4);
+ ASSERT_EQ(LevenshteinDistance(kilogram_, kilo_, std::equal_to<char>()), 4);
+}
+
+TEST_F(LevenshteinDistanceTest, Suffix) {
+ ASSERT_EQ(LevenshteinDistance(mother_, grandmother_, std::equal_to<char>()),
+ 5);
+ ASSERT_EQ(LevenshteinDistance(grandmother_, mother_, std::equal_to<char>()),
+ 5);
+}
+
+TEST_F(LevenshteinDistanceTest, DifferentComparisons) {
+ ASSERT_EQ(LevenshteinDistance(lower_, upper_, std::equal_to<char>()), 5);
+ ASSERT_EQ(LevenshteinDistance(upper_, lower_, std::equal_to<char>()), 5);
+ ASSERT_EQ(
+ LevenshteinDistance(gtl::ArraySlice<char>(lower_.data(), lower_.size()),
+ gtl::ArraySlice<char>(upper_.data(), upper_.size()),
+ std::equal_to<char>()),
+ 5);
+ auto no_case_cmp = [](char c1, char c2) {
+ return std::tolower(c1) == std::tolower(c2);
+ };
+ ASSERT_EQ(LevenshteinDistance(lower_, upper_, no_case_cmp), 3);
+ ASSERT_EQ(LevenshteinDistance(upper_, lower_, no_case_cmp), 3);
+}
+
+TEST_F(LevenshteinDistanceTest, Vectors) {
+ ASSERT_EQ(
+ LevenshteinDistance(std::string("algorithm"), std::string("altruistic"),
+ std::equal_to<char>()),
+ 6);
+}
+
+static void BM_EditDistanceHelper(int n, int len, bool completely_different) {
+ string a =
+ "The quick brown fox jumped over the lazy dog and on and on and on"
+ " Every good boy deserves fudge. In fact, this is a very long sentence "
+ " w/many bytes..";
+ while (a.size() < static_cast<size_t>(len)) {
+ a = a + a;
+ }
+ string b = a;
+ if (completely_different) {
+ for (size_t i = 0; i < b.size(); i++) {
+ b[i]++;
+ }
+ }
+ while (n-- > 0) {
+ LevenshteinDistance(gtl::ArraySlice<char>(a.data(), len),
+ gtl::ArraySlice<char>(b.data(), len),
+ std::equal_to<char>());
+ }
+}
+
+static void BM_EditDistanceSame(int n, int len) {
+ BM_EditDistanceHelper(n, len, false);
+}
+static void BM_EditDistanceDiff(int n, int len) {
+ BM_EditDistanceHelper(n, len, true);
+}
+
+BENCHMARK(BM_EditDistanceSame)->Arg(5);
+BENCHMARK(BM_EditDistanceSame)->Arg(50);
+BENCHMARK(BM_EditDistanceSame)->Arg(200);
+BENCHMARK(BM_EditDistanceSame)->Arg(1000);
+BENCHMARK(BM_EditDistanceDiff)->Arg(5);
+BENCHMARK(BM_EditDistanceDiff)->Arg(50);
+BENCHMARK(BM_EditDistanceDiff)->Arg(200);
+BENCHMARK(BM_EditDistanceDiff)->Arg(1000);
+
+} // namespace
+} // namespace gtl
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h
new file mode 100644
index 0000000000..c23075129c
--- /dev/null
+++ b/tensorflow/core/lib/gtl/inlined_vector.h
@@ -0,0 +1,839 @@
+// An InlinedVector<T,N,A> is like a std::vector<T,A>, except that storage
+// for sequences of length <= N are provided inline without requiring
+// any heap allocation. Typically N is very small (e.g., 4) so that
+// sequences that are expected to be short do not require allocations.
+//
+// Only some of the std::vector<> operations are currently implemented.
+// Other operations may be added as needed to facilitate migrating
+// code that uses std::vector<> to InlinedVector<>.
+//
+// NOTE: If you want an inlined version to replace use of a
+// std::vector<bool>, consider using util::bitmap::InlinedBitVector<NBITS>
+// in util/bitmap/inlined_bitvector.h
+//
+// TODO(billydonahue): change size_t to size_type where appropriate.
+
+#ifndef TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
+#define TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
+
+#include <stddef.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/types.h>
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <type_traits>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/gtl/manual_constructor.h"
+
+#include <initializer_list> // NOLINT(build/include_order)
+
+namespace tensorflow {
+namespace gtl {
+
+template <typename T, int N, typename A = std::allocator<T> >
+class InlinedVector {
+ public:
+ typedef A allocator_type;
+ typedef typename allocator_type::value_type value_type;
+ typedef typename allocator_type::pointer pointer;
+ typedef typename allocator_type::const_pointer const_pointer;
+ typedef typename allocator_type::reference reference;
+ typedef typename allocator_type::const_reference const_reference;
+ typedef typename allocator_type::size_type size_type;
+ typedef typename allocator_type::difference_type difference_type;
+ typedef pointer iterator;
+ typedef const_pointer const_iterator;
+
+ // Create an empty vector
+ InlinedVector();
+ explicit InlinedVector(const allocator_type& alloc);
+
+ // Create a vector with n copies of value_type().
+ explicit InlinedVector(size_t n);
+
+ // Create a vector with n copies of elem
+ InlinedVector(size_t n, const value_type& elem,
+ const allocator_type& alloc = allocator_type());
+
+ // Create and initialize with the elements [range_start .. range_end).
+ // The unused enable_if argument restricts this constructor so that it is
+ // elided when value_type is an integral type. This prevents ambiguous
+ // interpretation between a call to this constructor with two integral
+ // arguments and a call to the preceding (n, elem) constructor.
+ template <typename InputIterator>
+ InlinedVector(
+ InputIterator range_start, InputIterator range_end,
+ const allocator_type& alloc = allocator_type(),
+ typename std::enable_if<!std::is_integral<InputIterator>::value>::type* =
+ NULL)
+ : allocator_and_tag_(alloc) {
+ AppendRange(range_start, range_end);
+ }
+
+ InlinedVector(std::initializer_list<value_type> init,
+ const allocator_type& alloc = allocator_type())
+ : allocator_and_tag_(alloc) {
+ AppendRange(init.begin(), init.end());
+ }
+
+ InlinedVector(const InlinedVector& v);
+
+ ~InlinedVector() { clear(); }
+
+ InlinedVector& operator=(const InlinedVector& v) {
+ // Optimized to avoid reallocation.
+ // Prefer reassignment to copy construction for elements.
+ if (size() < v.size()) { // grow
+ reserve(v.size());
+ std::copy(v.begin(), v.begin() + size(), begin());
+ std::copy(v.begin() + size(), v.end(), std::back_inserter(*this));
+ } else { // maybe shrink
+ erase(begin() + v.size(), end());
+ std::copy(v.begin(), v.end(), begin());
+ }
+ return *this;
+ }
+
+ size_t size() const {
+ return allocated() ? allocation().size() : tag().size();
+ }
+
+ bool empty() const { return (size() == 0); }
+
+ // Return number of elements that can be stored in vector
+ // without requiring a reallocation of underlying memory
+ size_t capacity() const { return allocated() ? allocation().capacity() : N; }
+
+ // Return a pointer to the underlying array.
+ // Only result[0,size()-1] are defined.
+ const_pointer data() const {
+ return allocated() ? allocated_space() : inlined_space();
+ }
+ pointer data() { return allocated() ? allocated_space() : inlined_space(); }
+
+ // An older name for the more standard-friendly .data().
+ const_pointer array() const { return data(); }
+ pointer mutable_array() { return data(); }
+
+ // Remove all elements
+ void clear() {
+ size_t s = size();
+ if (allocated()) {
+ DestroyAllocated(allocated_space(), allocated_space() + s);
+ allocation().Dealloc(allocator());
+ } else {
+ DestroyInlined(inlined_space(), inlined_space() + s);
+ }
+ tag() = Tag();
+ }
+
+ // Return the ith element
+ // REQUIRES: 0 <= i < size()
+ const value_type& at(size_t i) const {
+ DCHECK_LT(i, size());
+ return array()[i];
+ }
+ const value_type& operator[](size_t i) const {
+ DCHECK_LT(i, size());
+ return array()[i];
+ }
+
+ // Return a non-const reference to the ith element
+ // REQUIRES: 0 <= i < size()
+ value_type& at(size_t i) {
+ DCHECK_LT(i, size());
+ return mutable_array()[i];
+ }
+ value_type& operator[](size_t i) {
+ DCHECK_LT(i, size());
+ return mutable_array()[i];
+ }
+
+ value_type& back() {
+ DCHECK(!empty());
+ return at(size() - 1);
+ }
+
+ const value_type& back() const {
+ DCHECK(!empty());
+ return at(size() - 1);
+ }
+
+ value_type& front() {
+ DCHECK(!empty());
+ return at(0);
+ }
+
+ const value_type& front() const {
+ DCHECK(!empty());
+ return at(0);
+ }
+
+ // Append t to the vector.
+ // Increases size() by one.
+ // Amortized complexity: O(1)
+ // Worst-case complexity: O(size())
+ void push_back(const value_type& t) {
+ size_t s = size();
+ DCHECK_LE(s, capacity());
+ if (s == capacity()) {
+ return GrowAndPushBack(t);
+ }
+ DCHECK_LT(s, capacity());
+
+ if (allocated()) {
+ ConstructAllocated(allocated_space() + s, t);
+ } else {
+ ConstructInlined(inlined_space() + s, t);
+ }
+
+ set_size_internal(s + 1);
+ }
+
+ void pop_back() {
+ DCHECK(!empty());
+ size_t s = size();
+ if (allocated()) {
+ DestroyAllocated(allocated_space() + s - 1, allocated_space() + s);
+ } else {
+ DestroyInlined(inlined_space() + s - 1, inlined_space() + s);
+ }
+ set_size_internal(s - 1);
+ }
+
+ // Resizes the vector to contain "n" elements.
+ // If "n" is smaller than the initial size, extra elements are destroyed.
+ // If "n" is larger than the initial size, enough copies of "elem"
+ // are appended to increase the size to "n". If "elem" is omitted,
+ // new elements are value-initialized.
+ void resize(size_t n);
+ void resize(size_t n, const value_type& elem);
+
+ iterator begin() { return mutable_array(); }
+ const_iterator begin() const { return array(); }
+
+ iterator end() { return mutable_array() + size(); }
+ const_iterator end() const { return array() + size(); }
+
+ iterator insert(iterator pos, const value_type& v);
+
+ iterator erase(iterator pos) {
+ DCHECK_LT(pos, end());
+ DCHECK_GE(pos, begin());
+ std::copy(pos + 1, end(), pos);
+ pop_back();
+ return pos;
+ }
+
+ iterator erase(iterator first, iterator last);
+
+ // Enlarges the underlying representation so it can hold at least
+ // "n" elements without reallocation.
+ // Does not change size() or the actual contents of the vector.
+ void reserve(size_t n) {
+ if (n > capacity()) {
+ // Make room for new elements
+ EnlargeBy(n - size());
+ }
+ }
+
+ // Swap the contents of *this with other.
+ // REQUIRES: value_type is swappable and copyable.
+ void swap(InlinedVector& other);
+
+ allocator_type get_allocator() const { return allocator(); }
+
+ private:
+ struct AllocatorTraits {
+ typedef typename allocator_type::value_type value_type;
+ typedef typename allocator_type::pointer pointer;
+ typedef typename allocator_type::size_type size_type;
+
+ static void construct(allocator_type& a, // NOLINT(runtime/references)
+ pointer p) {
+ // Tricky: do we support non-copyable types, or support allocators
+ // that do special things with construct()? Non-copyable types are
+ // needed today, so they are more important. When we sort out the
+ // Android NDK C++11 problem, we will be able to use the proper
+ // std::allocator_traits<A>::construct(p, ...).
+ //
+ // a.construct(p, value_type());
+ new (p) value_type();
+ }
+ static void construct(allocator_type& a, // NOLINT(runtime/references)
+ pointer p, const value_type& t) {
+ a.construct(p, t);
+ }
+ static void destroy(allocator_type& a, // NOLINT(runtime/references)
+ pointer p) {
+ a.destroy(p);
+ }
+ static pointer allocate(allocator_type& a, // NOLINT(runtime/references)
+ size_type n) {
+ return a.allocate(n);
+ }
+ static void deallocate(allocator_type& a, // NOLINT(runtime/references)
+ pointer p, size_type n) {
+ a.deallocate(p, n);
+ }
+ };
+
+ // If the vector is inlined, holds the size of the vector.
+ // If the vector is allocated, holds the special value kAllocated,
+ // and the size is stored in the vector's Allocation.
+ class Tag {
+ public:
+ Tag() : size_(0) {}
+ size_t size() const { return size_; }
+ void set_size(size_t n) { size_ = n; }
+ bool allocated() const { return size_ == kAllocated; }
+ void set_allocated() { size_ = kAllocated; }
+
+ private:
+ static const size_t kAllocated = -1;
+ size_t size_;
+ };
+
+ // Derives from allocator_type to use the empty base class optimization.
+ // If the allocator_type is stateless, we can 'store'
+ // our instance of it for free.
+ class AllocatorAndTag : private allocator_type {
+ public:
+ explicit AllocatorAndTag(const allocator_type& a, Tag t = Tag())
+ : allocator_type(a), tag_(t) {}
+ Tag& tag() { return tag_; }
+ const Tag& tag() const { return tag_; }
+ allocator_type& allocator() { return *this; }
+ const allocator_type& allocator() const { return *this; }
+
+ private:
+ Tag tag_;
+ };
+
+ class Allocation {
+ public:
+ Allocation(allocator_type& a, // NOLINT(runtime/references)
+ size_t capacity)
+ : size_(0),
+ capacity_(capacity),
+ buffer_(AllocatorTraits::allocate(a, capacity_)) {}
+
+ void Dealloc(allocator_type& a) { // NOLINT(runtime/references)
+ AllocatorTraits::deallocate(a, buffer(), capacity());
+ }
+
+ size_t size() const { return size_; }
+ void set_size(size_t s) { size_ = s; }
+ size_t capacity() const { return capacity_; }
+ const value_type* buffer() const { return buffer_; }
+ value_type* buffer() { return buffer_; }
+
+ private:
+ size_t size_;
+ size_t capacity_;
+ value_type* buffer_;
+ };
+
+ const Tag& tag() const { return allocator_and_tag_.tag(); }
+ Tag& tag() { return allocator_and_tag_.tag(); }
+
+ Allocation& allocation() { return *rep_.allocation_storage.allocation.get(); }
+ const Allocation& allocation() const {
+ return *rep_.allocation_storage.allocation.get();
+ }
+ void init_allocation(const Allocation& allocation) {
+ rep_.allocation_storage.allocation.Init(allocation);
+ }
+
+ value_type* inlined_space() { return rep_.inlined_storage.inlined[0].get(); }
+ const value_type* inlined_space() const {
+ return rep_.inlined_storage.inlined[0].get();
+ }
+
+ value_type* allocated_space() { return allocation().buffer(); }
+ const value_type* allocated_space() const { return allocation().buffer(); }
+
+ const allocator_type& allocator() const {
+ return allocator_and_tag_.allocator();
+ }
+ allocator_type& allocator() { return allocator_and_tag_.allocator(); }
+
+ bool allocated() const { return tag().allocated(); }
+ void set_allocated() { return tag().set_allocated(); }
+
+ void set_size_internal(size_t n) {
+ if (allocated()) {
+ allocation().set_size(n);
+ } else {
+ tag().set_size(n);
+ }
+ }
+
+ // Enlarge the underlying representation so we can store size_ + delta elems.
+ // The size is not changed, and any newly added memory is not initialized.
+ void EnlargeBy(size_t delta);
+
+ void ResetAllocation(Allocation new_allocation) {
+ if (allocated()) {
+ DestroyAllocated(allocated_space(), allocated_space() + size());
+ DCHECK_EQ(begin(), allocated_space());
+ allocation().Dealloc(allocator());
+ allocation() = new_allocation;
+ } else {
+ DestroyInlined(inlined_space(), inlined_space() + size());
+ init_allocation(new_allocation); // bug: only init once
+ set_allocated();
+ }
+ }
+
+ void GrowAndPushBack(const value_type& t) {
+ DCHECK_EQ(size(), capacity());
+ const size_t s = size();
+
+ Allocation new_allocation(allocator(), 2 * capacity());
+ new_allocation.set_size(s + 1);
+
+ UninitializedCopyAllocated(array(), array() + s, new_allocation.buffer());
+ ConstructAllocated(new_allocation.buffer() + s, t);
+
+ ResetAllocation(new_allocation);
+ }
+
+ void InitAssign(size_t n);
+ void InitAssign(size_t n, const value_type& t);
+
+ void ConstructInlined(pointer p) { new (p) value_type(); }
+
+ void ConstructInlined(pointer p, const value_type& t) {
+ new (p) value_type(t);
+ }
+
+ void ConstructAllocated(pointer p) {
+ AllocatorTraits::construct(allocator(), p);
+ }
+ void ConstructAllocated(pointer p, const value_type& t) {
+ AllocatorTraits::construct(allocator(), p, t);
+ }
+
+ template <typename Iter>
+ void UninitializedCopyInlined(Iter src, Iter src_last, value_type* dst) {
+ std::uninitialized_copy(src, src_last, dst);
+ }
+
+ template <typename Iter>
+ void UninitializedCopyAllocated(Iter src, Iter src_last, value_type* dst) {
+ for (; src != src_last; ++dst, ++src) ConstructAllocated(dst, *src);
+ }
+
+ void UninitializedFillInlined(value_type* dst, value_type* dst_last) {
+ for (; dst != dst_last; ++dst) ConstructInlined(dst);
+ }
+ void UninitializedFillInlined(value_type* dst, value_type* dst_last,
+ const value_type& t) {
+ std::uninitialized_fill(dst, dst_last, t);
+ }
+
+ void UninitializedFillAllocated(value_type* dst, value_type* dst_last) {
+ for (; dst != dst_last; ++dst) ConstructAllocated(dst);
+ }
+ void UninitializedFillAllocated(value_type* dst, value_type* dst_last,
+ const value_type& t) {
+ for (; dst != dst_last; ++dst) ConstructAllocated(dst, t);
+ }
+
+ // Destroy [ptr, ptr_last) in place.
+ void DestroyInlined(value_type* ptr, value_type* ptr_last);
+ void DestroyAllocated(value_type* ptr, value_type* ptr_last);
+
+ template <typename Iter>
+ void AppendRange(Iter first, Iter last, std::input_iterator_tag);
+
+ // Faster path for forward iterators.
+ template <typename Iter>
+ void AppendRange(Iter first, Iter last, std::forward_iterator_tag);
+
+ template <typename Iter>
+ void AppendRange(Iter first, Iter last);
+
+ AllocatorAndTag allocator_and_tag_;
+
+ // Either the inlined or allocated representation
+ union Rep {
+ // Use struct to perform indirection that solves a bizarre compilation
+ // error on Visual Studio (all known versions).
+ struct {
+ tensorflow::ManualConstructor<value_type> inlined[N];
+ } inlined_storage;
+ struct {
+ tensorflow::ManualConstructor<Allocation> allocation;
+ } allocation_storage;
+ } rep_;
+};
+
+template <typename T, int N, typename A>
+const size_t InlinedVector<T, N, A>::Tag::kAllocated;
+
+template <typename T, int N, typename A>
+inline void swap(InlinedVector<T, N, A>& a, InlinedVector<T, N, A>& b) {
+ a.swap(b);
+}
+
+template <typename T, int N, typename A>
+inline bool operator==(const InlinedVector<T, N, A>& a,
+ const InlinedVector<T, N, A>& b) {
+ return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin());
+}
+
+template <typename T, int N, typename A>
+inline bool operator!=(const InlinedVector<T, N, A>& a,
+ const InlinedVector<T, N, A>& b) {
+ return !(a == b);
+}
+
+template <typename T, int N, typename A>
+inline bool operator<(const InlinedVector<T, N, A>& a,
+ const InlinedVector<T, N, A>& b) {
+ return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end());
+}
+
+template <typename T, int N, typename A>
+inline bool operator>(const InlinedVector<T, N, A>& a,
+ const InlinedVector<T, N, A>& b) {
+ return b < a;
+}
+
+template <typename T, int N, typename A>
+inline bool operator<=(const InlinedVector<T, N, A>& a,
+ const InlinedVector<T, N, A>& b) {
+ return !(b < a);
+}
+
+template <typename T, int N, typename A>
+inline bool operator>=(const InlinedVector<T, N, A>& a,
+ const InlinedVector<T, N, A>& b) {
+ return !(a < b);
+}
+
+// ========================================
+// Implementation
+
+template <typename T, int N, typename A>
+inline InlinedVector<T, N, A>::InlinedVector()
+ : allocator_and_tag_(allocator_type()) {}
+
+template <typename T, int N, typename A>
+inline InlinedVector<T, N, A>::InlinedVector(const allocator_type& alloc)
+ : allocator_and_tag_(alloc) {}
+
+template <typename T, int N, typename A>
+inline InlinedVector<T, N, A>::InlinedVector(size_t n)
+ : allocator_and_tag_(allocator_type()) {
+ InitAssign(n);
+}
+
+template <typename T, int N, typename A>
+inline InlinedVector<T, N, A>::InlinedVector(size_t n, const value_type& elem,
+ const allocator_type& alloc)
+ : allocator_and_tag_(alloc) {
+ InitAssign(n, elem);
+}
+
+template <typename T, int N, typename A>
+inline InlinedVector<T, N, A>::InlinedVector(const InlinedVector& v)
+ : allocator_and_tag_(v.allocator()) {
+ reserve(v.size());
+ if (allocated()) {
+ UninitializedCopyAllocated(v.begin(), v.end(), allocated_space());
+ } else {
+ UninitializedCopyInlined(v.begin(), v.end(), inlined_space());
+ }
+ set_size_internal(v.size());
+}
+
+template <typename T, int N, typename A>
+inline void InlinedVector<T, N, A>::InitAssign(size_t n, const value_type& t) {
+ if (n > static_cast<size_t>(N)) {
+ Allocation new_allocation(allocator(), n);
+ init_allocation(new_allocation);
+ set_allocated();
+ UninitializedFillAllocated(allocated_space(), allocated_space() + n, t);
+ } else {
+ UninitializedFillInlined(inlined_space(), inlined_space() + n, t);
+ }
+ set_size_internal(n);
+}
+
+template <typename T, int N, typename A>
+inline void InlinedVector<T, N, A>::InitAssign(size_t n) {
+ if (n > static_cast<size_t>(N)) {
+ Allocation new_allocation(allocator(), n);
+ init_allocation(new_allocation);
+ set_allocated();
+ UninitializedFillAllocated(allocated_space(), allocated_space() + n);
+ } else {
+ UninitializedFillInlined(inlined_space(), inlined_space() + n);
+ }
+ set_size_internal(n);
+}
+
+template <typename T, int N, typename A>
+inline void InlinedVector<T, N, A>::resize(size_t n) {
+ size_t s = size();
+ if (n < s) {
+ erase(begin() + n, end());
+ return;
+ }
+ reserve(n);
+ DCHECK_GE(capacity(), n);
+
+ // Fill new space with elements constructed in-place.
+ if (allocated()) {
+ UninitializedFillAllocated(allocated_space() + s, allocated_space() + n);
+ } else {
+ UninitializedFillInlined(inlined_space() + s, inlined_space() + n);
+ }
+ set_size_internal(n);
+}
+
+template <typename T, int N, typename A>
+inline void InlinedVector<T, N, A>::resize(size_t n, const value_type& elem) {
+ size_t s = size();
+ if (n < s) {
+ erase(begin() + n, end());
+ return;
+ }
+ reserve(n);
+ DCHECK_GE(capacity(), n);
+
+ // Fill new space with copies of 'elem'.
+ if (allocated()) {
+ UninitializedFillAllocated(allocated_space() + s, allocated_space() + n,
+ elem);
+ } else {
+ UninitializedFillInlined(inlined_space() + s, inlined_space() + n, elem);
+ }
+ set_size_internal(n);
+}
+
+template <typename T, int N, typename A>
+typename InlinedVector<T, N, A>::iterator InlinedVector<T, N, A>::insert(
+ iterator pos, const value_type& v) {
+ DCHECK_GE(pos, begin());
+ DCHECK_LE(pos, end());
+ if (pos == end()) {
+ push_back(v);
+ return end() - 1;
+ }
+ size_t s = size();
+ size_t idx = std::distance(begin(), pos);
+ if (s == capacity()) {
+ EnlargeBy(1);
+ }
+ CHECK_LT(s, capacity());
+ pos = begin() + idx; // Reset 'pos' into a post-enlarge iterator.
+
+ if (allocated()) {
+ ConstructAllocated(allocated_space() + s, *(allocated_space() + s - 1));
+ std::copy_backward(pos, allocated_space() + s - 1, allocated_space() + s);
+ } else {
+ ConstructInlined(inlined_space() + s, *(inlined_space() + s - 1));
+ std::copy_backward(pos, inlined_space() + s - 1, inlined_space() + s);
+ }
+
+ *pos = v;
+
+ set_size_internal(s + 1);
+ return pos;
+}
+
+template <typename T, int N, typename A>
+typename InlinedVector<T, N, A>::iterator InlinedVector<T, N, A>::erase(
+ iterator first, iterator last) {
+ DCHECK_LE(begin(), first);
+ DCHECK_LE(first, last);
+ DCHECK_LE(last, end());
+
+ size_t s = size();
+ ptrdiff_t erase_gap = std::distance(first, last);
+
+ if (allocated()) {
+ std::copy(last, allocated_space() + s, first);
+ DestroyAllocated(allocated_space() + s - erase_gap, allocated_space() + s);
+ } else {
+ std::copy(last, inlined_space() + s, first);
+ DestroyInlined(inlined_space() + s - erase_gap, inlined_space() + s);
+ }
+
+ set_size_internal(size() - erase_gap);
+
+ return first;
+}
+
+template <typename T, int N, typename A>
+void InlinedVector<T, N, A>::swap(InlinedVector& other) {
+ using std::swap; // Augment ADL with std::swap.
+ if (&other == this) {
+ return;
+ }
+ if (allocated() && other.allocated()) {
+ // Both out of line, so just swap the tag, allocation, and allocator.
+ swap(tag(), other.tag());
+ swap(allocation(), other.allocation());
+ swap(allocator(), other.allocator());
+ return;
+ }
+ if (!allocated() && !other.allocated()) {
+ // Both inlined: swap up to smaller size, then move remaining elements.
+ InlinedVector* a = this;
+ InlinedVector* b = &other;
+ if (size() < other.size()) {
+ swap(a, b);
+ }
+
+ const size_t a_size = a->size();
+ const size_t b_size = b->size();
+ DCHECK_GE(a_size, b_size);
+ // 'a' is larger. Swap the elements up to the smaller array size.
+ std::swap_ranges(a->inlined_space(), a->inlined_space() + b_size,
+ b->inlined_space());
+
+ // Move the remaining elements: A[b_size,a_size) -> B[b_size,a_size)
+ b->UninitializedCopyInlined(a->inlined_space() + b_size,
+ a->inlined_space() + a_size,
+ b->inlined_space() + b_size);
+ a->DestroyInlined(a->inlined_space() + b_size, a->inlined_space() + a_size);
+
+ swap(a->tag(), b->tag());
+ swap(a->allocator(), b->allocator());
+ DCHECK_EQ(b->size(), a_size);
+ DCHECK_EQ(a->size(), b_size);
+ return;
+ }
+ // One is out of line, one is inline.
+ // We first move the elements from the inlined vector into the
+ // inlined space in the other vector. We then put the other vector's
+ // pointer/capacity into the originally inlined vector and swap
+ // the tags.
+ InlinedVector* a = this;
+ InlinedVector* b = &other;
+ if (a->allocated()) {
+ swap(a, b);
+ }
+ DCHECK(!a->allocated());
+ DCHECK(b->allocated());
+ const size_t a_size = a->size();
+ const size_t b_size = b->size();
+
+ // Made Local copies of size(), don't need tag() accurate anymore
+ swap(a->tag(), b->tag());
+
+ // Copy b_allocation out before b's union gets clobbered by inline_space.
+ Allocation b_allocation = b->allocation();
+
+ b->UninitializedCopyInlined(a->inlined_space(), a->inlined_space() + a_size,
+ b->inlined_space());
+ a->DestroyInlined(a->inlined_space(), a->inlined_space() + a_size);
+
+ a->allocation() = b_allocation;
+
+ if (a->allocator() != b->allocator()) {
+ swap(a->allocator(), b->allocator());
+ }
+
+ DCHECK_EQ(b->size(), a_size);
+ DCHECK_EQ(a->size(), b_size);
+}
+
+template <typename T, int N, typename A>
+void InlinedVector<T, N, A>::EnlargeBy(size_t delta) {
+ const size_t s = size();
+ DCHECK_LE(s, capacity());
+
+ size_t target = std::max(static_cast<size_t>(N), s + delta);
+
+ // Compute new capacity by repeatedly doubling current capacity
+ // TODO(psrc): Check and avoid overflow?
+ size_t new_capacity = capacity();
+ while (new_capacity < target) {
+ new_capacity <<= 1;
+ }
+
+ Allocation new_allocation(allocator(), new_capacity);
+ new_allocation.set_size(s);
+
+ UninitializedCopyAllocated(array(), array() + s, new_allocation.buffer());
+
+ ResetAllocation(new_allocation);
+}
+
+template <typename T, int N, typename A>
+inline void InlinedVector<T, N, A>::DestroyInlined(value_type* ptr,
+ value_type* ptr_last) {
+ for (value_type* p = ptr; p != ptr_last; ++p) {
+ p->~value_type();
+ }
+
+// Overwrite unused memory with 0xab so we can catch uninitialized usage.
+// Cast to void* to tell the compiler that we don't care that we might be
+// scribbling on a vtable pointer.
+#ifndef NDEBUG
+ if (ptr != ptr_last) {
+ memset(reinterpret_cast<void*>(ptr), 0xab, sizeof(*ptr) * (ptr_last - ptr));
+ }
+#endif
+}
+
+template <typename T, int N, typename A>
+inline void InlinedVector<T, N, A>::DestroyAllocated(value_type* ptr,
+ value_type* ptr_last) {
+ for (value_type* p = ptr; p != ptr_last; ++p) {
+ AllocatorTraits::destroy(allocator(), p);
+ }
+
+// Overwrite unused memory with 0xab so we can catch uninitialized usage.
+// Cast to void* to tell the compiler that we don't care that we might be
+// scribbling on a vtable pointer.
+#ifndef NDEBUG
+ if (ptr != ptr_last) {
+ memset(reinterpret_cast<void*>(ptr), 0xab, sizeof(*ptr) * (ptr_last - ptr));
+ }
+#endif
+}
+
+template <typename T, int N, typename A>
+template <typename Iter>
+inline void InlinedVector<T, N, A>::AppendRange(Iter first, Iter last,
+ std::input_iterator_tag) {
+ std::copy(first, last, std::back_inserter(*this));
+}
+
+template <typename T, int N, typename A>
+template <typename Iter>
+inline void InlinedVector<T, N, A>::AppendRange(Iter first, Iter last,
+ std::forward_iterator_tag) {
+ typedef typename std::iterator_traits<Iter>::difference_type Length;
+ Length length = std::distance(first, last);
+ reserve(size() + length);
+ if (allocated()) {
+ UninitializedCopyAllocated(first, last, allocated_space() + size());
+ } else {
+ UninitializedCopyInlined(first, last, inlined_space() + size());
+ }
+ set_size_internal(size() + length);
+}
+
+template <typename T, int N, typename A>
+template <typename Iter>
+inline void InlinedVector<T, N, A>::AppendRange(Iter first, Iter last) {
+ typedef typename std::iterator_traits<Iter>::iterator_category IterTag;
+ AppendRange(first, last, IterTag());
+}
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc
new file mode 100644
index 0000000000..ec5fe1eaa8
--- /dev/null
+++ b/tensorflow/core/lib/gtl/inlined_vector_test.cc
@@ -0,0 +1,905 @@
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+#include <list>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+typedef tensorflow::gtl::InlinedVector<int, 8> IntVec;
+
+// A type that counts number of live occurrences of the type
+static int64 instances = 0;
+class Instance {
+ public:
+ int value_;
+ explicit Instance(int x) : value_(x) { instances++; }
+ Instance(const Instance& x) : value_(x.value_) { instances++; }
+ ~Instance() { instances--; }
+
+ friend inline void swap(Instance& a, Instance& b) {
+ using std::swap;
+ swap(a.value_, b.value_);
+ }
+
+ friend std::ostream& operator<<(std::ostream& o, const Instance& v) {
+ return o << "[value:" << v.value_ << "]";
+ }
+};
+
+typedef tensorflow::gtl::InlinedVector<Instance, 8> InstanceVec;
+
+// A simple reference counted class to make sure that the proper elements are
+// destroyed in the erase(begin, end) test.
+class RefCounted {
+ public:
+ RefCounted(int value, int* count) : value_(value), count_(count) { Ref(); }
+
+ RefCounted(const RefCounted& v) : value_(v.value_), count_(v.count_) {
+ VLOG(5) << "[RefCounted: copy"
+ << " from count @" << v.count_ << "]";
+ Ref();
+ }
+
+ ~RefCounted() {
+ Unref();
+ count_ = NULL;
+ }
+
+ friend void swap(RefCounted& a, RefCounted& b) {
+ using std::swap;
+ swap(a.value_, b.value_);
+ swap(a.count_, b.count_);
+ }
+
+ RefCounted& operator=(RefCounted v) {
+ using std::swap;
+ swap(*this, v);
+ return *this;
+ }
+
+ void Ref() const {
+ CHECK(count_ != NULL);
+ ++(*count_);
+ VLOG(5) << "[Ref: refcount " << *count_ << " on count @" << count_ << "]";
+ }
+
+ void Unref() const {
+ --(*count_);
+ CHECK_GE(*count_, 0);
+ VLOG(5) << "[Unref: refcount " << *count_ << " on count @" << count_ << "]";
+ }
+
+ int count() const { return *count_; }
+
+ friend std::ostream& operator<<(std::ostream& o, const RefCounted& v) {
+ return o << "[value:" << v.value_ << ", count:" << *v.count_ << "]";
+ }
+
+ int value_;
+ int* count_;
+};
+
+typedef tensorflow::gtl::InlinedVector<RefCounted, 8> RefCountedVec;
+
+// A class with a vtable pointer
+class Dynamic {
+ public:
+ virtual ~Dynamic() {}
+
+ friend std::ostream& operator<<(std::ostream& o, const Dynamic& v) {
+ return o << "[Dynamic]";
+ }
+};
+
+typedef tensorflow::gtl::InlinedVector<Dynamic, 8> DynamicVec;
+
+// Append 0..len-1 to *v
+static void Fill(IntVec* v, int len, int offset = 0) {
+ for (int i = 0; i < len; i++) {
+ v->push_back(i + offset);
+ }
+}
+
+static IntVec Fill(int len, int offset = 0) {
+ IntVec v;
+ Fill(&v, len, offset);
+ return v;
+}
+
+TEST(IntVec, SimpleOps) {
+ for (int len = 0; len < 20; len++) {
+ IntVec v;
+ const IntVec& cv = v; // const alias
+
+ Fill(&v, len);
+ EXPECT_EQ(len, v.size());
+ EXPECT_LE(len, v.capacity());
+
+ for (int i = 0; i < len; i++) {
+ EXPECT_EQ(i, v[i]);
+ }
+ EXPECT_EQ(v.begin(), v.array());
+ EXPECT_EQ(v.begin(), v.mutable_array());
+
+ EXPECT_EQ(v.begin(), v.data());
+ EXPECT_EQ(cv.begin(), cv.data());
+
+ int counter = 0;
+ for (IntVec::iterator iter = v.begin(); iter != v.end(); ++iter) {
+ EXPECT_EQ(counter, *iter);
+ counter++;
+ }
+ EXPECT_EQ(counter, len);
+
+ counter = 0;
+ for (IntVec::const_iterator iter = v.begin(); iter != v.end(); ++iter) {
+ EXPECT_EQ(counter, *iter);
+ counter++;
+ }
+ EXPECT_EQ(counter, len);
+
+ if (len > 0) {
+ EXPECT_EQ(0, v.front());
+ EXPECT_EQ(len - 1, v.back());
+ v.pop_back();
+ EXPECT_EQ(len - 1, v.size());
+ for (size_t i = 0; i < v.size(); ++i) {
+ EXPECT_EQ(i, v[i]);
+ }
+ }
+ }
+}
+
+TEST(IntVec, Erase) {
+ for (int len = 1; len < 20; len++) {
+ for (int i = 0; i < len; ++i) {
+ IntVec v;
+ Fill(&v, len);
+ v.erase(v.begin() + i);
+ EXPECT_EQ(len - 1, v.size());
+ for (int j = 0; j < i; ++j) {
+ EXPECT_EQ(j, v[j]);
+ }
+ for (int j = i; j < len - 1; ++j) {
+ EXPECT_EQ(j + 1, v[j]);
+ }
+ }
+ }
+}
+
+// At the end of this test loop, the elements between [erase_begin, erase_end)
+// should have reference counts == 0, and all others elements should have
+// reference counts == 1.
+TEST(RefCountedVec, EraseBeginEnd) {
+ for (int len = 1; len < 20; ++len) {
+ for (int erase_begin = 0; erase_begin < len; ++erase_begin) {
+ for (int erase_end = erase_begin; erase_end <= len; ++erase_end) {
+ std::vector<int> counts(len, 0);
+ RefCountedVec v;
+ for (int i = 0; i < len; ++i) {
+ v.push_back(RefCounted(i, &counts[i]));
+ }
+
+ int erase_len = erase_end - erase_begin;
+
+ v.erase(v.begin() + erase_begin, v.begin() + erase_end);
+
+ EXPECT_EQ(len - erase_len, v.size());
+
+ // Check the elements before the first element erased.
+ for (int i = 0; i < erase_begin; ++i) {
+ EXPECT_EQ(i, v[i].value_);
+ }
+
+ // Check the elements after the first element erased.
+ for (size_t i = erase_begin; i < v.size(); ++i) {
+ EXPECT_EQ(i + erase_len, v[i].value_);
+ }
+
+ // Check that the elements at the beginning are preserved.
+ for (int i = 0; i < erase_begin; ++i) {
+ EXPECT_EQ(1, counts[i]);
+ }
+
+ // Check that the erased elements are destroyed
+ for (int i = erase_begin; i < erase_end; ++i) {
+ EXPECT_EQ(0, counts[i]);
+ }
+
+ // Check that the elements at the end are preserved.
+ for (int i = erase_end; i < len; ++i) {
+ EXPECT_EQ(1, counts[i]);
+ }
+ }
+ }
+ }
+}
+
+struct NoDefaultCtor {
+ explicit NoDefaultCtor(int /* x */) {}
+};
+struct NoCopy {
+ NoCopy() {}
+ NoCopy(const NoCopy& /* x */) = delete;
+};
+struct NoAssign {
+ NoAssign() {}
+ NoAssign& operator=(const NoAssign& /* x */) = delete;
+};
+TEST(InlinedVectorTest, NoDefaultCtor) {
+ tensorflow::gtl::InlinedVector<NoDefaultCtor, 1> v(10, NoDefaultCtor(2));
+ (void)v;
+}
+TEST(InlinedVectorTest, NoCopy) {
+ tensorflow::gtl::InlinedVector<NoCopy, 1> v(10);
+ (void)v;
+}
+TEST(InlinedVectorTest, NoAssign) {
+ tensorflow::gtl::InlinedVector<NoAssign, 1> v(10);
+ (void)v;
+}
+
+TEST(IntVec, Insert) {
+ for (int len = 0; len < 20; len++) {
+ for (int pos = 0; pos <= len; pos++) {
+ IntVec v;
+ Fill(&v, len);
+ v.insert(v.begin() + pos, 9999);
+ EXPECT_EQ(v.size(), len + 1);
+ for (int i = 0; i < pos; i++) {
+ EXPECT_EQ(v[i], i);
+ }
+ EXPECT_EQ(v[pos], 9999);
+ for (size_t i = pos + 1; i < v.size(); i++) {
+ EXPECT_EQ(v[i], i - 1);
+ }
+ }
+ }
+}
+
+TEST(RefCountedVec, InsertConstructorDestructor) {
+ // Make sure the proper construction/destruction happen during insert
+ // operations.
+ for (int len = 0; len < 20; len++) {
+ SCOPED_TRACE(len);
+ for (int pos = 0; pos <= len; pos++) {
+ SCOPED_TRACE(pos);
+ std::vector<int> counts(len, 0);
+ RefCountedVec v;
+ for (int i = 0; i < len; ++i) {
+ SCOPED_TRACE(i);
+ v.push_back(RefCounted(i, &counts[i]));
+ }
+
+ for (auto elem : counts) {
+ EXPECT_EQ(1, elem);
+ }
+
+ int inserted_count = 0;
+ RefCounted insert_element(9999, &inserted_count);
+ EXPECT_EQ(1, inserted_count);
+ v.insert(v.begin() + pos, insert_element);
+ EXPECT_EQ(2, inserted_count);
+ // Check that the elements at the end are preserved.
+ for (auto elem : counts) {
+ EXPECT_EQ(1, elem);
+ }
+ EXPECT_EQ(2, inserted_count);
+ }
+ }
+}
+
+TEST(IntVec, Resize) {
+ for (int len = 0; len < 20; len++) {
+ IntVec v;
+ Fill(&v, len);
+
+ // Try resizing up and down by k elements
+ static const int kResizeElem = 1000000;
+ for (int k = 0; k < 10; k++) {
+ // Enlarging resize
+ v.resize(len + k, kResizeElem);
+ EXPECT_EQ(len + k, v.size());
+ EXPECT_LE(len + k, v.capacity());
+ for (int i = 0; i < len + k; i++) {
+ if (i < len) {
+ EXPECT_EQ(i, v[i]);
+ } else {
+ EXPECT_EQ(kResizeElem, v[i]);
+ }
+ }
+
+ // Shrinking resize
+ v.resize(len, kResizeElem);
+ EXPECT_EQ(len, v.size());
+ EXPECT_LE(len, v.capacity());
+ for (int i = 0; i < len; i++) {
+ EXPECT_EQ(i, v[i]);
+ }
+ }
+ }
+}
+
+TEST(IntVec, InitWithLength) {
+ for (int len = 0; len < 20; len++) {
+ IntVec v(len, 7);
+ EXPECT_EQ(len, v.size());
+ EXPECT_LE(len, v.capacity());
+ for (int i = 0; i < len; i++) {
+ EXPECT_EQ(7, v[i]);
+ }
+ }
+}
+
+TEST(IntVec, CopyConstructorAndAssignment) {
+ for (int len = 0; len < 20; len++) {
+ IntVec v;
+ Fill(&v, len);
+ EXPECT_EQ(len, v.size());
+ EXPECT_LE(len, v.capacity());
+
+ IntVec v2(v);
+ EXPECT_EQ(v, v2);
+
+ for (int start_len = 0; start_len < 20; start_len++) {
+ IntVec v3;
+ Fill(&v3, start_len, 99); // Add dummy elements that should go away
+ v3 = v;
+ EXPECT_EQ(v, v3);
+ }
+ }
+}
+
+TEST(OverheadTest, Storage) {
+ // Check for size overhead.
+ // In particular, ensure that std::allocator doesn't cost anything to store.
+ // The union should be absorbing some of the allocation bookkeeping overhead
+ // in the larger vectors, leaving only the size_ field as overhead.
+ using tensorflow::gtl::InlinedVector;
+ EXPECT_EQ(3 * sizeof(int*),
+ sizeof(InlinedVector<int*, 1>) - 1 * sizeof(int*));
+ EXPECT_EQ(2 * sizeof(int*),
+ sizeof(InlinedVector<int*, 2>) - 2 * sizeof(int*));
+ EXPECT_EQ(1 * sizeof(int*),
+ sizeof(InlinedVector<int*, 3>) - 3 * sizeof(int*));
+ EXPECT_EQ(1 * sizeof(int*),
+ sizeof(InlinedVector<int*, 4>) - 4 * sizeof(int*));
+ EXPECT_EQ(1 * sizeof(int*),
+ sizeof(InlinedVector<int*, 5>) - 5 * sizeof(int*));
+ EXPECT_EQ(1 * sizeof(int*),
+ sizeof(InlinedVector<int*, 6>) - 6 * sizeof(int*));
+ EXPECT_EQ(1 * sizeof(int*),
+ sizeof(InlinedVector<int*, 7>) - 7 * sizeof(int*));
+ EXPECT_EQ(1 * sizeof(int*),
+ sizeof(InlinedVector<int*, 8>) - 8 * sizeof(int*));
+}
+
+TEST(IntVec, Clear) {
+ for (int len = 0; len < 20; len++) {
+ SCOPED_TRACE(len);
+ IntVec v;
+ Fill(&v, len);
+ v.clear();
+ EXPECT_EQ(0, v.size());
+ EXPECT_EQ(v.begin(), v.end());
+ }
+}
+
+TEST(IntVec, Reserve) {
+ for (size_t len = 0; len < 20; len++) {
+ IntVec v;
+ Fill(&v, len);
+
+ for (size_t newlen = 0; newlen < 100; newlen++) {
+ const int* start_rep = v.array();
+ v.reserve(newlen);
+ const int* final_rep = v.array();
+ if (newlen <= len) {
+ EXPECT_EQ(start_rep, final_rep);
+ }
+ EXPECT_LE(newlen, v.capacity());
+
+ // Filling up to newlen should not change rep
+ while (v.size() < newlen) {
+ v.push_back(0);
+ }
+ EXPECT_EQ(final_rep, v.array());
+ }
+ }
+}
+
+template <typename T>
+static std::vector<typename T::value_type> Vec(const T& src) {
+ std::vector<typename T::value_type> result;
+ for (const auto& elem : src) {
+ result.push_back(elem);
+ }
+ return result;
+}
+
+TEST(IntVec, SelfRefPushBack) {
+ std::vector<string> std_v;
+ tensorflow::gtl::InlinedVector<string, 4> v;
+ const string s = "A very long string to ensure heap.";
+ std_v.push_back(s);
+ v.push_back(s);
+ for (int i = 0; i < 20; ++i) {
+ EXPECT_EQ(std_v, Vec(v));
+
+ v.push_back(v.back());
+ std_v.push_back(std_v.back());
+ }
+ EXPECT_EQ(std_v, Vec(v));
+}
+
+TEST(IntVec, Swap) {
+ for (int l1 = 0; l1 < 20; l1++) {
+ SCOPED_TRACE(l1);
+ for (int l2 = 0; l2 < 20; l2++) {
+ SCOPED_TRACE(l2);
+ IntVec a = Fill(l1, 0);
+ IntVec b = Fill(l2, 100);
+ {
+ using std::swap;
+ swap(a, b);
+ }
+ EXPECT_EQ(l1, b.size());
+ EXPECT_EQ(l2, a.size());
+ for (int i = 0; i < l1; i++) {
+ SCOPED_TRACE(i);
+ EXPECT_EQ(i, b[i]);
+ }
+ for (int i = 0; i < l2; i++) {
+ SCOPED_TRACE(i);
+ EXPECT_EQ(100 + i, a[i]);
+ }
+ }
+ }
+}
+
+TEST(InstanceVec, Swap) {
+ for (int l1 = 0; l1 < 20; l1++) {
+ for (int l2 = 0; l2 < 20; l2++) {
+ InstanceVec a, b;
+ for (int i = 0; i < l1; i++) a.push_back(Instance(i));
+ for (int i = 0; i < l2; i++) b.push_back(Instance(100 + i));
+ EXPECT_EQ(l1 + l2, instances);
+ {
+ using std::swap;
+ swap(a, b);
+ }
+ EXPECT_EQ(l1 + l2, instances);
+ EXPECT_EQ(l1, b.size());
+ EXPECT_EQ(l2, a.size());
+ for (int i = 0; i < l1; i++) {
+ EXPECT_EQ(i, b[i].value_);
+ }
+ for (int i = 0; i < l2; i++) {
+ EXPECT_EQ(100 + i, a[i].value_);
+ }
+ }
+ }
+}
+
+TEST(IntVec, EqualAndNotEqual) {
+ IntVec a, b;
+ EXPECT_TRUE(a == b);
+ EXPECT_FALSE(a != b);
+
+ a.push_back(3);
+ EXPECT_FALSE(a == b);
+ EXPECT_TRUE(a != b);
+
+ b.push_back(3);
+ EXPECT_TRUE(a == b);
+ EXPECT_FALSE(a != b);
+
+ b.push_back(7);
+ EXPECT_FALSE(a == b);
+ EXPECT_TRUE(a != b);
+
+ a.push_back(6);
+ EXPECT_FALSE(a == b);
+ EXPECT_TRUE(a != b);
+
+ a.clear();
+ b.clear();
+ for (int i = 0; i < 100; i++) {
+ a.push_back(i);
+ b.push_back(i);
+ EXPECT_TRUE(a == b);
+ EXPECT_FALSE(a != b);
+
+ b[i] = b[i] + 1;
+ EXPECT_FALSE(a == b);
+ EXPECT_TRUE(a != b);
+
+ b[i] = b[i] - 1; // Back to before
+ EXPECT_TRUE(a == b);
+ EXPECT_FALSE(a != b);
+ }
+}
+
+TEST(IntVec, RelationalOps) {
+ IntVec a, b;
+ EXPECT_FALSE(a < b);
+ EXPECT_FALSE(b < a);
+ EXPECT_FALSE(a > b);
+ EXPECT_FALSE(b > a);
+ EXPECT_TRUE(a <= b);
+ EXPECT_TRUE(b <= a);
+ EXPECT_TRUE(a >= b);
+ EXPECT_TRUE(b >= a);
+ b.push_back(3);
+ EXPECT_TRUE(a < b);
+ EXPECT_FALSE(b < a);
+ EXPECT_FALSE(a > b);
+ EXPECT_TRUE(b > a);
+ EXPECT_TRUE(a <= b);
+ EXPECT_FALSE(b <= a);
+ EXPECT_FALSE(a >= b);
+ EXPECT_TRUE(b >= a);
+}
+
+TEST(InstanceVec, CountConstructorsDestructors) {
+ const int start = instances;
+ for (int len = 0; len < 20; len++) {
+ InstanceVec v;
+ for (int i = 0; i < len; i++) {
+ v.push_back(Instance(i));
+ }
+ EXPECT_EQ(start + len, instances);
+
+ { // Copy constructor should create 'len' more instances.
+ InstanceVec v_copy(v);
+ EXPECT_EQ(start + len + len, instances);
+ }
+ EXPECT_EQ(start + len, instances);
+
+ // Enlarging resize() must construct some objects
+ v.resize(len + 10, Instance(100));
+ EXPECT_EQ(start + len + 10, instances);
+
+ // Shrinking resize() must destroy some objects
+ v.resize(len, Instance(100));
+ EXPECT_EQ(start + len, instances);
+
+ // reserve() must not increase the number of initialized objects
+ v.reserve(len + 1000);
+ EXPECT_EQ(start + len, instances);
+
+ // pop_back() and erase() must destroy one object
+ if (len > 0) {
+ v.pop_back();
+ EXPECT_EQ(start + len - 1, instances);
+ if (!v.empty()) {
+ v.erase(v.begin());
+ EXPECT_EQ(start + len - 2, instances);
+ }
+ }
+ }
+ EXPECT_EQ(start, instances);
+}
+
+TEST(InstanceVec, CountConstructorsDestructorsOnAssignment) {
+ const int start = instances;
+ for (int len = 0; len < 20; len++) {
+ for (int longorshort = 0; longorshort <= 1; ++longorshort) {
+ InstanceVec longer, shorter;
+ for (int i = 0; i < len; i++) {
+ longer.push_back(Instance(i));
+ shorter.push_back(Instance(i));
+ }
+ longer.push_back(Instance(len));
+ EXPECT_EQ(start + len + len + 1, instances);
+
+ if (longorshort) {
+ shorter = longer;
+ EXPECT_EQ(start + (len + 1) + (len + 1), instances);
+ } else {
+ longer = shorter;
+ EXPECT_EQ(start + len + len, instances);
+ }
+ }
+ }
+ EXPECT_EQ(start, instances);
+}
+
+TEST(RangedConstructor, SimpleType) {
+ std::vector<int> source_v = {4, 5, 6};
+ // First try to fit in inline backing
+ tensorflow::gtl::InlinedVector<int, 4> v(source_v.begin(), source_v.end());
+ EXPECT_EQ(3, v.size());
+ EXPECT_EQ(4, v.capacity()); // Indication that we're still on inlined storage
+ EXPECT_EQ(4, v[0]);
+ EXPECT_EQ(5, v[1]);
+ EXPECT_EQ(6, v[2]);
+
+ // Now, force a re-allocate
+ tensorflow::gtl::InlinedVector<int, 2> realloc_v(source_v.begin(),
+ source_v.end());
+ EXPECT_EQ(3, realloc_v.size());
+ EXPECT_LT(2, realloc_v.capacity());
+ EXPECT_EQ(4, realloc_v[0]);
+ EXPECT_EQ(5, realloc_v[1]);
+ EXPECT_EQ(6, realloc_v[2]);
+}
+
+TEST(RangedConstructor, ComplexType) {
+ // We also use a list here to pass a different flavor of iterator (e.g. not
+ // random-access).
+ std::list<Instance> source_v = {Instance(0)};
+
+ // First try to fit in inline backing
+ tensorflow::gtl::InlinedVector<Instance, 1> v(source_v.begin(),
+ source_v.end());
+ EXPECT_EQ(1, v.size());
+ EXPECT_EQ(1, v.capacity()); // Indication that we're still on inlined storage
+ EXPECT_EQ(0, v[0].value_);
+
+ std::list<Instance> source_v2 = {Instance(0), Instance(1)};
+ // Now, force a re-allocate
+ tensorflow::gtl::InlinedVector<Instance, 1> realloc_v(source_v2.begin(),
+ source_v2.end());
+ EXPECT_EQ(2, realloc_v.size());
+ EXPECT_LT(1, realloc_v.capacity());
+ EXPECT_EQ(0, realloc_v[0].value_);
+ EXPECT_EQ(1, realloc_v[1].value_);
+}
+
+TEST(RangedConstructor, ElementsAreConstructed) {
+ std::vector<string> source_v = {"cat", "dog"};
+
+ // Force expansion and re-allocation of v. Ensures that when the vector is
+ // expanded that new elements are constructed.
+ tensorflow::gtl::InlinedVector<string, 1> v(source_v.begin(), source_v.end());
+ EXPECT_EQ("cat", v[0]);
+ EXPECT_EQ("dog", v[1]);
+}
+
+TEST(InitializerListConstructor, SimpleTypeWithInlineBacking) {
+ auto vec = tensorflow::gtl::InlinedVector<int, 4>{4, 5, 6};
+ EXPECT_EQ(3, vec.size());
+ EXPECT_EQ(4, vec.capacity());
+ EXPECT_EQ(4, vec[0]);
+ EXPECT_EQ(5, vec[1]);
+ EXPECT_EQ(6, vec[2]);
+}
+
+TEST(InitializerListConstructor, SimpleTypeWithReallocationRequired) {
+ auto vec = tensorflow::gtl::InlinedVector<int, 2>{4, 5, 6};
+ EXPECT_EQ(3, vec.size());
+ EXPECT_LE(3, vec.capacity());
+ EXPECT_EQ(4, vec[0]);
+ EXPECT_EQ(5, vec[1]);
+ EXPECT_EQ(6, vec[2]);
+}
+
+TEST(InitializerListConstructor, DisparateTypesInList) {
+ EXPECT_EQ((std::vector<int>{-7, 8}),
+ Vec(tensorflow::gtl::InlinedVector<int, 2>{-7, 8ULL}));
+
+ EXPECT_EQ(
+ (std::vector<string>{"foo", "bar"}),
+ Vec(tensorflow::gtl::InlinedVector<string, 2>{"foo", string("bar")}));
+}
+
+TEST(InitializerListConstructor, ComplexTypeWithInlineBacking) {
+ auto vec = tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0)};
+ EXPECT_EQ(1, vec.size());
+ EXPECT_EQ(1, vec.capacity());
+ EXPECT_EQ(0, vec[0].value_);
+}
+
+TEST(InitializerListConstructor, ComplexTypeWithReallocationRequired) {
+ auto vec =
+ tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0), Instance(1)};
+ EXPECT_EQ(2, vec.size());
+ EXPECT_LE(2, vec.capacity());
+ EXPECT_EQ(0, vec[0].value_);
+ EXPECT_EQ(1, vec[1].value_);
+}
+
+TEST(DynamicVec, DynamicVecCompiles) {
+ DynamicVec v;
+ (void)v;
+}
+
+#ifdef INLINED_VECTOR_HAS_ALLOC
+TEST(AllocatorSupportTest, Constructors) {
+ typedef STLCountingAllocator<int> MyAlloc;
+ typedef tensorflow::gtl::InlinedVector<int, 4, MyAlloc> AllocVec;
+ const int ia[] = {0, 1, 2, 3, 4, 5, 6, 7};
+ int64 allocated = 0;
+ MyAlloc alloc(&allocated);
+ { AllocVec TF_ATTRIBUTE_UNUSED v; }
+ { AllocVec TF_ATTRIBUTE_UNUSED v(alloc); }
+ { AllocVec TF_ATTRIBUTE_UNUSED v(ia, ia + arraysize(ia), alloc); }
+#ifdef LANG_CXX11
+ { AllocVec TF_ATTRIBUTE_UNUSED v({1, 2, 3}, alloc); }
+#endif // LANG_CXX11
+}
+
+TEST(AllocatorSupportTest, CountAllocations) {
+ typedef STLCountingAllocator<int> MyAlloc;
+ typedef tensorflow::gtl::InlinedVector<int, 4, MyAlloc> AllocVec;
+ const int ia[] = {0, 1, 2, 3, 4, 5, 6, 7};
+ int64 allocated = 0;
+ MyAlloc alloc(&allocated);
+ {
+ AllocVec TF_ATTRIBUTE_UNUSED v(ia, ia + 4, alloc);
+ EXPECT_THAT(allocated, 0);
+ }
+ EXPECT_THAT(allocated, 0);
+ {
+ AllocVec TF_ATTRIBUTE_UNUSED v(ia, ia + arraysize(ia), alloc);
+ EXPECT_THAT(allocated, v.size() * sizeof(int));
+ }
+ EXPECT_THAT(allocated, 0);
+}
+
+TEST(AllocatorSupportTest, SwapBothAllocated) {
+ typedef STLCountingAllocator<int> MyAlloc;
+ typedef tensorflow::gtl::InlinedVector<int, 4, MyAlloc> AllocVec;
+ int64 allocated1 = 0;
+ int64 allocated2 = 0;
+ {
+ const std::vector<int> ia1 = {0, 1, 2, 3, 4, 5, 6, 7};
+ const std::vector<int> ia2 = {0, 1, 2, 3, 4, 5, 6, 7, 8};
+ MyAlloc a1(&allocated1);
+ MyAlloc a2(&allocated2);
+ AllocVec v1(ia1.data(), ia1.data() + ia1.size(), a1);
+ AllocVec v2(ia2.data(), ia2.data() + ia2.size(), a2);
+ EXPECT_LT(v1.capacity(), v2.capacity());
+ EXPECT_THAT(allocated1, v1.capacity() * sizeof(int));
+ EXPECT_THAT(allocated2, v2.capacity() * sizeof(int));
+ v1.swap(v2);
+ EXPECT_EQ(ia2, Vec(v1));
+ EXPECT_EQ(ia1, Vec(v2));
+ EXPECT_THAT(allocated1, v2.capacity() * sizeof(int));
+ EXPECT_THAT(allocated2, v1.capacity() * sizeof(int));
+ }
+ EXPECT_THAT(allocated1, 0);
+ EXPECT_THAT(allocated2, 0);
+}
+
+TEST(AllocatorSupportTest, SwapOneAllocated) {
+ typedef STLCountingAllocator<int> MyAlloc;
+ typedef tensorflow::gtl::InlinedVector<int, 4, MyAlloc> AllocVec;
+ int64 allocated1 = 0;
+ int64 allocated2 = 0;
+ {
+ const std::vector<int> ia1 = {0, 1, 2, 3, 4, 5, 6, 7};
+ const std::vector<int> ia2 = {0, 1, 2, 3};
+ MyAlloc a1(&allocated1);
+ MyAlloc a2(&allocated2);
+ AllocVec v1(ia1.data(), ia1.data() + ia1.size(), a1);
+ AllocVec v2(ia2.data(), ia2.data() + ia2.size(), a2);
+ EXPECT_THAT(allocated1, v1.capacity() * sizeof(int));
+ EXPECT_THAT(allocated2, 0);
+ v1.swap(v2);
+ EXPECT_EQ(ia2, Vec(v1));
+ EXPECT_EQ(ia1, Vec(v2));
+ EXPECT_THAT(allocated1, v2.capacity() * sizeof(int));
+ EXPECT_THAT(allocated2, 0);
+ EXPECT_TRUE(v2.get_allocator() == a1);
+ EXPECT_TRUE(v1.get_allocator() == a2);
+ }
+ EXPECT_THAT(allocated1, 0);
+ EXPECT_THAT(allocated2, 0);
+}
+#endif // INLINED_VECTOR_HAS_ALLOC
+
+static void BM_InlinedVectorFill(int iters, int len) {
+ for (int i = 0; i < iters; i++) {
+ IntVec v;
+ for (int j = 0; j < len; j++) {
+ v.push_back(j);
+ }
+ }
+ testing::BytesProcessed((static_cast<int64>(iters) * len) * sizeof(int));
+}
+BENCHMARK(BM_InlinedVectorFill)->Range(0, 1024);
+
+static void BM_InlinedVectorFillRange(int iters, int len) {
+ std::unique_ptr<int[]> ia(new int[len]);
+ for (int j = 0; j < len; j++) {
+ ia[j] = j;
+ }
+ for (int i = 0; i < iters; i++) {
+ IntVec TF_ATTRIBUTE_UNUSED v(ia.get(), ia.get() + len);
+ }
+ testing::BytesProcessed((static_cast<int64>(iters) * len) * sizeof(int));
+}
+BENCHMARK(BM_InlinedVectorFillRange)->Range(0, 1024);
+
+static void BM_StdVectorFill(int iters, int len) {
+ for (int i = 0; i < iters; i++) {
+ std::vector<int> v;
+ for (int j = 0; j < len; j++) {
+ v.push_back(j);
+ }
+ }
+ testing::BytesProcessed((static_cast<int64>(iters) * len) * sizeof(int));
+}
+BENCHMARK(BM_StdVectorFill)->Range(0, 1024);
+
+namespace {
+struct Buffer { // some arbitrary structure for benchmarking.
+ char* base;
+ int length;
+ int capacity;
+ void* user_data;
+};
+} // anonymous namespace
+
+static void BM_InlinedVectorTenAssignments(int iters, int len) {
+ typedef tensorflow::gtl::InlinedVector<Buffer, 2> BufferVec;
+
+ BufferVec src;
+ src.resize(len);
+
+ iters *= 10;
+ BufferVec dst;
+ for (int i = 0; i < iters; i++) {
+ dst = src;
+ }
+}
+BENCHMARK(BM_InlinedVectorTenAssignments)
+ ->Arg(0)
+ ->Arg(1)
+ ->Arg(2)
+ ->Arg(3)
+ ->Arg(4)
+ ->Arg(20);
+
+static void BM_CreateFromInitializerList(int iters) {
+ for (; iters > 0; iters--) {
+ tensorflow::gtl::InlinedVector<int, 4> x{1, 2, 3};
+ (void)x[0];
+ }
+}
+BENCHMARK(BM_CreateFromInitializerList);
+
+namespace {
+
+struct LargeSwappable {
+ LargeSwappable() : d_(1024, 17) {}
+ ~LargeSwappable() {}
+ LargeSwappable(const LargeSwappable& o) : d_(o.d_) {}
+
+ friend void swap(LargeSwappable& a, LargeSwappable& b) {
+ using std::swap;
+ swap(a.d_, b.d_);
+ }
+
+ LargeSwappable& operator=(LargeSwappable o) {
+ using std::swap;
+ swap(*this, o);
+ return *this;
+ }
+
+ std::vector<int> d_;
+};
+
+} // namespace
+
+static void BM_LargeSwappableElements(int iters, int len) {
+ typedef tensorflow::gtl::InlinedVector<LargeSwappable, 32> Vec;
+ Vec a(len);
+ Vec b;
+ while (--iters >= 0) {
+ using std::swap;
+ swap(a, b);
+ }
+}
+BENCHMARK(BM_LargeSwappableElements)->Range(0, 1024);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/int_type.h b/tensorflow/core/lib/gtl/int_type.h
new file mode 100644
index 0000000000..d3fcb08d38
--- /dev/null
+++ b/tensorflow/core/lib/gtl/int_type.h
@@ -0,0 +1,343 @@
+// #status: LEGACY
+// #category: Miscellaneous
+// #summary: Integral types; prefer util/intops/strong_int.h
+// #bugs: Infrastructure > C++ Library Team > util
+//
+// IntType is a simple template class mechanism for defining "logical"
+// integer-like class types that support many of the same functionalities
+// as native integer types, but which prevent assignment, construction, and
+// other operations from other similar integer-like types. Essentially, the
+// template class IntType<IntTypeName, ValueType> (where ValueType assumes
+// valid scalar types such as int, uint, int32, etc) has the additional
+// property that it cannot be assigned to or constructed from other IntTypes
+// or native integer types of equal or implicitly convertible type.
+//
+// The class is useful for preventing mingling of integer variables with
+// different logical roles or units. Unfortunately, C++ provides relatively
+// good type-safety for user-defined classes but not for integer types. It is
+// essentially up to the user to use nice variable names and comments to prevent
+// accidental mismatches, such as confusing a user-index with a group-index or a
+// time-in-milliseconds with a time-in-seconds. The use of typedefs are limited
+// in that regard as they do not enforce type-safety.
+//
+// USAGE -----------------------------------------------------------------------
+//
+// DEFINE_INT_TYPE(IntTypeName, ValueType);
+//
+// where:
+// IntTypeName: is the desired (unique) name for the "logical" integer type.
+// ValueType: is one of the integral types as defined by base::is_integral
+// (see base/type_traits.h).
+//
+// DISALLOWED OPERATIONS / TYPE-SAFETY ENFORCEMENT -----------------------------
+//
+// Consider these definitions and variable declarations:
+// DEFINE_INT_TYPE(GlobalDocID, int64);
+// DEFINE_INT_TYPE(LocalDocID, int64);
+// GlobalDocID global;
+// LocalDocID local;
+//
+// The class IntType prevents:
+//
+// 1) Assignments of other IntTypes with different IntTypeNames.
+//
+// global = local; <-- Fails to compile!
+// local = global; <-- Fails to compile!
+//
+// 2) Explicit/implicit conversion from an IntType to another IntType.
+//
+// LocalDocID l(global); <-- Fails to compile!
+// LocalDocID l = global; <-- Fails to compile!
+//
+// void GetGlobalDoc(GlobalDocID global) { }
+// GetGlobalDoc(global); <-- Compiles fine, types match!
+// GetGlobalDoc(local); <-- Fails to compile!
+//
+// 3) Implicit conversion from an IntType to a native integer type.
+//
+// void GetGlobalDoc(int64 global) { ...
+// GetGlobalDoc(global); <-- Fails to compile!
+// GetGlobalDoc(local); <-- Fails to compile!
+//
+// void GetLocalDoc(int32 local) { ...
+// GetLocalDoc(global); <-- Fails to compile!
+// GetLocalDoc(local); <-- Fails to compile!
+//
+//
+// SUPPORTED OPERATIONS --------------------------------------------------------
+//
+// The following operators are supported: unary: ++ (both prefix and postfix),
+// +, -, ! (logical not), ~ (one's complement); comparison: ==, !=, <, <=, >,
+// >=; numerical: +, -, *, /; assignment: =, +=, -=, /=, *=; stream: <<. Each
+// operator allows the same IntTypeName and the ValueType to be used on
+// both left- and right-hand sides.
+//
+// It also supports an accessor value() returning the stored value as ValueType,
+// and a templatized accessor value<T>() method that serves as syntactic sugar
+// for static_cast<T>(var.value()). These accessors are useful when assigning
+// the stored value into protocol buffer fields and using it as printf args.
+//
+// The class also defines a hash functor that allows the IntType to be used
+// as key to hashable containers such as std::unordered_map and
+// std::unordered_set.
+//
+// We suggest using the IntTypeIndexedContainer wrapper around FixedArray and
+// STL vector (see int-type-indexed-container.h) if an IntType is intended to
+// be used as an index into these containers. These wrappers are indexed in a
+// type-safe manner using IntTypes to ensure type-safety.
+//
+// NB: this implementation does not attempt to abide by or enforce dimensional
+// analysis on these scalar types.
+//
+// EXAMPLES --------------------------------------------------------------------
+//
+// DEFINE_INT_TYPE(GlobalDocID, int64);
+// GlobalDocID global = 3;
+// cout << global; <-- Prints 3 to stdout.
+//
+// for (GlobalDocID i(0); i < global; ++i) {
+// cout << i;
+// } <-- Print(ln)s 0 1 2 to stdout
+//
+// DEFINE_INT_TYPE(LocalDocID, int64);
+// LocalDocID local;
+// cout << local; <-- Prints 0 to stdout it default
+// initializes the value to 0.
+//
+// local = 5;
+// local *= 2;
+// LocalDocID l(local);
+// cout << l + local; <-- Prints 20 to stdout.
+//
+// GenericSearchRequest request;
+// request.set_doc_id(global.value()); <-- Uses value() to extract the value
+// from the IntType class.
+//
+// REMARKS ---------------------------------------------------------------------
+//
+// The following bad usage is permissible although discouraged. Essentially, it
+// involves using the value*() accessors to extract the native integer type out
+// of the IntType class. Keep in mind that the primary reason for the IntType
+// class is to prevent *accidental* mingling of similar logical integer types --
+// and not type casting from one type to another.
+//
+// DEFINE_INT_TYPE(GlobalDocID, int64);
+// DEFINE_INT_TYPE(LocalDocID, int64);
+// GlobalDocID global;
+// LocalDocID local;
+//
+// global = local.value(); <-- Compiles fine.
+//
+// void GetGlobalDoc(GlobalDocID global) { ...
+// GetGlobalDoc(local.value()); <-- Compiles fine.
+//
+// void GetGlobalDoc(int64 global) { ...
+// GetGlobalDoc(local.value()); <-- Compiles fine.
+
+#ifndef TENSORFLOW_LIB_GTL_INT_TYPE_H_
+#define TENSORFLOW_LIB_GTL_INT_TYPE_H_
+
+#include <stddef.h>
+#include <functional>
+#include <iosfwd>
+#include <ostream> // NOLINT
+#include <unordered_map>
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace gtl {
+
+template <typename IntTypeName, typename _ValueType>
+class IntType;
+
+// Defines the IntType using value_type and typedefs it to int_type_name.
+// The struct int_type_name ## _tag_ trickery is needed to ensure that a new
+// type is created per int_type_name.
+#define TF_LIB_GTL_DEFINE_INT_TYPE(int_type_name, value_type) \
+ struct int_type_name##_tag_ {}; \
+ typedef ::tensorflow::gtl::IntType<int_type_name##_tag_, value_type> \
+ int_type_name;
+
+// Holds an integer value (of type ValueType) and behaves as a ValueType by
+// exposing assignment, unary, comparison, and arithmetic operators.
+//
+// The template parameter IntTypeName defines the name for the int type and must
+// be unique within a binary (the convenient DEFINE_INT_TYPE macro at the end of
+// the file generates a unique IntTypeName). The parameter ValueType defines
+// the integer type value (see supported list above).
+//
+// This class is NOT thread-safe.
+template <typename IntTypeName, typename _ValueType>
+class IntType {
+ public:
+ typedef _ValueType ValueType; // for non-member operators
+ typedef IntType<IntTypeName, ValueType> ThisType; // Syntactic sugar.
+
+ // Note that this may change from time to time without notice.
+ struct Hasher {
+ size_t operator()(const IntType& arg) const {
+ return static_cast<size_t>(arg.value());
+ }
+ };
+
+ public:
+ // Default c'tor initializing value_ to 0.
+ constexpr IntType() : value_(0) {}
+ // C'tor explicitly initializing from a ValueType.
+ constexpr explicit IntType(ValueType value) : value_(value) {}
+
+ // IntType uses the default copy constructor, destructor and assign operator.
+ // The defaults are sufficient and omitting them allows the compiler to add
+ // the move constructor/assignment.
+
+ // -- ACCESSORS --------------------------------------------------------------
+ // The class provides a value() accessor returning the stored ValueType value_
+ // as well as a templatized accessor that is just a syntactic sugar for
+ // static_cast<T>(var.value());
+ constexpr ValueType value() const { return value_; }
+
+ template <typename ValType>
+ constexpr ValType value() const {
+ return static_cast<ValType>(value_);
+ }
+
+ // -- UNARY OPERATORS --------------------------------------------------------
+ ThisType& operator++() { // prefix ++
+ ++value_;
+ return *this;
+ }
+ const ThisType operator++(int v) { // postfix ++
+ ThisType temp(*this);
+ ++value_;
+ return temp;
+ }
+ ThisType& operator--() { // prefix --
+ --value_;
+ return *this;
+ }
+ const ThisType operator--(int v) { // postfix --
+ ThisType temp(*this);
+ --value_;
+ return temp;
+ }
+
+ constexpr bool operator!() const { return value_ == 0; }
+ constexpr const ThisType operator+() const { return ThisType(value_); }
+ constexpr const ThisType operator-() const { return ThisType(-value_); }
+ constexpr const ThisType operator~() const { return ThisType(~value_); }
+
+// -- ASSIGNMENT OPERATORS ---------------------------------------------------
+// We support the following assignment operators: =, +=, -=, *=, /=, <<=, >>=
+// and %= for both ThisType and ValueType.
+#define INT_TYPE_ASSIGNMENT_OP(op) \
+ ThisType& operator op(const ThisType& arg_value) { \
+ value_ op arg_value.value(); \
+ return *this; \
+ } \
+ ThisType& operator op(ValueType arg_value) { \
+ value_ op arg_value; \
+ return *this; \
+ }
+ INT_TYPE_ASSIGNMENT_OP(+= );
+ INT_TYPE_ASSIGNMENT_OP(-= );
+ INT_TYPE_ASSIGNMENT_OP(*= );
+ INT_TYPE_ASSIGNMENT_OP(/= );
+ INT_TYPE_ASSIGNMENT_OP(<<= ); // NOLINT
+ INT_TYPE_ASSIGNMENT_OP(>>= ); // NOLINT
+ INT_TYPE_ASSIGNMENT_OP(%= );
+#undef INT_TYPE_ASSIGNMENT_OP
+
+ ThisType& operator=(ValueType arg_value) {
+ value_ = arg_value;
+ return *this;
+ }
+
+ private:
+ // The integer value of type ValueType.
+ ValueType value_;
+
+ static_assert(std::is_integral<ValueType>::value, "invalid integer type");
+} TF_PACKED;
+
+// -- NON-MEMBER STREAM OPERATORS ----------------------------------------------
+// We provide the << operator, primarily for logging purposes. Currently, there
+// seems to be no need for an >> operator.
+template <typename IntTypeName, typename ValueType>
+std::ostream& operator<<(std::ostream& os, // NOLINT
+ IntType<IntTypeName, ValueType> arg) {
+ return os << arg.value();
+}
+
+// -- NON-MEMBER ARITHMETIC OPERATORS ------------------------------------------
+// We support only the +, -, *, and / operators with the same IntType and
+// ValueType types. The reason is to allow simple manipulation on these IDs
+// when used as indices in vectors and arrays.
+//
+// NB: Although it is possible to do IntType * IntType and IntType / IntType,
+// it is probably non-sensical from a dimensionality analysis perspective.
+#define INT_TYPE_ARITHMETIC_OP(op) \
+ template <typename IntTypeName, typename ValueType> \
+ static inline constexpr IntType<IntTypeName, ValueType> operator op( \
+ IntType<IntTypeName, ValueType> id_1, \
+ IntType<IntTypeName, ValueType> id_2) { \
+ return IntType<IntTypeName, ValueType>(id_1.value() op id_2.value()); \
+ } \
+ template <typename IntTypeName, typename ValueType> \
+ static inline constexpr IntType<IntTypeName, ValueType> operator op( \
+ IntType<IntTypeName, ValueType> id, \
+ typename IntType<IntTypeName, ValueType>::ValueType arg_val) { \
+ return IntType<IntTypeName, ValueType>(id.value() op arg_val); \
+ } \
+ template <typename IntTypeName, typename ValueType> \
+ static inline constexpr IntType<IntTypeName, ValueType> operator op( \
+ typename IntType<IntTypeName, ValueType>::ValueType arg_val, \
+ IntType<IntTypeName, ValueType> id) { \
+ return IntType<IntTypeName, ValueType>(arg_val op id.value()); \
+ }
+INT_TYPE_ARITHMETIC_OP(+);
+INT_TYPE_ARITHMETIC_OP(-);
+INT_TYPE_ARITHMETIC_OP(*);
+INT_TYPE_ARITHMETIC_OP(/ );
+INT_TYPE_ARITHMETIC_OP(<< ); // NOLINT
+INT_TYPE_ARITHMETIC_OP(>> ); // NOLINT
+INT_TYPE_ARITHMETIC_OP(% );
+#undef INT_TYPE_ARITHMETIC_OP
+
+// -- NON-MEMBER COMPARISON OPERATORS ------------------------------------------
+// Static inline comparison operators. We allow all comparison operators among
+// the following types (OP \in [==, !=, <, <=, >, >=]:
+// IntType<IntTypeName, ValueType> OP IntType<IntTypeName, ValueType>
+// IntType<IntTypeName, ValueType> OP ValueType
+// ValueType OP IntType<IntTypeName, ValueType>
+#define INT_TYPE_COMPARISON_OP(op) \
+ template <typename IntTypeName, typename ValueType> \
+ static inline constexpr bool operator op( \
+ IntType<IntTypeName, ValueType> id_1, \
+ IntType<IntTypeName, ValueType> id_2) { \
+ return id_1.value() op id_2.value(); \
+ } \
+ template <typename IntTypeName, typename ValueType> \
+ static inline constexpr bool operator op( \
+ IntType<IntTypeName, ValueType> id, \
+ typename IntType<IntTypeName, ValueType>::ValueType val) { \
+ return id.value() op val; \
+ } \
+ template <typename IntTypeName, typename ValueType> \
+ static inline constexpr bool operator op( \
+ typename IntType<IntTypeName, ValueType>::ValueType val, \
+ IntType<IntTypeName, ValueType> id) { \
+ return val op id.value(); \
+ }
+INT_TYPE_COMPARISON_OP(== ); // NOLINT
+INT_TYPE_COMPARISON_OP(!= ); // NOLINT
+INT_TYPE_COMPARISON_OP(< ); // NOLINT
+INT_TYPE_COMPARISON_OP(<= ); // NOLINT
+INT_TYPE_COMPARISON_OP(> ); // NOLINT
+INT_TYPE_COMPARISON_OP(>= ); // NOLINT
+#undef INT_TYPE_COMPARISON_OP
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_INT_TYPE_H_
diff --git a/tensorflow/core/lib/gtl/int_type_test.cc b/tensorflow/core/lib/gtl/int_type_test.cc
new file mode 100644
index 0000000000..694886d345
--- /dev/null
+++ b/tensorflow/core/lib/gtl/int_type_test.cc
@@ -0,0 +1,282 @@
+// Unit test cases for IntType.
+
+#include <memory>
+#include <unordered_map>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/gtl/int_type.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TF_LIB_GTL_DEFINE_INT_TYPE(Int8_IT, int8);
+TF_LIB_GTL_DEFINE_INT_TYPE(UInt8_IT, uint8);
+TF_LIB_GTL_DEFINE_INT_TYPE(Int16_IT, int16);
+TF_LIB_GTL_DEFINE_INT_TYPE(UInt16_IT, uint16);
+TF_LIB_GTL_DEFINE_INT_TYPE(Int32_IT, int32);
+TF_LIB_GTL_DEFINE_INT_TYPE(Int64_IT, int64);
+TF_LIB_GTL_DEFINE_INT_TYPE(UInt32_IT, uint32);
+TF_LIB_GTL_DEFINE_INT_TYPE(UInt64_IT, uint64);
+TF_LIB_GTL_DEFINE_INT_TYPE(Long_IT, long); // NOLINT
+
+template <typename IntType_Type>
+class IntTypeTest : public ::testing::Test {
+ public:
+ typedef IntType_Type T;
+};
+
+// All tests below will be executed on all supported IntTypes.
+typedef ::testing::Types<Int8_IT, UInt8_IT, Int16_IT, UInt16_IT, Int32_IT,
+ Int64_IT, UInt64_IT, Long_IT> SupportedIntTypes;
+
+TYPED_TEST_CASE(IntTypeTest, SupportedIntTypes);
+
+TYPED_TEST(IntTypeTest, TestInitialization) {
+ constexpr typename TestFixture::T a;
+ constexpr typename TestFixture::T b(1);
+ constexpr typename TestFixture::T c(b);
+ EXPECT_EQ(0, a); // default initialization to 0
+ EXPECT_EQ(1, b);
+ EXPECT_EQ(1, c);
+}
+
+TYPED_TEST(IntTypeTest, TestOperators) {
+ typename TestFixture::T a(0);
+ typename TestFixture::T b(1);
+ typename TestFixture::T c(2);
+ constexpr typename TestFixture::T d(3);
+ constexpr typename TestFixture::T e(4);
+
+ // On all EXPECT_EQ below, we use the accessor value() as to not invoke the
+ // comparison operators which must themselves be tested.
+
+ // -- UNARY OPERATORS --------------------------------------------------------
+ EXPECT_EQ(0, (a++).value());
+ EXPECT_EQ(2, (++a).value());
+ EXPECT_EQ(2, (a--).value());
+ EXPECT_EQ(0, (--a).value());
+
+ EXPECT_EQ(true, !a);
+ EXPECT_EQ(false, !b);
+ static_assert(!d == false, "Unary operator! failed");
+
+ EXPECT_EQ(a.value(), +a);
+ static_assert(+d == d.value(), "Unary operator+ failed");
+ EXPECT_EQ(-a.value(), -a);
+ static_assert(-d == -d.value(), "Unary operator- failed");
+ EXPECT_EQ(~a.value(), ~a); // ~zero
+ EXPECT_EQ(~b.value(), ~b); // ~non-zero
+ static_assert(~d == ~d.value(), "Unary operator~ failed");
+
+ // -- ASSIGNMENT OPERATORS ---------------------------------------------------
+ // We test all assignment operators using IntType and constant as arguments.
+ // We also test the return from the operators.
+ // From same IntType
+ c = a = b;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+ // From constant
+ c = b = 2;
+ EXPECT_EQ(2, b.value());
+ EXPECT_EQ(2, c.value());
+ // From same IntType
+ c = a += b;
+ EXPECT_EQ(3, a.value());
+ EXPECT_EQ(3, c.value());
+ c = a -= b;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+ c = a *= b;
+ EXPECT_EQ(2, a.value());
+ EXPECT_EQ(2, c.value());
+ c = a /= b;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+ c = a <<= b;
+ EXPECT_EQ(4, a.value());
+ EXPECT_EQ(4, c.value());
+ c = a >>= b;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+ c = a %= b;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+ // From constant
+ c = a += 2;
+ EXPECT_EQ(3, a.value());
+ EXPECT_EQ(3, c.value());
+ c = a -= 2;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+ c = a *= 2;
+ EXPECT_EQ(2, a.value());
+ EXPECT_EQ(2, c.value());
+ c = a /= 2;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+ c = a <<= 2;
+ EXPECT_EQ(4, a.value());
+ EXPECT_EQ(4, c.value());
+ c = a >>= 2;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+ c = a %= 2;
+ EXPECT_EQ(1, a.value());
+ EXPECT_EQ(1, c.value());
+
+ // -- COMPARISON OPERATORS ---------------------------------------------------
+ a = 0;
+ b = 1;
+
+ EXPECT_FALSE(a == b);
+ EXPECT_TRUE(a == 0); // NOLINT
+ EXPECT_FALSE(1 == a); // NOLINT
+ static_assert(d == d, "operator== failed");
+ static_assert(d == 3, "operator== failed");
+ static_assert(3 == d, "operator== failed");
+ EXPECT_TRUE(a != b);
+ EXPECT_TRUE(a != 1); // NOLINT
+ EXPECT_FALSE(0 != a); // NOLINT
+ static_assert(d != e, "operator!= failed");
+ static_assert(d != 4, "operator!= failed");
+ static_assert(4 != d, "operator!= failed");
+ EXPECT_TRUE(a < b);
+ EXPECT_TRUE(a < 1); // NOLINT
+ EXPECT_FALSE(0 < a); // NOLINT
+ static_assert(d < e, "operator< failed");
+ static_assert(d < 4, "operator< failed");
+ static_assert(3 < e, "operator< failed");
+ EXPECT_TRUE(a <= b);
+ EXPECT_TRUE(a <= 1); // NOLINT
+ EXPECT_TRUE(0 <= a); // NOLINT
+ static_assert(d <= e, "operator<= failed");
+ static_assert(d <= 4, "operator<= failed");
+ static_assert(3 <= e, "operator<= failed");
+ EXPECT_FALSE(a > b);
+ EXPECT_FALSE(a > 1); // NOLINT
+ EXPECT_FALSE(0 > a); // NOLINT
+ static_assert(e > d, "operator> failed");
+ static_assert(e > 3, "operator> failed");
+ static_assert(4 > d, "operator> failed");
+ EXPECT_FALSE(a >= b);
+ EXPECT_FALSE(a >= 1); // NOLINT
+ EXPECT_TRUE(0 >= a); // NOLINT
+ static_assert(e >= d, "operator>= failed");
+ static_assert(e >= 3, "operator>= failed");
+ static_assert(4 >= d, "operator>= failed");
+
+ // -- BINARY OPERATORS -------------------------------------------------------
+ a = 1;
+ b = 3;
+ EXPECT_EQ(4, (a + b).value());
+ EXPECT_EQ(4, (a + 3).value());
+ EXPECT_EQ(4, (1 + b).value());
+ static_assert((d + e).value() == 7, "Binary operator+ failed");
+ static_assert((d + 4).value() == 7, "Binary operator+ failed");
+ static_assert((3 + e).value() == 7, "Binary operator+ failed");
+ EXPECT_EQ(2, (b - a).value());
+ EXPECT_EQ(2, (b - 1).value());
+ EXPECT_EQ(2, (3 - a).value());
+ static_assert((e - d).value() == 1, "Binary operator- failed");
+ static_assert((e - 3).value() == 1, "Binary operator- failed");
+ static_assert((4 - d).value() == 1, "Binary operator- failed");
+ EXPECT_EQ(3, (a * b).value());
+ EXPECT_EQ(3, (a * 3).value());
+ EXPECT_EQ(3, (1 * b).value());
+ static_assert((d * e).value() == 12, "Binary operator* failed");
+ static_assert((d * 4).value() == 12, "Binary operator* failed");
+ static_assert((3 * e).value() == 12, "Binary operator* failed");
+ EXPECT_EQ(0, (a / b).value());
+ EXPECT_EQ(0, (a / 3).value());
+ EXPECT_EQ(0, (1 / b).value());
+ static_assert((d / e).value() == 0, "Binary operator/ failed");
+ static_assert((d / 4).value() == 0, "Binary operator/ failed");
+ static_assert((3 / e).value() == 0, "Binary operator/ failed");
+ EXPECT_EQ(8, (a << b).value());
+ EXPECT_EQ(8, (a << 3).value());
+ EXPECT_EQ(8, (1 << b).value());
+ static_assert((d << e).value() == 48, "Binary operator<< failed");
+ static_assert((d << 4).value() == 48, "Binary operator<< failed");
+ static_assert((3 << e).value() == 48, "Binary operator<< failed");
+ b = 8;
+ EXPECT_EQ(4, (b >> a).value());
+ EXPECT_EQ(4, (b >> 1).value());
+ EXPECT_EQ(4, (8 >> a).value());
+ static_assert((d >> e).value() == 0, "Binary operator>> failed");
+ static_assert((d >> 4).value() == 0, "Binary operator>> failed");
+ static_assert((3 >> e).value() == 0, "Binary operator>> failed");
+ b = 3;
+ a = 2;
+ EXPECT_EQ(1, (b % a).value());
+ EXPECT_EQ(1, (b % 2).value());
+ EXPECT_EQ(1, (3 % a).value());
+ static_assert((e % d).value() == 1, "Binary operator% failed");
+ static_assert((e % 3).value() == 1, "Binary operator% failed");
+ static_assert((4 % d).value() == 1, "Binary operator% failed");
+}
+
+TYPED_TEST(IntTypeTest, TestHashFunctor) {
+ std::unordered_map<typename TestFixture::T, char,
+ typename TestFixture::T::Hasher> map;
+ typename TestFixture::T a(0);
+ map[a] = 'c';
+ EXPECT_EQ('c', map[a]);
+ map[++a] = 'o';
+ EXPECT_EQ('o', map[a]);
+
+ typename TestFixture::T b(a);
+ EXPECT_EQ(typename TestFixture::T::Hasher()(a),
+ typename TestFixture::T::Hasher()(b));
+}
+
+// Tests the use of the templatized value accessor that performs static_casts.
+// We use -1 to force casting in unsigned integers.
+TYPED_TEST(IntTypeTest, TestValueAccessor) {
+ constexpr typename TestFixture::T::ValueType i = -1;
+ constexpr typename TestFixture::T int_type(i);
+ EXPECT_EQ(i, int_type.value());
+ static_assert(int_type.value() == i, "value() failed");
+ // The use of the keyword 'template' (suggested by Clang) is only necessary
+ // as this code is part of a template class. Weird syntax though. Good news
+ // is that only int_type.value<int>() is needed in most code.
+ EXPECT_EQ(static_cast<int>(i), int_type.template value<int>());
+ EXPECT_EQ(static_cast<int8>(i), int_type.template value<int8>());
+ EXPECT_EQ(static_cast<int16>(i), int_type.template value<int16>());
+ EXPECT_EQ(static_cast<int32>(i), int_type.template value<int32>());
+ EXPECT_EQ(static_cast<uint32>(i), int_type.template value<uint32>());
+ EXPECT_EQ(static_cast<int64>(i), int_type.template value<int64>());
+ EXPECT_EQ(static_cast<uint64>(i), int_type.template value<uint64>());
+ EXPECT_EQ(static_cast<long>(i), int_type.template value<long>()); // NOLINT
+ static_assert(int_type.template value<int>() == static_cast<int>(i),
+ "value<Value>() failed");
+}
+
+TYPED_TEST(IntTypeTest, TestMove) {
+ // Check that the int types have move constructor/assignment.
+ // We do this by composing a struct with an int type and a unique_ptr. This
+ // struct can't be copied due to the unique_ptr, so it must be moved.
+ // If this compiles, it means that the int types have move operators.
+ struct NotCopyable {
+ typename TestFixture::T inttype;
+ std::unique_ptr<int> ptr;
+
+ static NotCopyable Make(int i) {
+ NotCopyable f;
+ f.inttype = typename TestFixture::T(i);
+ f.ptr.reset(new int(i));
+ return f;
+ }
+ };
+
+ // Test move constructor.
+ NotCopyable foo = NotCopyable::Make(123);
+ EXPECT_EQ(123, foo.inttype);
+ EXPECT_EQ(123, *foo.ptr);
+
+ // Test move assignment.
+ foo = NotCopyable::Make(321);
+ EXPECT_EQ(321, foo.inttype);
+ EXPECT_EQ(321, *foo.ptr);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/iterator_range.h b/tensorflow/core/lib/gtl/iterator_range.h
new file mode 100644
index 0000000000..baec85c40a
--- /dev/null
+++ b/tensorflow/core/lib/gtl/iterator_range.h
@@ -0,0 +1,49 @@
+// This provides a very simple, boring adaptor for a begin and end iterator
+// into a range type. This should be used to build range views that work well
+// with range based for loops and range based constructors.
+//
+// Note that code here follows more standards-based coding conventions as it
+// is mirroring proposed interfaces for standardization.
+//
+// Converted from chandlerc@'s code to Google style by joshl@.
+
+#ifndef TENSORFLOW_LIB_GTL_ITERATOR_RANGE_H_
+#define TENSORFLOW_LIB_GTL_ITERATOR_RANGE_H_
+
+#include <utility>
+
+namespace tensorflow {
+namespace gtl {
+
+// A range adaptor for a pair of iterators.
+//
+// This just wraps two iterators into a range-compatible interface. Nothing
+// fancy at all.
+template <typename IteratorT>
+class iterator_range {
+ public:
+ iterator_range() : begin_iterator_(), end_iterator_() {}
+ iterator_range(IteratorT begin_iterator, IteratorT end_iterator)
+ : begin_iterator_(std::move(begin_iterator)),
+ end_iterator_(std::move(end_iterator)) {}
+
+ IteratorT begin() const { return begin_iterator_; }
+ IteratorT end() const { return end_iterator_; }
+
+ private:
+ IteratorT begin_iterator_, end_iterator_;
+};
+
+// Convenience function for iterating over sub-ranges.
+//
+// This provides a bit of syntactic sugar to make using sub-ranges
+// in for loops a bit easier. Analogous to std::make_pair().
+template <class T>
+iterator_range<T> make_range(T x, T y) {
+ return iterator_range<T>(std::move(x), std::move(y));
+}
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_ITERATOR_RANGE_H_
diff --git a/tensorflow/core/lib/gtl/iterator_range_test.cc b/tensorflow/core/lib/gtl/iterator_range_test.cc
new file mode 100644
index 0000000000..328be4ecbc
--- /dev/null
+++ b/tensorflow/core/lib/gtl/iterator_range_test.cc
@@ -0,0 +1,60 @@
+#include "tensorflow/core/lib/gtl/iterator_range.h"
+
+#include <vector>
+#include "tensorflow/core/platform/port.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace gtl {
+namespace {
+
+TEST(IteratorRange, WholeVector) {
+ std::vector<int> v = {2, 3, 5, 7, 11, 13};
+ iterator_range<std::vector<int>::iterator> range(v.begin(), v.end());
+ int index = 0;
+ for (int prime : range) {
+ ASSERT_LT(index, v.size());
+ EXPECT_EQ(v[index], prime);
+ ++index;
+ }
+ EXPECT_EQ(v.size(), index);
+}
+
+TEST(IteratorRange, VectorMakeRange) {
+ std::vector<int> v = {2, 3, 5, 7, 11, 13};
+ auto range = make_range(v.begin(), v.end());
+ int index = 0;
+ for (int prime : range) {
+ ASSERT_LT(index, v.size());
+ EXPECT_EQ(v[index], prime);
+ ++index;
+ }
+ EXPECT_EQ(v.size(), index);
+}
+
+TEST(IteratorRange, PartArray) {
+ int v[] = {2, 3, 5, 7, 11, 13};
+ iterator_range<int*> range(&v[1], &v[4]); // 3, 5, 7
+ int index = 1;
+ for (int prime : range) {
+ ASSERT_LT(index, TF_ARRAYSIZE(v));
+ EXPECT_EQ(v[index], prime);
+ ++index;
+ }
+ EXPECT_EQ(4, index);
+}
+
+TEST(IteratorRange, ArrayMakeRange) {
+ int v[] = {2, 3, 5, 7, 11, 13};
+ auto range = make_range(&v[1], &v[4]); // 3, 5, 7
+ int index = 1;
+ for (int prime : range) {
+ ASSERT_LT(index, TF_ARRAYSIZE(v));
+ EXPECT_EQ(v[index], prime);
+ ++index;
+ }
+ EXPECT_EQ(4, index);
+}
+} // namespace
+} // namespace gtl
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/manual_constructor.h b/tensorflow/core/lib/gtl/manual_constructor.h
new file mode 100644
index 0000000000..39f029ed4a
--- /dev/null
+++ b/tensorflow/core/lib/gtl/manual_constructor.h
@@ -0,0 +1,230 @@
+// ManualConstructor statically-allocates space in which to store some
+// object, but does not initialize it. You can then call the constructor
+// and destructor for the object yourself as you see fit. This is useful
+// for memory management optimizations, where you want to initialize and
+// destroy an object multiple times but only allocate it once.
+//
+// (When I say ManualConstructor statically allocates space, I mean that
+// the ManualConstructor object itself is forced to be the right size.)
+
+#ifndef TENSORFLOW_LIB_GTL_MANUAL_CONSTRUCTOR_H_
+#define TENSORFLOW_LIB_GTL_MANUAL_CONSTRUCTOR_H_
+
+#include <stddef.h>
+#include <new>
+#include <utility>
+
+#include "tensorflow/core/platform/port.h" // For aligned_malloc/aligned_free
+
+namespace tensorflow {
+namespace gtl {
+namespace internal {
+
+//
+// Provides a char array with the exact same alignment as another type. The
+// first parameter must be a complete type, the second parameter is how many
+// of that type to provide space for.
+//
+// TF_LIB_GTL_ALIGNED_CHAR_ARRAY(struct stat, 16) storage_;
+//
+// Because MSVC and older GCCs require that the argument to their alignment
+// construct to be a literal constant integer, we use a template instantiated
+// at all the possible powers of two.
+#ifndef SWIG
+template <int alignment, int size>
+struct AlignType {};
+template <int size>
+struct AlignType<0, size> {
+ typedef char result[size];
+};
+#if defined(COMPILER_MSVC)
+#define TF_LIB_GTL_ALIGN_ATTRIBUTE(X) __declspec(align(X))
+#define TF_LIB_GTL_ALIGN_OF(T) __alignof(T)
+#elif defined(COMPILER_GCC3) || __GNUC__ >= 3 || defined(__APPLE__) || \
+ defined(COMPILER_ICC) || defined(OS_NACL) || defined(__clang__)
+#define TF_LIB_GTL_ALIGN_ATTRIBUTE(X) __attribute__((aligned(X)))
+#define TF_LIB_GTL_ALIGN_OF(T) __alignof__(T)
+#endif
+
+#if defined(TF_LIB_GTL_ALIGN_ATTRIBUTE)
+
+#define TF_LIB_GTL_ALIGNTYPE_TEMPLATE(X) \
+ template <int size> \
+ struct AlignType<X, size> { \
+ typedef TF_LIB_GTL_ALIGN_ATTRIBUTE(X) char result[size]; \
+ }
+
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(1);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(2);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(4);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(8);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(16);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(32);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(64);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(128);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(256);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(512);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(1024);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(2048);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(4096);
+TF_LIB_GTL_ALIGNTYPE_TEMPLATE(8192);
+// Any larger and MSVC++ will complain.
+
+#define TF_LIB_GTL_ALIGNED_CHAR_ARRAY(T, Size) \
+ typename tensorflow::gtl::internal::AlignType<TF_LIB_GTL_ALIGN_OF(T), \
+ sizeof(T) * Size>::result
+
+#undef TF_LIB_GTL_ALIGNTYPE_TEMPLATE
+#undef TF_LIB_GTL_ALIGN_ATTRIBUTE
+
+#else // defined(TF_LIB_GTL_ALIGN_ATTRIBUTE)
+#error "You must define TF_LIB_GTL_ALIGNED_CHAR_ARRAY for your compiler."
+#endif // defined(TF_LIB_GTL_ALIGN_ATTRIBUTE)
+
+#else // !SWIG
+
+// SWIG can't represent alignment and doesn't care about alignment on data
+// members (it works fine without it).
+template <typename Size>
+struct AlignType {
+ typedef char result[Size];
+};
+#define TF_LIB_GTL_ALIGNED_CHAR_ARRAY(T, Size) \
+ tensorflow::gtl::internal::AlignType<Size * sizeof(T)>::result
+
+// Enough to parse with SWIG, will never be used by running code.
+#define TF_LIB_GTL_ALIGN_OF(Type) 16
+
+#endif // !SWIG
+
+} // namespace internal
+} // namespace gtl
+
+template <typename Type>
+class ManualConstructor {
+ public:
+ // No constructor or destructor because one of the most useful uses of
+ // this class is as part of a union, and members of a union cannot have
+ // constructors or destructors. And, anyway, the whole point of this
+ // class is to bypass these.
+
+ // Support users creating arrays of ManualConstructor<>s. This ensures that
+ // the array itself has the correct alignment.
+ static void* operator new[](size_t size) {
+ return port::aligned_malloc(size, TF_LIB_GTL_ALIGN_OF(Type));
+ }
+ static void operator delete[](void* mem) { port::aligned_free(mem); }
+
+ inline Type* get() { return reinterpret_cast<Type*>(space_); }
+ inline const Type* get() const {
+ return reinterpret_cast<const Type*>(space_);
+ }
+
+ inline Type* operator->() { return get(); }
+ inline const Type* operator->() const { return get(); }
+
+ inline Type& operator*() { return *get(); }
+ inline const Type& operator*() const { return *get(); }
+
+ inline void Init() { new (space_) Type; }
+
+// Init() constructs the Type instance using the given arguments
+// (which are forwarded to Type's constructor). In C++11, Init() can
+// take any number of arguments of any type, and forwards them perfectly.
+// On pre-C++11 platforms, it can take up to 11 arguments, and may not be
+// able to forward certain kinds of arguments.
+//
+// Note that Init() with no arguments performs default-initialization,
+// not zero-initialization (i.e it behaves the same as "new Type;", not
+// "new Type();"), so it will leave non-class types uninitialized.
+#ifdef LANG_CXX11
+ template <typename... Ts>
+ inline void Init(Ts&&... args) { // NOLINT
+ new (space_) Type(std::forward<Ts>(args)...); // NOLINT
+ }
+#else // !defined(LANG_CXX11)
+ template <typename T1>
+ inline void Init(const T1& p1) {
+ new (space_) Type(p1);
+ }
+
+ template <typename T1, typename T2>
+ inline void Init(const T1& p1, const T2& p2) {
+ new (space_) Type(p1, p2);
+ }
+
+ template <typename T1, typename T2, typename T3>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3) {
+ new (space_) Type(p1, p2, p3);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4) {
+ new (space_) Type(p1, p2, p3, p4);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4,
+ const T5& p5) {
+ new (space_) Type(p1, p2, p3, p4, p5);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5,
+ typename T6>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4,
+ const T5& p5, const T6& p6) {
+ new (space_) Type(p1, p2, p3, p4, p5, p6);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5,
+ typename T6, typename T7>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4,
+ const T5& p5, const T6& p6, const T7& p7) {
+ new (space_) Type(p1, p2, p3, p4, p5, p6, p7);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5,
+ typename T6, typename T7, typename T8>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4,
+ const T5& p5, const T6& p6, const T7& p7, const T8& p8) {
+ new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5,
+ typename T6, typename T7, typename T8, typename T9>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4,
+ const T5& p5, const T6& p6, const T7& p7, const T8& p8,
+ const T9& p9) {
+ new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8, p9);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5,
+ typename T6, typename T7, typename T8, typename T9, typename T10>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4,
+ const T5& p5, const T6& p6, const T7& p7, const T8& p8,
+ const T9& p9, const T10& p10) {
+ new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10);
+ }
+
+ template <typename T1, typename T2, typename T3, typename T4, typename T5,
+ typename T6, typename T7, typename T8, typename T9, typename T10,
+ typename T11>
+ inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4,
+ const T5& p5, const T6& p6, const T7& p7, const T8& p8,
+ const T9& p9, const T10& p10, const T11& p11) {
+ new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11);
+ }
+#endif // LANG_CXX11
+
+ inline void Destroy() { get()->~Type(); }
+
+ private:
+ TF_LIB_GTL_ALIGNED_CHAR_ARRAY(Type, 1) space_;
+};
+
+#undef TF_LIB_GTL_ALIGNED_CHAR_ARRAY
+#undef TF_LIB_GTL_ALIGN_OF
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_MANUAL_CONSTRUCTOR_H_
diff --git a/tensorflow/core/lib/gtl/manual_constructor_test.cc b/tensorflow/core/lib/gtl/manual_constructor_test.cc
new file mode 100644
index 0000000000..a929591be2
--- /dev/null
+++ b/tensorflow/core/lib/gtl/manual_constructor_test.cc
@@ -0,0 +1,113 @@
+#include "tensorflow/core/lib/gtl/manual_constructor.h"
+
+#include <stdint.h>
+
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+static int constructor_count_ = 0;
+
+template <int kSize>
+struct TestN {
+ TestN() { ++constructor_count_; }
+ ~TestN() { --constructor_count_; }
+ char a[kSize];
+};
+
+typedef TestN<1> Test1;
+typedef TestN<2> Test2;
+typedef TestN<3> Test3;
+typedef TestN<4> Test4;
+typedef TestN<5> Test5;
+typedef TestN<9> Test9;
+typedef TestN<15> Test15;
+
+} // namespace
+
+namespace {
+
+TEST(ManualConstructorTest, Sizeof) {
+ CHECK_EQ(sizeof(ManualConstructor<Test1>), sizeof(Test1));
+ CHECK_EQ(sizeof(ManualConstructor<Test2>), sizeof(Test2));
+ CHECK_EQ(sizeof(ManualConstructor<Test3>), sizeof(Test3));
+ CHECK_EQ(sizeof(ManualConstructor<Test4>), sizeof(Test4));
+ CHECK_EQ(sizeof(ManualConstructor<Test5>), sizeof(Test5));
+ CHECK_EQ(sizeof(ManualConstructor<Test9>), sizeof(Test9));
+ CHECK_EQ(sizeof(ManualConstructor<Test15>), sizeof(Test15));
+
+ CHECK_EQ(constructor_count_, 0);
+ ManualConstructor<Test1> mt[4];
+ CHECK_EQ(sizeof(mt), 4);
+ CHECK_EQ(constructor_count_, 0);
+ mt[0].Init();
+ CHECK_EQ(constructor_count_, 1);
+ mt[0].Destroy();
+}
+
+TEST(ManualConstructorTest, Alignment) {
+ // We want to make sure that ManualConstructor aligns its memory properly
+ // on a word barrier. Otherwise, it might be unexpectedly slow, since
+ // memory access will be unaligned.
+
+ struct {
+ char a;
+ ManualConstructor<void*> b;
+ } test1;
+ struct {
+ char a;
+ void* b;
+ } control1;
+
+ // TODO(bww): Make these tests more direct with C++11 alignment_of<T>::value.
+ EXPECT_EQ(reinterpret_cast<char*>(test1.b.get()) - &test1.a,
+ reinterpret_cast<char*>(&control1.b) - &control1.a);
+ EXPECT_EQ(reinterpret_cast<intptr_t>(test1.b.get()) % sizeof(control1.b), 0);
+
+ struct {
+ char a;
+ ManualConstructor<long double> b;
+ } test2;
+ struct {
+ char a;
+ long double b;
+ } control2;
+
+ EXPECT_EQ(reinterpret_cast<char*>(test2.b.get()) - &test2.a,
+ reinterpret_cast<char*>(&control2.b) - &control2.a);
+#ifdef ARCH_K8
+ EXPECT_EQ(reinterpret_cast<intptr_t>(test2.b.get()) % 16, 0);
+#endif
+#ifdef ARCH_PIII
+ EXPECT_EQ(reinterpret_cast<intptr_t>(test2.b.get()) % 4, 0);
+#endif
+}
+
+TEST(ManualConstructorTest, DefaultInitialize) {
+ struct X {
+ X() : x(123) {}
+ int x;
+ };
+ union {
+ ManualConstructor<X> x;
+ ManualConstructor<int> y;
+ } u;
+ *u.y = -1;
+ u.x.Init(); // should default-initialize u.x
+ EXPECT_EQ(123, u.x->x);
+}
+
+TEST(ManualConstructorTest, ZeroInitializePOD) {
+ union {
+ ManualConstructor<int> x;
+ ManualConstructor<int> y;
+ } u;
+ *u.y = -1;
+ u.x.Init(); // should not zero-initialize u.x
+ EXPECT_EQ(-1, *u.y);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/map_util.h b/tensorflow/core/lib/gtl/map_util.h
new file mode 100644
index 0000000000..c953de57c7
--- /dev/null
+++ b/tensorflow/core/lib/gtl/map_util.h
@@ -0,0 +1,123 @@
+// This file provides utility functions for use with STL map-like data
+// structures, such as std::map and hash_map. Some functions will also work with
+// sets, such as ContainsKey().
+
+#ifndef TENSORFLOW_LIB_GTL_MAP_UTIL_H_
+#define TENSORFLOW_LIB_GTL_MAP_UTIL_H_
+
+#include <stddef.h>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace tensorflow {
+namespace gtl {
+
+// Returns a pointer to the const value associated with the given key if it
+// exists, or NULL otherwise.
+template <class Collection>
+const typename Collection::value_type::second_type* FindOrNull(
+ const Collection& collection,
+ const typename Collection::value_type::first_type& key) {
+ typename Collection::const_iterator it = collection.find(key);
+ if (it == collection.end()) {
+ return 0;
+ }
+ return &it->second;
+}
+
+// Same as above but returns a pointer to the non-const value.
+template <class Collection>
+typename Collection::value_type::second_type* FindOrNull(
+ Collection& collection, // NOLINT
+ const typename Collection::value_type::first_type& key) {
+ typename Collection::iterator it = collection.find(key);
+ if (it == collection.end()) {
+ return 0;
+ }
+ return &it->second;
+}
+
+// Returns the pointer value associated with the given key. If none is found,
+// NULL is returned. The function is designed to be used with a map of keys to
+// pointers.
+//
+// This function does not distinguish between a missing key and a key mapped
+// to a NULL value.
+template <class Collection>
+typename Collection::value_type::second_type FindPtrOrNull(
+ const Collection& collection,
+ const typename Collection::value_type::first_type& key) {
+ typename Collection::const_iterator it = collection.find(key);
+ if (it == collection.end()) {
+ return typename Collection::value_type::second_type();
+ }
+ return it->second;
+}
+
+// Returns a const reference to the value associated with the given key if it
+// exists, otherwise returns a const reference to the provided default value.
+//
+// WARNING: If a temporary object is passed as the default "value,"
+// this function will return a reference to that temporary object,
+// which will be destroyed at the end of the statement. A common
+// example: if you have a map with string values, and you pass a char*
+// as the default "value," either use the returned value immediately
+// or store it in a string (not string&).
+template <class Collection>
+const typename Collection::value_type::second_type& FindWithDefault(
+ const Collection& collection,
+ const typename Collection::value_type::first_type& key,
+ const typename Collection::value_type::second_type& value) {
+ typename Collection::const_iterator it = collection.find(key);
+ if (it == collection.end()) {
+ return value;
+ }
+ return it->second;
+}
+
+// Inserts the given key and value into the given collection if and only if the
+// given key did NOT already exist in the collection. If the key previously
+// existed in the collection, the value is not changed. Returns true if the
+// key-value pair was inserted; returns false if the key was already present.
+template <class Collection>
+bool InsertIfNotPresent(Collection* const collection,
+ const typename Collection::value_type& vt) {
+ return collection->insert(vt).second;
+}
+
+// Same as above except the key and value are passed separately.
+template <class Collection>
+bool InsertIfNotPresent(
+ Collection* const collection,
+ const typename Collection::value_type::first_type& key,
+ const typename Collection::value_type::second_type& value) {
+ return InsertIfNotPresent(collection,
+ typename Collection::value_type(key, value));
+}
+
+// Looks up a given key and value pair in a collection and inserts the key-value
+// pair if it's not already present. Returns a reference to the value associated
+// with the key.
+template <class Collection>
+typename Collection::value_type::second_type& LookupOrInsert(
+ Collection* const collection, const typename Collection::value_type& vt) {
+ return collection->insert(vt).first->second;
+}
+
+// Same as above except the key-value are passed separately.
+template <class Collection>
+typename Collection::value_type::second_type& LookupOrInsert(
+ Collection* const collection,
+ const typename Collection::value_type::first_type& key,
+ const typename Collection::value_type::second_type& value) {
+ return LookupOrInsert(collection,
+ typename Collection::value_type(key, value));
+}
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_MAP_UTIL_H_
diff --git a/tensorflow/core/lib/gtl/map_util_test.cc b/tensorflow/core/lib/gtl/map_util_test.cc
new file mode 100644
index 0000000000..356f987337
--- /dev/null
+++ b/tensorflow/core/lib/gtl/map_util_test.cc
@@ -0,0 +1,47 @@
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+#include <map>
+#include <set>
+#include <string>
+#include "tensorflow/core/platform/port.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TEST(MapUtil, Find) {
+ typedef std::map<string, string> Map;
+ Map m;
+
+ // Check that I can use a type that's implicitly convertible to the
+ // key or value type, such as const char* -> string.
+ EXPECT_EQ("", gtl::FindWithDefault(m, "foo", ""));
+ m["foo"] = "bar";
+ EXPECT_EQ("bar", gtl::FindWithDefault(m, "foo", ""));
+ EXPECT_EQ("bar", *gtl::FindOrNull(m, "foo"));
+ string str;
+ EXPECT_TRUE(m.count("foo") > 0);
+ EXPECT_EQ(m["foo"], "bar");
+}
+
+TEST(MapUtil, LookupOrInsert) {
+ typedef std::map<string, string> Map;
+ Map m;
+
+ // Check that I can use a type that's implicitly convertible to the
+ // key or value type, such as const char* -> string.
+ EXPECT_EQ("xyz", gtl::LookupOrInsert(&m, "foo", "xyz"));
+ EXPECT_EQ("xyz", gtl::LookupOrInsert(&m, "foo", "abc"));
+}
+
+TEST(MapUtil, InsertIfNotPresent) {
+ // Set operations
+ typedef std::set<int> Set;
+ Set s;
+ EXPECT_TRUE(gtl::InsertIfNotPresent(&s, 0));
+ EXPECT_EQ(s.count(0), 1);
+ EXPECT_FALSE(gtl::InsertIfNotPresent(&s, 0));
+ EXPECT_EQ(s.count(0), 1);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/stl_util.h b/tensorflow/core/lib/gtl/stl_util.h
new file mode 100644
index 0000000000..83abcd6b55
--- /dev/null
+++ b/tensorflow/core/lib/gtl/stl_util.h
@@ -0,0 +1,130 @@
+// This file provides utility functions for use with STL
+
+#ifndef TENSORFLOW_LIB_GTL_STL_UTIL_H_
+#define TENSORFLOW_LIB_GTL_STL_UTIL_H_
+
+#include <stddef.h>
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace tensorflow {
+namespace gtl {
+
+// Returns a mutable char* pointing to a string's internal buffer, which may not
+// be null-terminated. Returns NULL for an empty string. If not non-null,
+// writing through this pointer will modify the string.
+//
+// string_as_array(&str)[i] is valid for 0 <= i < str.size() until the
+// next call to a string method that invalidates iterators.
+//
+// In C++11 you may simply use &str[0] to get a mutable char*.
+//
+// Prior to C++11, there was no standard-blessed way of getting a mutable
+// reference to a string's internal buffer. The requirement that string be
+// contiguous is officially part of the C++11 standard [string.require]/5.
+// According to Matt Austern, this should already work on all current C++98
+// implementations.
+inline char* string_as_array(string* str) {
+ return str->empty() ? NULL : &*str->begin();
+}
+
+// Returns the T* array for the given vector, or NULL if the vector was empty.
+//
+// Note: If you know the array will never be empty, you can use &*v.begin()
+// directly, but that is may dump core if v is empty. This function is the most
+// efficient code that will work, taking into account how our STL is actually
+// implemented. THIS IS NON-PORTABLE CODE, so use this function instead of
+// repeating the nonportable code everywhere. If our STL implementation changes,
+// we will need to change this as well.
+template <typename T, typename Allocator>
+inline T* vector_as_array(std::vector<T, Allocator>* v) {
+#if defined NDEBUG && !defined _GLIBCXX_DEBUG
+ return &*v->begin();
+#else
+ return v->empty() ? NULL : &*v->begin();
+#endif
+}
+// vector_as_array overload for const std::vector<>.
+template <typename T, typename Allocator>
+inline const T* vector_as_array(const std::vector<T, Allocator>* v) {
+#if defined NDEBUG && !defined _GLIBCXX_DEBUG
+ return &*v->begin();
+#else
+ return v->empty() ? NULL : &*v->begin();
+#endif
+}
+
+// Like str->resize(new_size), except any new characters added to "*str" as a
+// result of resizing may be left uninitialized, rather than being filled with
+// '0' bytes. Typically used when code is then going to overwrite the backing
+// store of the string with known data. Uses a Google extension to ::string.
+inline void STLStringResizeUninitialized(string* s, size_t new_size) {
+#if __google_stl_resize_uninitialized_string
+ s->resize_uninitialized(new_size);
+#else
+ s->resize(new_size);
+#endif
+}
+
+// Calls delete (non-array version) on the SECOND item (pointer) in each pair in
+// the range [begin, end).
+//
+// Note: If you're calling this on an entire container, you probably want to
+// call STLDeleteValues(&container) instead, or use ValueDeleter.
+template <typename ForwardIterator>
+void STLDeleteContainerPairSecondPointers(ForwardIterator begin,
+ ForwardIterator end) {
+ while (begin != end) {
+ ForwardIterator temp = begin;
+ ++begin;
+ delete temp->second;
+ }
+}
+
+// Deletes all the elements in an STL container and clears the container. This
+// function is suitable for use with a vector, set, hash_set, or any other STL
+// container which defines sensible begin(), end(), and clear() methods.
+//
+// If container is NULL, this function is a no-op.
+template <typename T>
+void STLDeleteElements(T* container) {
+ if (!container) return;
+ auto it = container->begin();
+ while (it != container->end()) {
+ auto temp = it;
+ ++it;
+ delete *temp;
+ }
+ container->clear();
+}
+
+// Given an STL container consisting of (key, value) pairs, STLDeleteValues
+// deletes all the "value" components and clears the container. Does nothing in
+// the case it's given a NULL pointer.
+template <typename T>
+void STLDeleteValues(T* container) {
+ if (!container) return;
+ auto it = container->begin();
+ while (it != container->end()) {
+ auto temp = it;
+ ++it;
+ delete temp->second;
+ }
+ container->clear();
+}
+
+// Sorts and removes duplicates from a sequence container.
+template <typename T>
+inline void STLSortAndRemoveDuplicates(T* v) {
+ std::sort(v->begin(), v->end());
+ v->erase(std::unique(v->begin(), v->end()), v->end());
+}
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_STL_UTIL_H_
diff --git a/tensorflow/core/lib/gtl/top_n.h b/tensorflow/core/lib/gtl/top_n.h
new file mode 100644
index 0000000000..b95b998c21
--- /dev/null
+++ b/tensorflow/core/lib/gtl/top_n.h
@@ -0,0 +1,324 @@
+// This simple class finds the top n elements of an incrementally provided set
+// of elements which you push one at a time. If the number of elements exceeds
+// n, the lowest elements are incrementally dropped. At the end you get
+// a vector of the top elements sorted in descending order (through Extract() or
+// ExtractNondestructive()), or a vector of the top elements but not sorted
+// (through ExtractUnsorted() or ExtractUnsortedNondestructive()).
+//
+// The value n is specified in the constructor. If there are p elements pushed
+// altogether:
+// The total storage requirements are O(min(n, p)) elements
+// The running time is O(p * log(min(n, p))) comparisons
+// If n is a constant, the total storage required is a constant and the running
+// time is linear in p.
+//
+// NOTE(zhifengc): There is a way to do this in O(min(n, p)) storage and O(p)
+// runtime. The basic idea is to repeatedly fill up a buffer of 2 * n elements,
+// discarding the lowest n elements whenever the buffer is full using a linear-
+// time median algorithm. This may have better performance when the input
+// sequence is partially sorted.
+//
+// NOTE(zhifengc): This class should be redesigned to avoid reallocating a
+// vector for each Extract.
+
+#ifndef TENSORFLOW_LIB_GTL_TOP_N_H_
+#define TENSORFLOW_LIB_GTL_TOP_N_H_
+
+#include <stddef.h>
+#include <algorithm>
+#include <functional>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace gtl {
+
+// Cmp is an stl binary predicate. Note that Cmp is the "greater" predicate,
+// not the more commonly used "less" predicate.
+//
+// If you use a "less" predicate here, the TopN will pick out the bottom N
+// elements out of the ones passed to it, and it will return them sorted in
+// ascending order.
+//
+// TopN is rule-of-zero copyable and movable if its members are.
+template <class T, class Cmp = std::greater<T> >
+class TopN {
+ public:
+ // The TopN is in one of the three states:
+ //
+ // o UNORDERED: this is the state an instance is originally in,
+ // where the elements are completely orderless.
+ //
+ // o BOTTOM_KNOWN: in this state, we keep the invariant that there
+ // is at least one element in it, and the lowest element is at
+ // position 0. The elements in other positions remain
+ // unsorted. This state is reached if the state was originally
+ // UNORDERED and a peek_bottom() function call is invoked.
+ //
+ // o HEAP_SORTED: in this state, the array is kept as a heap and
+ // there are exactly (limit_+1) elements in the array. This
+ // state is reached when at least (limit_+1) elements are
+ // pushed in.
+ //
+ // The state transition graph is at follows:
+ //
+ // peek_bottom() (limit_+1) elements
+ // UNORDERED --------------> BOTTOM_KNOWN --------------------> HEAP_SORTED
+ // | ^
+ // | (limit_+1) elements |
+ // +-----------------------------------------------------------+
+
+ enum State { UNORDERED, BOTTOM_KNOWN, HEAP_SORTED };
+ using UnsortedIterator = typename std::vector<T>::const_iterator;
+
+ // 'limit' is the maximum number of top results to return.
+ explicit TopN(size_t limit) : TopN(limit, Cmp()) {}
+ TopN(size_t limit, const Cmp &cmp) : limit_(limit), cmp_(cmp) {}
+
+ size_t limit() const { return limit_; }
+
+ // Number of elements currently held by this TopN object. This
+ // will be no greater than 'limit' passed to the constructor.
+ size_t size() const { return std::min(elements_.size(), limit_); }
+
+ bool empty() const { return size() == 0; }
+
+ // If you know how many elements you will push at the time you create the
+ // TopN object, you can call reserve to preallocate the memory that TopN
+ // will need to process all 'n' pushes. Calling this method is optional.
+ void reserve(size_t n) { elements_.reserve(std::min(n, limit_ + 1)); }
+
+ // Push 'v'. If the maximum number of elements was exceeded, drop the
+ // lowest element and return it in 'dropped' (if given). If the maximum is not
+ // exceeded, 'dropped' will remain unchanged. 'dropped' may be omitted or
+ // nullptr, in which case it is not filled in.
+ // Requires: T is CopyAssignable, Swappable
+ void push(const T &v) { push(v, nullptr); }
+ void push(const T &v, T *dropped) { PushInternal(v, dropped); }
+
+ // Move overloads of push.
+ // Requires: T is MoveAssignable, Swappable
+ void push(T &&v) { // NOLINT(build/c++11)
+ push(std::move(v), nullptr);
+ }
+ void push(T &&v, T *dropped) { // NOLINT(build/c++11)
+ PushInternal(std::move(v), dropped);
+ }
+
+ // Peeks the bottom result without calling Extract()
+ const T &peek_bottom();
+
+ // Extract the elements as a vector sorted in descending order. The caller
+ // assumes ownership of the vector and must delete it when done. This is a
+ // destructive operation. The only method that can be called immediately
+ // after Extract() is Reset().
+ std::vector<T> *Extract();
+
+ // Similar to Extract(), but makes no guarantees the elements are in sorted
+ // order. As with Extract(), the caller assumes ownership of the vector and
+ // must delete it when done. This is a destructive operation. The only
+ // method that can be called immediately after ExtractUnsorted() is Reset().
+ std::vector<T> *ExtractUnsorted();
+
+ // A non-destructive version of Extract(). Copy the elements in a new vector
+ // sorted in descending order and return it. The caller assumes ownership of
+ // the new vector and must delete it when done. After calling
+ // ExtractNondestructive(), the caller can continue to push() new elements.
+ std::vector<T> *ExtractNondestructive() const;
+
+ // A non-destructive version of Extract(). Copy the elements to a given
+ // vector sorted in descending order. After calling
+ // ExtractNondestructive(), the caller can continue to push() new elements.
+ // Note:
+ // 1. The given argument must to be allocated.
+ // 2. Any data contained in the vector prior to the call will be deleted
+ // from it. After the call the vector will contain only the elements
+ // from the data structure.
+ void ExtractNondestructive(std::vector<T> *output) const;
+
+ // A non-destructive version of ExtractUnsorted(). Copy the elements in a new
+ // vector and return it, with no guarantees the elements are in sorted order.
+ // The caller assumes ownership of the new vector and must delete it when
+ // done. After calling ExtractUnsortedNondestructive(), the caller can
+ // continue to push() new elements.
+ std::vector<T> *ExtractUnsortedNondestructive() const;
+
+ // A non-destructive version of ExtractUnsorted(). Copy the elements into
+ // a given vector, with no guarantees the elements are in sorted order.
+ // After calling ExtractUnsortedNondestructive(), the caller can continue
+ // to push() new elements.
+ // Note:
+ // 1. The given argument must to be allocated.
+ // 2. Any data contained in the vector prior to the call will be deleted
+ // from it. After the call the vector will contain only the elements
+ // from the data structure.
+ void ExtractUnsortedNondestructive(std::vector<T> *output) const;
+
+ // Return an iterator to the beginning (end) of the container,
+ // with no guarantees about the order of iteration. These iterators are
+ // invalidated by mutation of the data structure.
+ UnsortedIterator unsorted_begin() const { return elements_.begin(); }
+ UnsortedIterator unsorted_end() const { return elements_.begin() + size(); }
+
+ // Accessor for comparator template argument.
+ Cmp *comparator() { return &cmp_; }
+
+ // This removes all elements. If Extract() or ExtractUnsorted() have been
+ // called, this will put it back in an empty but useable state.
+ void Reset();
+
+ private:
+ template <typename U>
+ void PushInternal(U &&v, T *dropped); // NOLINT(build/c++11)
+
+ // elements_ can be in one of two states:
+ // elements_.size() <= limit_: elements_ is an unsorted vector of elements
+ // pushed so far.
+ // elements_.size() > limit_: The last element of elements_ is unused;
+ // the other elements of elements_ are an stl heap whose size is exactly
+ // limit_. In this case elements_.size() is exactly one greater than
+ // limit_, but don't use "elements_.size() == limit_ + 1" to check for
+ // that because you'll get a false positive if limit_ == size_t(-1).
+ std::vector<T> elements_;
+ size_t limit_; // Maximum number of elements to find
+ Cmp cmp_; // Greater-than comparison function
+ State state_ = UNORDERED;
+};
+
+// ----------------------------------------------------------------------
+// Implementations of non-inline functions
+
+template <class T, class Cmp>
+template <typename U>
+void TopN<T, Cmp>::PushInternal(U &&v, T *dropped) { // NOLINT(build/c++11)
+ if (limit_ == 0) {
+ if (dropped) *dropped = std::forward<U>(v); // NOLINT(build/c++11)
+ return;
+ }
+ if (state_ != HEAP_SORTED) {
+ elements_.push_back(std::forward<U>(v)); // NOLINT(build/c++11)
+ if (state_ == UNORDERED || cmp_(elements_.back(), elements_.front())) {
+ // Easy case: we just pushed the new element back
+ } else {
+ // To maintain the BOTTOM_KNOWN state, we need to make sure that
+ // the element at position 0 is always the smallest. So we put
+ // the new element at position 0 and push the original bottom
+ // element in the back.
+ // Warning: this code is subtle.
+ using std::swap;
+ swap(elements_.front(), elements_.back());
+ }
+ if (elements_.size() == limit_ + 1) {
+ // Transition from unsorted vector to a heap.
+ std::make_heap(elements_.begin(), elements_.end(), cmp_);
+ if (dropped) *dropped = std::move(elements_.front());
+ std::pop_heap(elements_.begin(), elements_.end(), cmp_);
+ state_ = HEAP_SORTED;
+ }
+ } else {
+ // Only insert the new element if it is greater than the least element.
+ if (cmp_(v, elements_.front())) {
+ elements_.back() = std::forward<U>(v); // NOLINT(build/c++11)
+ std::push_heap(elements_.begin(), elements_.end(), cmp_);
+ if (dropped) *dropped = std::move(elements_.front());
+ std::pop_heap(elements_.begin(), elements_.end(), cmp_);
+ } else {
+ if (dropped) *dropped = std::forward<U>(v); // NOLINT(build/c++11)
+ }
+ }
+}
+
+template <class T, class Cmp>
+const T &TopN<T, Cmp>::peek_bottom() {
+ CHECK(!empty());
+ if (state_ == UNORDERED) {
+ // We need to do a linear scan to find out the bottom element
+ int min_candidate = 0;
+ for (size_t i = 1; i < elements_.size(); ++i) {
+ if (cmp_(elements_[min_candidate], elements_[i])) {
+ min_candidate = i;
+ }
+ }
+ // By swapping the element at position 0 and the minimal
+ // element, we transition to the BOTTOM_KNOWN state
+ if (min_candidate != 0) {
+ using std::swap;
+ swap(elements_[0], elements_[min_candidate]);
+ }
+ state_ = BOTTOM_KNOWN;
+ }
+ return elements_.front();
+}
+
+template <class T, class Cmp>
+std::vector<T> *TopN<T, Cmp>::Extract() {
+ auto out = new std::vector<T>;
+ out->swap(elements_);
+ if (state_ != HEAP_SORTED) {
+ std::sort(out->begin(), out->end(), cmp_);
+ } else {
+ out->pop_back();
+ std::sort_heap(out->begin(), out->end(), cmp_);
+ }
+ return out;
+}
+
+template <class T, class Cmp>
+std::vector<T> *TopN<T, Cmp>::ExtractUnsorted() {
+ auto out = new std::vector<T>;
+ out->swap(elements_);
+ if (state_ == HEAP_SORTED) {
+ // Remove the limit_+1'th element.
+ out->pop_back();
+ }
+ return out;
+}
+
+template <class T, class Cmp>
+std::vector<T> *TopN<T, Cmp>::ExtractNondestructive() const {
+ auto out = new std::vector<T>;
+ ExtractNondestructive(out);
+ return out;
+}
+
+template <class T, class Cmp>
+void TopN<T, Cmp>::ExtractNondestructive(std::vector<T> *output) const {
+ CHECK(output);
+ *output = elements_;
+ if (state_ != HEAP_SORTED) {
+ std::sort(output->begin(), output->end(), cmp_);
+ } else {
+ output->pop_back();
+ std::sort_heap(output->begin(), output->end(), cmp_);
+ }
+}
+
+template <class T, class Cmp>
+std::vector<T> *TopN<T, Cmp>::ExtractUnsortedNondestructive() const {
+ auto elements = new std::vector<T>;
+ ExtractUnsortedNondestructive(elements);
+ return elements;
+}
+
+template <class T, class Cmp>
+void TopN<T, Cmp>::ExtractUnsortedNondestructive(std::vector<T> *output) const {
+ CHECK(output);
+ *output = elements_;
+ if (state_ == HEAP_SORTED) {
+ // Remove the limit_+1'th element.
+ output->pop_back();
+ }
+}
+
+template <class T, class Cmp>
+void TopN<T, Cmp>::Reset() {
+ elements_.clear();
+ state_ = UNORDERED;
+}
+
+} // namespace gtl
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_GTL_TOP_N_H_
diff --git a/tensorflow/core/lib/gtl/top_n_test.cc b/tensorflow/core/lib/gtl/top_n_test.cc
new file mode 100644
index 0000000000..1812a1bd3f
--- /dev/null
+++ b/tensorflow/core/lib/gtl/top_n_test.cc
@@ -0,0 +1,249 @@
+// Unit test for TopN.
+
+#include "tensorflow/core/lib/gtl/top_n.h"
+
+#include <string>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace {
+
+using tensorflow::gtl::TopN;
+using tensorflow::random::PhiloxRandom;
+using tensorflow::random::SimplePhilox;
+using tensorflow::string;
+
+// Move the contents from an owned raw pointer, returning by value.
+// Objects are easier to manage by value.
+template <class T>
+T ConsumeRawPtr(T *p) {
+ T tmp = std::move(*p);
+ delete p;
+ return tmp;
+}
+
+template <class Cmp>
+void TestIntTopNHelper(size_t limit, size_t n_elements, const Cmp &cmp,
+ SimplePhilox *random, bool test_peek,
+ bool test_extract_unsorted) {
+ LOG(INFO) << "Testing limit=" << limit << ", n_elements=" << n_elements
+ << ", test_peek=" << test_peek
+ << ", test_extract_unsorted=" << test_extract_unsorted;
+ TopN<int, Cmp> top(limit, cmp);
+ std::vector<int> shadow(n_elements);
+ for (int i = 0; i != n_elements; ++i) shadow[i] = random->Uniform(limit);
+ for (int e : shadow) top.push(e);
+ std::sort(shadow.begin(), shadow.end(), cmp);
+ size_t top_size = std::min(limit, n_elements);
+ EXPECT_EQ(top_size, top.size());
+ if (test_peek && top_size != 0) {
+ EXPECT_EQ(shadow[top_size - 1], top.peek_bottom());
+ }
+ std::vector<int> v;
+ if (test_extract_unsorted) {
+ v = ConsumeRawPtr(top.ExtractUnsorted());
+ std::sort(v.begin(), v.end(), cmp);
+ } else {
+ v = ConsumeRawPtr(top.Extract());
+ }
+ EXPECT_EQ(top_size, v.size());
+ for (int i = 0; i != top_size; ++i) {
+ VLOG(1) << "Top element " << v[i];
+ EXPECT_EQ(shadow[i], v[i]);
+ }
+}
+
+template <class Cmp>
+void TestIntTopN(size_t limit, size_t n_elements, const Cmp &cmp,
+ SimplePhilox *random) {
+ // Test peek_bottom() and Extract()
+ TestIntTopNHelper(limit, n_elements, cmp, random, true, false);
+ // Test Extract()
+ TestIntTopNHelper(limit, n_elements, cmp, random, false, false);
+ // Test peek_bottom() and ExtractUnsorted()
+ TestIntTopNHelper(limit, n_elements, cmp, random, true, true);
+ // Test ExtractUnsorted()
+ TestIntTopNHelper(limit, n_elements, cmp, random, false, true);
+}
+
+TEST(TopNTest, Misc) {
+ PhiloxRandom philox(1, 1);
+ SimplePhilox random(&philox);
+
+ TestIntTopN(0, 5, std::greater<int>(), &random);
+ TestIntTopN(32, 0, std::greater<int>(), &random);
+ TestIntTopN(6, 6, std::greater<int>(), &random);
+ TestIntTopN(6, 6, std::less<int>(), &random);
+ TestIntTopN(1000, 999, std::greater<int>(), &random);
+ TestIntTopN(1000, 1000, std::greater<int>(), &random);
+ TestIntTopN(1000, 1001, std::greater<int>(), &random);
+ TestIntTopN(2300, 28393, std::less<int>(), &random);
+ TestIntTopN(30, 100, std::greater<int>(), &random);
+ TestIntTopN(100, 30, std::less<int>(), &random);
+ TestIntTopN(size_t(-1), 3, std::greater<int>(), &random);
+ TestIntTopN(size_t(-1), 0, std::greater<int>(), &random);
+ TestIntTopN(0, 5, std::greater<int>(), &random);
+}
+
+TEST(TopNTest, String) {
+ LOG(INFO) << "Testing strings";
+
+ TopN<string> top(3);
+ EXPECT_TRUE(top.empty());
+ top.push("abracadabra");
+ top.push("waldemar");
+ EXPECT_EQ(2, top.size());
+ EXPECT_EQ("abracadabra", top.peek_bottom());
+ top.push("");
+ EXPECT_EQ(3, top.size());
+ EXPECT_EQ("", top.peek_bottom());
+ top.push("top");
+ EXPECT_EQ(3, top.size());
+ EXPECT_EQ("abracadabra", top.peek_bottom());
+ top.push("Google");
+ top.push("test");
+ EXPECT_EQ(3, top.size());
+ EXPECT_EQ("test", top.peek_bottom());
+ TopN<string> top2(top);
+ TopN<string> top3(5);
+ top3 = top;
+ EXPECT_EQ("test", top3.peek_bottom());
+ {
+ std::vector<string> s = ConsumeRawPtr(top.Extract());
+ EXPECT_EQ(s[0], "waldemar");
+ EXPECT_EQ(s[1], "top");
+ EXPECT_EQ(s[2], "test");
+ }
+
+ top2.push("zero");
+ EXPECT_EQ(top2.peek_bottom(), "top");
+
+ {
+ std::vector<string> s = ConsumeRawPtr(top2.Extract());
+ EXPECT_EQ(s[0], "zero");
+ EXPECT_EQ(s[1], "waldemar");
+ EXPECT_EQ(s[2], "top");
+ }
+ {
+ std::vector<string> s = ConsumeRawPtr(top3.Extract());
+ EXPECT_EQ(s[0], "waldemar");
+ EXPECT_EQ(s[1], "top");
+ EXPECT_EQ(s[2], "test");
+ }
+
+ TopN<string> top4(3);
+ // Run this test twice to check Reset():
+ for (int i = 0; i < 2; ++i) {
+ top4.push("abcd");
+ top4.push("ijkl");
+ top4.push("efgh");
+ top4.push("mnop");
+ std::vector<string> s = ConsumeRawPtr(top4.Extract());
+ EXPECT_EQ(s[0], "mnop");
+ EXPECT_EQ(s[1], "ijkl");
+ EXPECT_EQ(s[2], "efgh");
+ top4.Reset();
+ }
+}
+
+// Test that pointers aren't leaked from a TopN if we use the 2-argument version
+// of push().
+TEST(TopNTest, Ptr) {
+ LOG(INFO) << "Testing 2-argument push()";
+ TopN<string *> topn(3);
+ for (int i = 0; i < 8; ++i) {
+ string *dropped = NULL;
+ topn.push(new string(std::to_string(i)), &dropped);
+ delete dropped;
+ }
+
+ for (int i = 8; i > 0; --i) {
+ string *dropped = NULL;
+ topn.push(new string(std::to_string(i)), &dropped);
+ delete dropped;
+ }
+
+ std::vector<string *> extract = ConsumeRawPtr(topn.Extract());
+ tensorflow::gtl::STLDeleteElements(&extract);
+}
+
+struct PointeeGreater {
+ template <typename T>
+ bool operator()(const T &a, const T &b) const {
+ return *a > *b;
+ }
+};
+
+TEST(TopNTest, MoveOnly) {
+ using StrPtr = std::unique_ptr<string>;
+ TopN<StrPtr, PointeeGreater> topn(3);
+ for (int i = 0; i < 8; ++i) topn.push(StrPtr(new string(std::to_string(i))));
+ for (int i = 8; i > 0; --i) topn.push(StrPtr(new string(std::to_string(i))));
+
+ std::vector<StrPtr> extract = ConsumeRawPtr(topn.Extract());
+ EXPECT_EQ(extract.size(), 3);
+ EXPECT_EQ(*(extract[0]), "8");
+ EXPECT_EQ(*(extract[1]), "7");
+ EXPECT_EQ(*(extract[2]), "7");
+}
+
+// Test that Nondestructive extracts do not need a Reset() afterwards,
+// and that pointers aren't leaked from a TopN after calling them.
+TEST(TopNTest, Nondestructive) {
+ LOG(INFO) << "Testing Nondestructive extracts";
+ TopN<int> top4(4);
+ for (int i = 0; i < 8; ++i) {
+ top4.push(i);
+ std::vector<int> v = ConsumeRawPtr(top4.ExtractNondestructive());
+ EXPECT_EQ(std::min(i + 1, 4), v.size());
+ for (size_t j = 0; j < v.size(); ++j) EXPECT_EQ(i - j, v[j]);
+ }
+
+ TopN<int> top3(3);
+ for (int i = 0; i < 8; ++i) {
+ top3.push(i);
+ std::vector<int> v = ConsumeRawPtr(top3.ExtractUnsortedNondestructive());
+ std::sort(v.begin(), v.end(), std::greater<int>());
+ EXPECT_EQ(std::min(i + 1, 3), v.size());
+ for (size_t j = 0; j < v.size(); ++j) EXPECT_EQ(i - j, v[j]);
+ }
+}
+
+struct ForbiddenCmp {
+ bool operator()(int lhs, int rhs) const {
+ LOG(FATAL) << "ForbiddenCmp called " << lhs << " " << rhs;
+ }
+};
+
+TEST(TopNTest, ZeroLimit) {
+ TopN<int, ForbiddenCmp> top(0);
+ top.push(1);
+ top.push(2);
+
+ int dropped = -1;
+ top.push(1, &dropped);
+ top.push(2, &dropped);
+
+ std::vector<int> v;
+ top.ExtractNondestructive(&v);
+ EXPECT_EQ(0, v.size());
+}
+
+TEST(TopNTest, Iteration) {
+ TopN<int> top(4);
+ for (int i = 0; i < 8; ++i) top.push(i);
+ std::vector<int> actual(top.unsorted_begin(), top.unsorted_end());
+ // Check that we have 4,5,6,7 as the top 4 (in some order, so we sort)
+ sort(actual.begin(), actual.end());
+ EXPECT_EQ(actual.size(), 4);
+ EXPECT_EQ(actual[0], 4);
+ EXPECT_EQ(actual[1], 5);
+ EXPECT_EQ(actual[2], 6);
+ EXPECT_EQ(actual[3], 7);
+}
+} // namespace
diff --git a/tensorflow/core/lib/hash/crc32c.cc b/tensorflow/core/lib/hash/crc32c.cc
new file mode 100644
index 0000000000..3bef1cf78d
--- /dev/null
+++ b/tensorflow/core/lib/hash/crc32c.cc
@@ -0,0 +1,244 @@
+// A portable implementation of crc32c, optimized to handle
+// four bytes at a time.
+
+#include "tensorflow/core/lib/hash/crc32c.h"
+
+#include <stdint.h>
+#include "tensorflow/core/lib/core/coding.h"
+
+namespace tensorflow {
+namespace crc32c {
+
+static const uint32 table0_[256] = {
+ 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, 0xc79a971f, 0x35f1141c,
+ 0x26a1e7e8, 0xd4ca64eb, 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b,
+ 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, 0x105ec76f, 0xe235446c,
+ 0xf165b798, 0x030e349b, 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384,
+ 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, 0x5d1d08bf, 0xaf768bbc,
+ 0xbc267848, 0x4e4dfb4b, 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a,
+ 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, 0xaa64d611, 0x580f5512,
+ 0x4b5fa6e6, 0xb93425e5, 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa,
+ 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, 0xf779deae, 0x05125dad,
+ 0x1642ae59, 0xe4292d5a, 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a,
+ 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, 0x417b1dbc, 0xb3109ebf,
+ 0xa0406d4b, 0x522bee48, 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957,
+ 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, 0x0c38d26c, 0xfe53516f,
+ 0xed03a29b, 0x1f682198, 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927,
+ 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, 0xdbfc821c, 0x2997011f,
+ 0x3ac7f2eb, 0xc8ac71e8, 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7,
+ 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, 0xa65c047d, 0x5437877e,
+ 0x4767748a, 0xb50cf789, 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859,
+ 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, 0x7198540d, 0x83f3d70e,
+ 0x90a324fa, 0x62c8a7f9, 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6,
+ 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, 0x3cdb9bdd, 0xceb018de,
+ 0xdde0eb2a, 0x2f8b6829, 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c,
+ 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, 0x082f63b7, 0xfa44e0b4,
+ 0xe9141340, 0x1b7f9043, 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c,
+ 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, 0x55326b08, 0xa759e80b,
+ 0xb4091bff, 0x466298fc, 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c,
+ 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, 0xa24bb5a6, 0x502036a5,
+ 0x4370c551, 0xb11b4652, 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d,
+ 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, 0xef087a76, 0x1d63f975,
+ 0x0e330a81, 0xfc588982, 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d,
+ 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, 0x38cc2a06, 0xcaa7a905,
+ 0xd9f75af1, 0x2b9cd9f2, 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed,
+ 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, 0x0417b1db, 0xf67c32d8,
+ 0xe52cc12c, 0x1747422f, 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff,
+ 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, 0xd3d3e1ab, 0x21b862a8,
+ 0x32e8915c, 0xc083125f, 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540,
+ 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, 0x9e902e7b, 0x6cfbad78,
+ 0x7fab5e8c, 0x8dc0dd8f, 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee,
+ 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, 0x69e9f0d5, 0x9b8273d6,
+ 0x88d28022, 0x7ab90321, 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e,
+ 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, 0x34f4f86a, 0xc69f7b69,
+ 0xd5cf889d, 0x27a40b9e, 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e,
+ 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351};
+static const uint32 table1_[256] = {
+ 0x00000000, 0x13a29877, 0x274530ee, 0x34e7a899, 0x4e8a61dc, 0x5d28f9ab,
+ 0x69cf5132, 0x7a6dc945, 0x9d14c3b8, 0x8eb65bcf, 0xba51f356, 0xa9f36b21,
+ 0xd39ea264, 0xc03c3a13, 0xf4db928a, 0xe7790afd, 0x3fc5f181, 0x2c6769f6,
+ 0x1880c16f, 0x0b225918, 0x714f905d, 0x62ed082a, 0x560aa0b3, 0x45a838c4,
+ 0xa2d13239, 0xb173aa4e, 0x859402d7, 0x96369aa0, 0xec5b53e5, 0xfff9cb92,
+ 0xcb1e630b, 0xd8bcfb7c, 0x7f8be302, 0x6c297b75, 0x58ced3ec, 0x4b6c4b9b,
+ 0x310182de, 0x22a31aa9, 0x1644b230, 0x05e62a47, 0xe29f20ba, 0xf13db8cd,
+ 0xc5da1054, 0xd6788823, 0xac154166, 0xbfb7d911, 0x8b507188, 0x98f2e9ff,
+ 0x404e1283, 0x53ec8af4, 0x670b226d, 0x74a9ba1a, 0x0ec4735f, 0x1d66eb28,
+ 0x298143b1, 0x3a23dbc6, 0xdd5ad13b, 0xcef8494c, 0xfa1fe1d5, 0xe9bd79a2,
+ 0x93d0b0e7, 0x80722890, 0xb4958009, 0xa737187e, 0xff17c604, 0xecb55e73,
+ 0xd852f6ea, 0xcbf06e9d, 0xb19da7d8, 0xa23f3faf, 0x96d89736, 0x857a0f41,
+ 0x620305bc, 0x71a19dcb, 0x45463552, 0x56e4ad25, 0x2c896460, 0x3f2bfc17,
+ 0x0bcc548e, 0x186eccf9, 0xc0d23785, 0xd370aff2, 0xe797076b, 0xf4359f1c,
+ 0x8e585659, 0x9dface2e, 0xa91d66b7, 0xbabffec0, 0x5dc6f43d, 0x4e646c4a,
+ 0x7a83c4d3, 0x69215ca4, 0x134c95e1, 0x00ee0d96, 0x3409a50f, 0x27ab3d78,
+ 0x809c2506, 0x933ebd71, 0xa7d915e8, 0xb47b8d9f, 0xce1644da, 0xddb4dcad,
+ 0xe9537434, 0xfaf1ec43, 0x1d88e6be, 0x0e2a7ec9, 0x3acdd650, 0x296f4e27,
+ 0x53028762, 0x40a01f15, 0x7447b78c, 0x67e52ffb, 0xbf59d487, 0xacfb4cf0,
+ 0x981ce469, 0x8bbe7c1e, 0xf1d3b55b, 0xe2712d2c, 0xd69685b5, 0xc5341dc2,
+ 0x224d173f, 0x31ef8f48, 0x050827d1, 0x16aabfa6, 0x6cc776e3, 0x7f65ee94,
+ 0x4b82460d, 0x5820de7a, 0xfbc3faf9, 0xe861628e, 0xdc86ca17, 0xcf245260,
+ 0xb5499b25, 0xa6eb0352, 0x920cabcb, 0x81ae33bc, 0x66d73941, 0x7575a136,
+ 0x419209af, 0x523091d8, 0x285d589d, 0x3bffc0ea, 0x0f186873, 0x1cbaf004,
+ 0xc4060b78, 0xd7a4930f, 0xe3433b96, 0xf0e1a3e1, 0x8a8c6aa4, 0x992ef2d3,
+ 0xadc95a4a, 0xbe6bc23d, 0x5912c8c0, 0x4ab050b7, 0x7e57f82e, 0x6df56059,
+ 0x1798a91c, 0x043a316b, 0x30dd99f2, 0x237f0185, 0x844819fb, 0x97ea818c,
+ 0xa30d2915, 0xb0afb162, 0xcac27827, 0xd960e050, 0xed8748c9, 0xfe25d0be,
+ 0x195cda43, 0x0afe4234, 0x3e19eaad, 0x2dbb72da, 0x57d6bb9f, 0x447423e8,
+ 0x70938b71, 0x63311306, 0xbb8de87a, 0xa82f700d, 0x9cc8d894, 0x8f6a40e3,
+ 0xf50789a6, 0xe6a511d1, 0xd242b948, 0xc1e0213f, 0x26992bc2, 0x353bb3b5,
+ 0x01dc1b2c, 0x127e835b, 0x68134a1e, 0x7bb1d269, 0x4f567af0, 0x5cf4e287,
+ 0x04d43cfd, 0x1776a48a, 0x23910c13, 0x30339464, 0x4a5e5d21, 0x59fcc556,
+ 0x6d1b6dcf, 0x7eb9f5b8, 0x99c0ff45, 0x8a626732, 0xbe85cfab, 0xad2757dc,
+ 0xd74a9e99, 0xc4e806ee, 0xf00fae77, 0xe3ad3600, 0x3b11cd7c, 0x28b3550b,
+ 0x1c54fd92, 0x0ff665e5, 0x759baca0, 0x663934d7, 0x52de9c4e, 0x417c0439,
+ 0xa6050ec4, 0xb5a796b3, 0x81403e2a, 0x92e2a65d, 0xe88f6f18, 0xfb2df76f,
+ 0xcfca5ff6, 0xdc68c781, 0x7b5fdfff, 0x68fd4788, 0x5c1aef11, 0x4fb87766,
+ 0x35d5be23, 0x26772654, 0x12908ecd, 0x013216ba, 0xe64b1c47, 0xf5e98430,
+ 0xc10e2ca9, 0xd2acb4de, 0xa8c17d9b, 0xbb63e5ec, 0x8f844d75, 0x9c26d502,
+ 0x449a2e7e, 0x5738b609, 0x63df1e90, 0x707d86e7, 0x0a104fa2, 0x19b2d7d5,
+ 0x2d557f4c, 0x3ef7e73b, 0xd98eedc6, 0xca2c75b1, 0xfecbdd28, 0xed69455f,
+ 0x97048c1a, 0x84a6146d, 0xb041bcf4, 0xa3e32483};
+static const uint32 table2_[256] = {
+ 0x00000000, 0xa541927e, 0x4f6f520d, 0xea2ec073, 0x9edea41a, 0x3b9f3664,
+ 0xd1b1f617, 0x74f06469, 0x38513ec5, 0x9d10acbb, 0x773e6cc8, 0xd27ffeb6,
+ 0xa68f9adf, 0x03ce08a1, 0xe9e0c8d2, 0x4ca15aac, 0x70a27d8a, 0xd5e3eff4,
+ 0x3fcd2f87, 0x9a8cbdf9, 0xee7cd990, 0x4b3d4bee, 0xa1138b9d, 0x045219e3,
+ 0x48f3434f, 0xedb2d131, 0x079c1142, 0xa2dd833c, 0xd62de755, 0x736c752b,
+ 0x9942b558, 0x3c032726, 0xe144fb14, 0x4405696a, 0xae2ba919, 0x0b6a3b67,
+ 0x7f9a5f0e, 0xdadbcd70, 0x30f50d03, 0x95b49f7d, 0xd915c5d1, 0x7c5457af,
+ 0x967a97dc, 0x333b05a2, 0x47cb61cb, 0xe28af3b5, 0x08a433c6, 0xade5a1b8,
+ 0x91e6869e, 0x34a714e0, 0xde89d493, 0x7bc846ed, 0x0f382284, 0xaa79b0fa,
+ 0x40577089, 0xe516e2f7, 0xa9b7b85b, 0x0cf62a25, 0xe6d8ea56, 0x43997828,
+ 0x37691c41, 0x92288e3f, 0x78064e4c, 0xdd47dc32, 0xc76580d9, 0x622412a7,
+ 0x880ad2d4, 0x2d4b40aa, 0x59bb24c3, 0xfcfab6bd, 0x16d476ce, 0xb395e4b0,
+ 0xff34be1c, 0x5a752c62, 0xb05bec11, 0x151a7e6f, 0x61ea1a06, 0xc4ab8878,
+ 0x2e85480b, 0x8bc4da75, 0xb7c7fd53, 0x12866f2d, 0xf8a8af5e, 0x5de93d20,
+ 0x29195949, 0x8c58cb37, 0x66760b44, 0xc337993a, 0x8f96c396, 0x2ad751e8,
+ 0xc0f9919b, 0x65b803e5, 0x1148678c, 0xb409f5f2, 0x5e273581, 0xfb66a7ff,
+ 0x26217bcd, 0x8360e9b3, 0x694e29c0, 0xcc0fbbbe, 0xb8ffdfd7, 0x1dbe4da9,
+ 0xf7908dda, 0x52d11fa4, 0x1e704508, 0xbb31d776, 0x511f1705, 0xf45e857b,
+ 0x80aee112, 0x25ef736c, 0xcfc1b31f, 0x6a802161, 0x56830647, 0xf3c29439,
+ 0x19ec544a, 0xbcadc634, 0xc85da25d, 0x6d1c3023, 0x8732f050, 0x2273622e,
+ 0x6ed23882, 0xcb93aafc, 0x21bd6a8f, 0x84fcf8f1, 0xf00c9c98, 0x554d0ee6,
+ 0xbf63ce95, 0x1a225ceb, 0x8b277743, 0x2e66e53d, 0xc448254e, 0x6109b730,
+ 0x15f9d359, 0xb0b84127, 0x5a968154, 0xffd7132a, 0xb3764986, 0x1637dbf8,
+ 0xfc191b8b, 0x595889f5, 0x2da8ed9c, 0x88e97fe2, 0x62c7bf91, 0xc7862def,
+ 0xfb850ac9, 0x5ec498b7, 0xb4ea58c4, 0x11abcaba, 0x655baed3, 0xc01a3cad,
+ 0x2a34fcde, 0x8f756ea0, 0xc3d4340c, 0x6695a672, 0x8cbb6601, 0x29faf47f,
+ 0x5d0a9016, 0xf84b0268, 0x1265c21b, 0xb7245065, 0x6a638c57, 0xcf221e29,
+ 0x250cde5a, 0x804d4c24, 0xf4bd284d, 0x51fcba33, 0xbbd27a40, 0x1e93e83e,
+ 0x5232b292, 0xf77320ec, 0x1d5de09f, 0xb81c72e1, 0xccec1688, 0x69ad84f6,
+ 0x83834485, 0x26c2d6fb, 0x1ac1f1dd, 0xbf8063a3, 0x55aea3d0, 0xf0ef31ae,
+ 0x841f55c7, 0x215ec7b9, 0xcb7007ca, 0x6e3195b4, 0x2290cf18, 0x87d15d66,
+ 0x6dff9d15, 0xc8be0f6b, 0xbc4e6b02, 0x190ff97c, 0xf321390f, 0x5660ab71,
+ 0x4c42f79a, 0xe90365e4, 0x032da597, 0xa66c37e9, 0xd29c5380, 0x77ddc1fe,
+ 0x9df3018d, 0x38b293f3, 0x7413c95f, 0xd1525b21, 0x3b7c9b52, 0x9e3d092c,
+ 0xeacd6d45, 0x4f8cff3b, 0xa5a23f48, 0x00e3ad36, 0x3ce08a10, 0x99a1186e,
+ 0x738fd81d, 0xd6ce4a63, 0xa23e2e0a, 0x077fbc74, 0xed517c07, 0x4810ee79,
+ 0x04b1b4d5, 0xa1f026ab, 0x4bdee6d8, 0xee9f74a6, 0x9a6f10cf, 0x3f2e82b1,
+ 0xd50042c2, 0x7041d0bc, 0xad060c8e, 0x08479ef0, 0xe2695e83, 0x4728ccfd,
+ 0x33d8a894, 0x96993aea, 0x7cb7fa99, 0xd9f668e7, 0x9557324b, 0x3016a035,
+ 0xda386046, 0x7f79f238, 0x0b899651, 0xaec8042f, 0x44e6c45c, 0xe1a75622,
+ 0xdda47104, 0x78e5e37a, 0x92cb2309, 0x378ab177, 0x437ad51e, 0xe63b4760,
+ 0x0c158713, 0xa954156d, 0xe5f54fc1, 0x40b4ddbf, 0xaa9a1dcc, 0x0fdb8fb2,
+ 0x7b2bebdb, 0xde6a79a5, 0x3444b9d6, 0x91052ba8};
+static const uint32 table3_[256] = {
+ 0x00000000, 0xdd45aab8, 0xbf672381, 0x62228939, 0x7b2231f3, 0xa6679b4b,
+ 0xc4451272, 0x1900b8ca, 0xf64463e6, 0x2b01c95e, 0x49234067, 0x9466eadf,
+ 0x8d665215, 0x5023f8ad, 0x32017194, 0xef44db2c, 0xe964b13d, 0x34211b85,
+ 0x560392bc, 0x8b463804, 0x924680ce, 0x4f032a76, 0x2d21a34f, 0xf06409f7,
+ 0x1f20d2db, 0xc2657863, 0xa047f15a, 0x7d025be2, 0x6402e328, 0xb9474990,
+ 0xdb65c0a9, 0x06206a11, 0xd725148b, 0x0a60be33, 0x6842370a, 0xb5079db2,
+ 0xac072578, 0x71428fc0, 0x136006f9, 0xce25ac41, 0x2161776d, 0xfc24ddd5,
+ 0x9e0654ec, 0x4343fe54, 0x5a43469e, 0x8706ec26, 0xe524651f, 0x3861cfa7,
+ 0x3e41a5b6, 0xe3040f0e, 0x81268637, 0x5c632c8f, 0x45639445, 0x98263efd,
+ 0xfa04b7c4, 0x27411d7c, 0xc805c650, 0x15406ce8, 0x7762e5d1, 0xaa274f69,
+ 0xb327f7a3, 0x6e625d1b, 0x0c40d422, 0xd1057e9a, 0xaba65fe7, 0x76e3f55f,
+ 0x14c17c66, 0xc984d6de, 0xd0846e14, 0x0dc1c4ac, 0x6fe34d95, 0xb2a6e72d,
+ 0x5de23c01, 0x80a796b9, 0xe2851f80, 0x3fc0b538, 0x26c00df2, 0xfb85a74a,
+ 0x99a72e73, 0x44e284cb, 0x42c2eeda, 0x9f874462, 0xfda5cd5b, 0x20e067e3,
+ 0x39e0df29, 0xe4a57591, 0x8687fca8, 0x5bc25610, 0xb4868d3c, 0x69c32784,
+ 0x0be1aebd, 0xd6a40405, 0xcfa4bccf, 0x12e11677, 0x70c39f4e, 0xad8635f6,
+ 0x7c834b6c, 0xa1c6e1d4, 0xc3e468ed, 0x1ea1c255, 0x07a17a9f, 0xdae4d027,
+ 0xb8c6591e, 0x6583f3a6, 0x8ac7288a, 0x57828232, 0x35a00b0b, 0xe8e5a1b3,
+ 0xf1e51979, 0x2ca0b3c1, 0x4e823af8, 0x93c79040, 0x95e7fa51, 0x48a250e9,
+ 0x2a80d9d0, 0xf7c57368, 0xeec5cba2, 0x3380611a, 0x51a2e823, 0x8ce7429b,
+ 0x63a399b7, 0xbee6330f, 0xdcc4ba36, 0x0181108e, 0x1881a844, 0xc5c402fc,
+ 0xa7e68bc5, 0x7aa3217d, 0x52a0c93f, 0x8fe56387, 0xedc7eabe, 0x30824006,
+ 0x2982f8cc, 0xf4c75274, 0x96e5db4d, 0x4ba071f5, 0xa4e4aad9, 0x79a10061,
+ 0x1b838958, 0xc6c623e0, 0xdfc69b2a, 0x02833192, 0x60a1b8ab, 0xbde41213,
+ 0xbbc47802, 0x6681d2ba, 0x04a35b83, 0xd9e6f13b, 0xc0e649f1, 0x1da3e349,
+ 0x7f816a70, 0xa2c4c0c8, 0x4d801be4, 0x90c5b15c, 0xf2e73865, 0x2fa292dd,
+ 0x36a22a17, 0xebe780af, 0x89c50996, 0x5480a32e, 0x8585ddb4, 0x58c0770c,
+ 0x3ae2fe35, 0xe7a7548d, 0xfea7ec47, 0x23e246ff, 0x41c0cfc6, 0x9c85657e,
+ 0x73c1be52, 0xae8414ea, 0xcca69dd3, 0x11e3376b, 0x08e38fa1, 0xd5a62519,
+ 0xb784ac20, 0x6ac10698, 0x6ce16c89, 0xb1a4c631, 0xd3864f08, 0x0ec3e5b0,
+ 0x17c35d7a, 0xca86f7c2, 0xa8a47efb, 0x75e1d443, 0x9aa50f6f, 0x47e0a5d7,
+ 0x25c22cee, 0xf8878656, 0xe1873e9c, 0x3cc29424, 0x5ee01d1d, 0x83a5b7a5,
+ 0xf90696d8, 0x24433c60, 0x4661b559, 0x9b241fe1, 0x8224a72b, 0x5f610d93,
+ 0x3d4384aa, 0xe0062e12, 0x0f42f53e, 0xd2075f86, 0xb025d6bf, 0x6d607c07,
+ 0x7460c4cd, 0xa9256e75, 0xcb07e74c, 0x16424df4, 0x106227e5, 0xcd278d5d,
+ 0xaf050464, 0x7240aedc, 0x6b401616, 0xb605bcae, 0xd4273597, 0x09629f2f,
+ 0xe6264403, 0x3b63eebb, 0x59416782, 0x8404cd3a, 0x9d0475f0, 0x4041df48,
+ 0x22635671, 0xff26fcc9, 0x2e238253, 0xf36628eb, 0x9144a1d2, 0x4c010b6a,
+ 0x5501b3a0, 0x88441918, 0xea669021, 0x37233a99, 0xd867e1b5, 0x05224b0d,
+ 0x6700c234, 0xba45688c, 0xa345d046, 0x7e007afe, 0x1c22f3c7, 0xc167597f,
+ 0xc747336e, 0x1a0299d6, 0x782010ef, 0xa565ba57, 0xbc65029d, 0x6120a825,
+ 0x0302211c, 0xde478ba4, 0x31035088, 0xec46fa30, 0x8e647309, 0x5321d9b1,
+ 0x4a21617b, 0x9764cbc3, 0xf54642fa, 0x2803e842};
+
+// Used to fetch a naturally-aligned 32-bit word in little endian byte-order
+static inline uint32_t LE_LOAD32(const uint8_t *p) {
+ return core::DecodeFixed32(reinterpret_cast<const char *>(p));
+}
+
+uint32 Extend(uint32 crc, const char *buf, size_t size) {
+ const uint8 *p = reinterpret_cast<const uint8 *>(buf);
+ const uint8 *e = p + size;
+ uint32 l = crc ^ 0xffffffffu;
+
+#define STEP1 \
+ do { \
+ int c = (l & 0xff) ^ *p++; \
+ l = table0_[c] ^ (l >> 8); \
+ } while (0)
+
+#define STEP4 \
+ do { \
+ uint32 c = l ^ LE_LOAD32(p); \
+ p += 4; \
+ l = table3_[c & 0xff] ^ table2_[(c >> 8) & 0xff] ^ \
+ table1_[(c >> 16) & 0xff] ^ table0_[c >> 24]; \
+ } while (0)
+
+ // Point x at first 4-byte aligned byte in string. This might be
+ // just past the end of the string.
+ const uintptr_t pval = reinterpret_cast<uintptr_t>(p);
+ const uint8 *x = reinterpret_cast<const uint8 *>(((pval + 3) >> 2) << 2);
+ if (x <= e) {
+ // Process bytes until finished or p is 4-byte aligned
+ while (p != x) {
+ STEP1;
+ }
+ }
+ // Process bytes 16 at a time
+ while ((e - p) >= 16) {
+ STEP4;
+ STEP4;
+ STEP4;
+ STEP4;
+ }
+ // Process bytes 4 at a time
+ while ((e - p) >= 4) {
+ STEP4;
+ }
+ // Process the last few bytes
+ while (p != e) {
+ STEP1;
+ }
+#undef STEP4
+#undef STEP1
+ return l ^ 0xffffffffu;
+}
+
+} // namespace crc32c
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/hash/crc32c.h b/tensorflow/core/lib/hash/crc32c.h
new file mode 100644
index 0000000000..f728b6f5e7
--- /dev/null
+++ b/tensorflow/core/lib/hash/crc32c.h
@@ -0,0 +1,39 @@
+#ifndef TENSORFLOW_LIB_HASH_CRC32C_H_
+#define TENSORFLOW_LIB_HASH_CRC32C_H_
+
+#include <stddef.h>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace crc32c {
+
+// Return the crc32c of concat(A, data[0,n-1]) where init_crc is the
+// crc32c of some string A. Extend() is often used to maintain the
+// crc32c of a stream of data.
+extern uint32 Extend(uint32 init_crc, const char* data, size_t n);
+
+// Return the crc32c of data[0,n-1]
+inline uint32 Value(const char* data, size_t n) { return Extend(0, data, n); }
+
+static const uint32 kMaskDelta = 0xa282ead8ul;
+
+// Return a masked representation of crc.
+//
+// Motivation: it is problematic to compute the CRC of a string that
+// contains embedded CRCs. Therefore we recommend that CRCs stored
+// somewhere (e.g., in files) should be masked before being stored.
+inline uint32 Mask(uint32 crc) {
+ // Rotate right by 15 bits and add a constant.
+ return ((crc >> 15) | (crc << 17)) + kMaskDelta;
+}
+
+// Return the crc whose masked representation is masked_crc.
+inline uint32 Unmask(uint32 masked_crc) {
+ uint32 rot = masked_crc - kMaskDelta;
+ return ((rot >> 17) | (rot << 15));
+}
+
+} // namespace crc32c
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_HASH_CRC32C_H_
diff --git a/tensorflow/core/lib/hash/crc32c_test.cc b/tensorflow/core/lib/hash/crc32c_test.cc
new file mode 100644
index 0000000000..54aced3186
--- /dev/null
+++ b/tensorflow/core/lib/hash/crc32c_test.cc
@@ -0,0 +1,51 @@
+#include "tensorflow/core/lib/hash/crc32c.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace crc32c {
+
+TEST(CRC, StandardResults) {
+ // From rfc3720 section B.4.
+ char buf[32];
+
+ memset(buf, 0, sizeof(buf));
+ ASSERT_EQ(0x8a9136aa, Value(buf, sizeof(buf)));
+
+ memset(buf, 0xff, sizeof(buf));
+ ASSERT_EQ(0x62a8ab43, Value(buf, sizeof(buf)));
+
+ for (int i = 0; i < 32; i++) {
+ buf[i] = i;
+ }
+ ASSERT_EQ(0x46dd794e, Value(buf, sizeof(buf)));
+
+ for (int i = 0; i < 32; i++) {
+ buf[i] = 31 - i;
+ }
+ ASSERT_EQ(0x113fdb5c, Value(buf, sizeof(buf)));
+
+ unsigned char data[48] = {
+ 0x01, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00,
+ 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x28, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ };
+ ASSERT_EQ(0xd9963a56, Value(reinterpret_cast<char*>(data), sizeof(data)));
+}
+
+TEST(CRC, Values) { ASSERT_NE(Value("a", 1), Value("foo", 3)); }
+
+TEST(CRC, Extend) {
+ ASSERT_EQ(Value("hello world", 11), Extend(Value("hello ", 6), "world", 5));
+}
+
+TEST(CRC, Mask) {
+ uint32 crc = Value("foo", 3);
+ ASSERT_NE(crc, Mask(crc));
+ ASSERT_NE(crc, Mask(Mask(crc)));
+ ASSERT_EQ(crc, Unmask(Mask(crc)));
+ ASSERT_EQ(crc, Unmask(Unmask(Mask(Mask(crc)))));
+}
+
+} // namespace crc32c
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/hash/hash.cc b/tensorflow/core/lib/hash/hash.cc
new file mode 100644
index 0000000000..075d252412
--- /dev/null
+++ b/tensorflow/core/lib/hash/hash.cc
@@ -0,0 +1,113 @@
+#include "tensorflow/core/lib/hash/hash.h"
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/raw_coding.h"
+
+#include <string.h>
+
+namespace tensorflow {
+
+// 0xff is in case char is signed.
+static inline uint32 ByteAs32(char c) { return static_cast<uint32>(c) & 0xff; }
+static inline uint64 ByteAs64(char c) { return static_cast<uint64>(c) & 0xff; }
+
+uint32 Hash32(const char* data, size_t n, uint32 seed) {
+ // 'm' and 'r' are mixing constants generated offline.
+ // They're not really 'magic', they just happen to work well.
+
+ const uint32 m = 0x5bd1e995;
+ const int r = 24;
+
+ // Initialize the hash to a 'random' value
+ uint32 h = seed ^ n;
+
+ // Mix 4 bytes at a time into the hash
+ while (n >= 4) {
+ uint32 k = core::DecodeFixed32(data);
+
+ k *= m;
+ k ^= k >> r;
+ k *= m;
+
+ h *= m;
+ h ^= k;
+
+ data += 4;
+ n -= 4;
+ }
+
+ // Handle the last few bytes of the input array
+
+ switch (n) {
+ case 3:
+ h ^= ByteAs32(data[2]) << 16;
+ TF_FALLTHROUGH_INTENDED;
+ case 2:
+ h ^= ByteAs32(data[1]) << 8;
+ TF_FALLTHROUGH_INTENDED;
+ case 1:
+ h ^= ByteAs32(data[0]);
+ h *= m;
+ }
+
+ // Do a few final mixes of the hash to ensure the last few
+ // bytes are well-incorporated.
+
+ h ^= h >> 13;
+ h *= m;
+ h ^= h >> 15;
+
+ return h;
+}
+
+uint64 Hash64(const char* data, size_t n, uint64 seed) {
+ const uint64 m = 0xc6a4a7935bd1e995;
+ const int r = 47;
+
+ uint64 h = seed ^ (n * m);
+
+ while (n >= 8) {
+ uint64 k = core::DecodeFixed64(data);
+ data += 8;
+ n -= 8;
+
+ k *= m;
+ k ^= k >> r;
+ k *= m;
+
+ h ^= k;
+ h *= m;
+ }
+
+ switch (n) {
+ case 7:
+ h ^= ByteAs64(data[6]) << 48;
+ TF_FALLTHROUGH_INTENDED;
+ case 6:
+ h ^= ByteAs64(data[5]) << 40;
+ TF_FALLTHROUGH_INTENDED;
+ case 5:
+ h ^= ByteAs64(data[4]) << 32;
+ TF_FALLTHROUGH_INTENDED;
+ case 4:
+ h ^= ByteAs64(data[3]) << 24;
+ TF_FALLTHROUGH_INTENDED;
+ case 3:
+ h ^= ByteAs64(data[2]) << 16;
+ TF_FALLTHROUGH_INTENDED;
+ case 2:
+ h ^= ByteAs64(data[1]) << 8;
+ TF_FALLTHROUGH_INTENDED;
+ case 1:
+ h ^= ByteAs64(data[0]);
+ h *= m;
+ }
+
+ h ^= h >> r;
+ h *= m;
+ h ^= h >> r;
+
+ return h;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/hash/hash.h b/tensorflow/core/lib/hash/hash.h
new file mode 100644
index 0000000000..af56218fed
--- /dev/null
+++ b/tensorflow/core/lib/hash/hash.h
@@ -0,0 +1,28 @@
+// Simple hash functions used for internal data structures
+
+#ifndef TENSORFLOW_LIB_HASH_HASH_H_
+#define TENSORFLOW_LIB_HASH_HASH_H_
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <string>
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+extern uint32 Hash32(const char* data, size_t n, uint32 seed);
+extern uint64 Hash64(const char* data, size_t n, uint64 seed);
+
+inline uint64 Hash64(const char* data, size_t n) {
+ return Hash64(data, n, 0xDECAFCAFFE);
+}
+
+inline uint64 Hash64(const string& str) {
+ return Hash64(str.data(), str.size());
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_HASH_HASH_H_
diff --git a/tensorflow/core/lib/hash/hash_test.cc b/tensorflow/core/lib/hash/hash_test.cc
new file mode 100644
index 0000000000..9d3b970f3b
--- /dev/null
+++ b/tensorflow/core/lib/hash/hash_test.cc
@@ -0,0 +1,64 @@
+#include <vector>
+
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TEST(Hash, SignedUnsignedIssue) {
+ const unsigned char d1[1] = {0x62};
+ const unsigned char d2[2] = {0xc3, 0x97};
+ const unsigned char d3[3] = {0xe2, 0x99, 0xa5};
+ const unsigned char d4[4] = {0xe1, 0x80, 0xb9, 0x32};
+ const unsigned char d5[48] = {
+ 0x01, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00,
+ 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x18, 0x28, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ };
+
+ struct Case {
+ uint32 hash32;
+ uint64 hash64;
+ const unsigned char* data;
+ size_t size;
+ uint32 seed;
+ };
+
+ for (Case c : std::vector<Case>{
+ {0x471a8188u, 0x4c61ea3eeda4cb87ull, nullptr, 0, 0xbc9f1d34},
+ {0xd615eba5u, 0x091309f7ef916c8aull, d1, sizeof(d1), 0xbc9f1d34},
+ {0x0c3cccdau, 0xa815bcdf1d1af01cull, d2, sizeof(d2), 0xbc9f1d34},
+ {0x3ba37e0eu, 0x02167564e4d06430ull, d3, sizeof(d3), 0xbc9f1d34},
+ {0x16174eb3u, 0x8f7ed82ffc21071full, d4, sizeof(d4), 0xbc9f1d34},
+ {0x98b1926cu, 0xce196580c97aff1eull, d5, sizeof(d5), 0x12345678},
+ }) {
+ EXPECT_EQ(c.hash32,
+ Hash32(reinterpret_cast<const char*>(c.data), c.size, c.seed));
+ EXPECT_EQ(c.hash64,
+ Hash64(reinterpret_cast<const char*>(c.data), c.size, c.seed));
+
+ // Check hashes with inputs aligned differently.
+ for (int align = 1; align <= 7; align++) {
+ std::string input(align, 'x');
+ input.append(reinterpret_cast<const char*>(c.data), c.size);
+ EXPECT_EQ(c.hash32, Hash32(&input[align], c.size, c.seed));
+ EXPECT_EQ(c.hash64, Hash64(&input[align], c.size, c.seed));
+ }
+ }
+}
+
+static void BM_Hash32(int iters, int len) {
+ std::string input(len, 'x');
+ uint32 h = 0;
+ for (int i = 0; i < iters; i++) {
+ h = Hash32(input.data(), len, 1);
+ }
+ testing::BytesProcessed(static_cast<int64>(iters) * len);
+ VLOG(1) << h;
+}
+BENCHMARK(BM_Hash32)->Range(1, 1024);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/histogram/histogram.cc b/tensorflow/core/lib/histogram/histogram.cc
new file mode 100644
index 0000000000..4c29d687b7
--- /dev/null
+++ b/tensorflow/core/lib/histogram/histogram.cc
@@ -0,0 +1,247 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "tensorflow/core/lib/histogram/histogram.h"
+#include <float.h>
+#include <math.h>
+#include "tensorflow/core/framework/summary.pb.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+namespace tensorflow {
+namespace histogram {
+
+static std::vector<double>* InitDefaultBucketsInner() {
+ std::vector<double> buckets;
+ std::vector<double> neg_buckets;
+ // Make buckets whose range grows by 10% starting at 1.0e-12 up to 1.0e20
+ double v = 1.0e-12;
+ while (v < 1.0e20) {
+ buckets.push_back(v);
+ neg_buckets.push_back(-v);
+ v *= 1.1;
+ }
+ buckets.push_back(DBL_MAX);
+ neg_buckets.push_back(-DBL_MAX);
+ std::reverse(neg_buckets.begin(), neg_buckets.end());
+ std::vector<double>* result = new std::vector<double>;
+ result->insert(result->end(), neg_buckets.begin(), neg_buckets.end());
+ result->push_back(0.0);
+ result->insert(result->end(), buckets.begin(), buckets.end());
+ return result;
+}
+
+static gtl::ArraySlice<double> InitDefaultBuckets() {
+ static std::vector<double>* default_bucket_limits = InitDefaultBucketsInner();
+ return *default_bucket_limits;
+}
+
+Histogram::Histogram() : bucket_limits_(InitDefaultBuckets()) { Clear(); }
+
+// Create a histogram with a custom set of bucket limits,
+// specified in "custom_buckets[0..custom_buckets.size()-1]"
+Histogram::Histogram(gtl::ArraySlice<double> custom_bucket_limits)
+ : custom_bucket_limits_(custom_bucket_limits.begin(),
+ custom_bucket_limits.end()),
+ bucket_limits_(custom_bucket_limits_) {
+#ifndef NDEBUG
+ DCHECK_GT(bucket_limits_.size(), 0);
+ // Verify that the bucket boundaries are strictly increasing
+ for (size_t i = 1; i < bucket_limits_.size(); i++) {
+ DCHECK_GT(bucket_limits_[i], bucket_limits_[i - 1]);
+ }
+#endif
+ Clear();
+}
+
+bool Histogram::DecodeFromProto(const HistogramProto& proto) {
+ if ((proto.bucket_size() != proto.bucket_limit_size()) ||
+ (proto.bucket_size() == 0)) {
+ return false;
+ }
+ min_ = proto.min();
+ max_ = proto.max();
+ num_ = proto.num();
+ sum_ = proto.sum();
+ sum_squares_ = proto.sum_squares();
+ custom_bucket_limits_.clear();
+ custom_bucket_limits_.insert(custom_bucket_limits_.end(),
+ proto.bucket_limit().begin(),
+ proto.bucket_limit().end());
+ bucket_limits_ = custom_bucket_limits_;
+ buckets_.clear();
+ buckets_.insert(buckets_.end(), proto.bucket().begin(), proto.bucket().end());
+ return true;
+}
+
+void Histogram::Clear() {
+ min_ = bucket_limits_[bucket_limits_.size() - 1];
+ max_ = -DBL_MAX;
+ num_ = 0;
+ sum_ = 0;
+ sum_squares_ = 0;
+ buckets_.resize(bucket_limits_.size());
+ for (size_t i = 0; i < bucket_limits_.size(); i++) {
+ buckets_[i] = 0;
+ }
+}
+
+void Histogram::Add(double value) {
+ int b =
+ std::upper_bound(bucket_limits_.begin(), bucket_limits_.end(), value) -
+ bucket_limits_.begin();
+
+ buckets_[b] += 1.0;
+ if (min_ > value) min_ = value;
+ if (max_ < value) max_ = value;
+ num_++;
+ sum_ += value;
+ sum_squares_ += (value * value);
+}
+
+double Histogram::Median() const { return Percentile(50.0); }
+
+double Histogram::Percentile(double p) const {
+ if (num_ == 0.0) return 0.0;
+ double threshold = num_ * (p / 100.0);
+ double sum = 0;
+ for (size_t b = 0; b < buckets_.size(); b++) {
+ sum += buckets_[b];
+ if (sum >= threshold) {
+ // Scale linearly within this bucket
+ double left_point = (b == 0) ? min_ : bucket_limits_[b - 1];
+ double right_point = bucket_limits_[b];
+ double left_sum = sum - buckets_[b];
+ double right_sum = sum;
+ double pos = (threshold - left_sum) / (right_sum - left_sum);
+ double r = left_point + (right_point - left_point) * pos;
+ if (r < min_) r = min_;
+ if (r > max_) r = max_;
+ return r;
+ }
+ }
+ return max_;
+}
+
+double Histogram::Average() const {
+ if (num_ == 0.0) return 0;
+ return sum_ / num_;
+}
+
+double Histogram::StandardDeviation() const {
+ if (num_ == 0.0) return 0;
+ double variance = (sum_squares_ * num_ - sum_ * sum_) / (num_ * num_);
+ return sqrt(variance);
+}
+
+std::string Histogram::ToString() const {
+ std::string r;
+ char buf[200];
+ snprintf(buf, sizeof(buf), "Count: %.0f Average: %.4f StdDev: %.2f\n", num_,
+ Average(), StandardDeviation());
+ r.append(buf);
+ snprintf(buf, sizeof(buf), "Min: %.4f Median: %.4f Max: %.4f\n",
+ (num_ == 0.0 ? 0.0 : min_), Median(), max_);
+ r.append(buf);
+ r.append("------------------------------------------------------\n");
+ const double mult = num_ > 0 ? 100.0 / num_ : 0.0;
+ double sum = 0;
+ for (size_t b = 0; b < buckets_.size(); b++) {
+ if (buckets_[b] <= 0.0) continue;
+ sum += buckets_[b];
+ snprintf(buf, sizeof(buf), "[ %10.2g, %10.2g ) %7.0f %7.3f%% %7.3f%% ",
+ ((b == 0) ? -DBL_MAX : bucket_limits_[b - 1]), // left
+ bucket_limits_[b], // right
+ buckets_[b], // count
+ mult * buckets_[b], // percentage
+ mult * sum); // cum percentage
+ r.append(buf);
+
+ // Add hash marks based on percentage; 20 marks for 100%.
+ int marks = static_cast<int>(20 * (buckets_[b] / num_) + 0.5);
+ r.append(marks, '#');
+ r.push_back('\n');
+ }
+ return r;
+}
+
+void Histogram::EncodeToProto(HistogramProto* proto,
+ bool preserve_zero_buckets) const {
+ proto->Clear();
+ proto->set_min(min_);
+ proto->set_max(max_);
+ proto->set_num(num_);
+ proto->set_sum(sum_);
+ proto->set_sum_squares(sum_squares_);
+ for (size_t i = 0; i < buckets_.size();) {
+ double end = bucket_limits_[i];
+ double count = buckets_[i];
+ i++;
+ if (!preserve_zero_buckets && count <= 0.0) {
+ // Find run of empty buckets and collapse them into one
+ while (i < buckets_.size() && buckets_[i] <= 0.0) {
+ end = bucket_limits_[i];
+ count = buckets_[i];
+ i++;
+ }
+ }
+ proto->add_bucket_limit(end);
+ proto->add_bucket(count);
+ }
+ if (proto->bucket_size() == 0.0) {
+ // It's easier when we restore if we always have at least one bucket entry
+ proto->add_bucket_limit(DBL_MAX);
+ proto->add_bucket(0.0);
+ }
+}
+
+// ThreadSafeHistogram implementation.
+bool ThreadSafeHistogram::DecodeFromProto(const HistogramProto& proto) {
+ mutex_lock l(mu_);
+ return histogram_.DecodeFromProto(proto);
+}
+
+void ThreadSafeHistogram::Clear() {
+ mutex_lock l(mu_);
+ histogram_.Clear();
+}
+
+void ThreadSafeHistogram::Add(double value) {
+ mutex_lock l(mu_);
+ histogram_.Add(value);
+}
+
+void ThreadSafeHistogram::EncodeToProto(HistogramProto* proto,
+ bool preserve_zero_buckets) const {
+ mutex_lock l(mu_);
+ histogram_.EncodeToProto(proto, preserve_zero_buckets);
+}
+
+double ThreadSafeHistogram::Median() const {
+ mutex_lock l(mu_);
+ return histogram_.Median();
+}
+
+double ThreadSafeHistogram::Percentile(double p) const {
+ mutex_lock l(mu_);
+ return histogram_.Percentile(p);
+}
+
+double ThreadSafeHistogram::Average() const {
+ mutex_lock l(mu_);
+ return histogram_.Average();
+}
+
+double ThreadSafeHistogram::StandardDeviation() const {
+ mutex_lock l(mu_);
+ return histogram_.StandardDeviation();
+}
+
+std::string ThreadSafeHistogram::ToString() const {
+ mutex_lock l(mu_);
+ return histogram_.ToString();
+}
+
+} // namespace histogram
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/histogram/histogram.h b/tensorflow/core/lib/histogram/histogram.h
new file mode 100644
index 0000000000..9b655f3acb
--- /dev/null
+++ b/tensorflow/core/lib/histogram/histogram.h
@@ -0,0 +1,119 @@
+#ifndef TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_
+#define TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_
+
+#include <string>
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace tensorflow {
+
+class HistogramProto;
+
+namespace histogram {
+
+class Histogram {
+ public:
+ // Create a histogram with a default set of bucket boundaries.
+ // Buckets near zero cover very small ranges (e.g. 10^-12), and each
+ // bucket range grows by ~10% as we head away from zero. The
+ // buckets cover the range from -DBL_MAX to DBL_MAX.
+ Histogram();
+
+ // Create a histogram with a custom set of bucket boundaries,
+ // specified in "custom_bucket_limits[0..custom_bucket_limits.size()-1]"
+ // REQUIRES: custom_bucket_limits[i] values are monotonically increasing.
+ // REQUIRES: custom_bucket_limits is not empty()
+ explicit Histogram(gtl::ArraySlice<double> custom_bucket_limits);
+
+ // Restore the state of a histogram that was previously encoded
+ // via Histogram::EncodeToProto. Note that only the bucket boundaries
+ // generated by EncodeToProto will be restored.
+ bool DecodeFromProto(const HistogramProto& proto);
+
+ ~Histogram() {}
+
+ void Clear();
+ void Add(double value);
+
+ // Save the current state of the histogram to "*proto". If
+ // "preserve_zero_buckets" is false, only non-zero bucket values and
+ // ranges are saved, and the bucket boundaries of zero-valued buckets
+ // are lost.
+ void EncodeToProto(HistogramProto* proto, bool preserve_zero_buckets) const;
+
+ // Return the median of the values in the histogram
+ double Median() const;
+
+ // Return the "p"th percentile [0.0..100.0] of the values in the
+ // distribution
+ double Percentile(double p) const;
+
+ // Return the average value of the distribution
+ double Average() const;
+
+ // Return the standard deviation of values in the distribution
+ double StandardDeviation() const;
+
+ // Returns a multi-line human-readable string representing the histogram
+ // contents. Example output:
+ // Count: 4 Average: 251.7475 StdDev: 432.02
+ // Min: -3.0000 Median: 5.0000 Max: 1000.0000
+ // ------------------------------------------------------
+ // [ -5, 0 ) 1 25.000% 25.000% #####
+ // [ 0, 5 ) 1 25.000% 50.000% #####
+ // [ 5, 10 ) 1 25.000% 75.000% #####
+ // [ 1000, 10000 ) 1 25.000% 100.000% #####
+ std::string ToString() const;
+
+ private:
+ double min_;
+ double max_;
+ double num_;
+ double sum_;
+ double sum_squares_;
+
+ std::vector<double> custom_bucket_limits_;
+ gtl::ArraySlice<double> bucket_limits_;
+ std::vector<double> buckets_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Histogram);
+};
+
+// Wrapper around a Histogram object that is thread safe.
+//
+// All methods hold a lock while delegating to a Histogram object owned by the
+// ThreadSafeHistogram instance.
+//
+// See Histogram for documentation of the methods.
+class ThreadSafeHistogram {
+ public:
+ ThreadSafeHistogram() {}
+ explicit ThreadSafeHistogram(gtl::ArraySlice<double> custom_bucket_limits)
+ : histogram_(custom_bucket_limits) {}
+ bool DecodeFromProto(const HistogramProto& proto);
+
+ ~ThreadSafeHistogram() {}
+
+ void Clear();
+
+ // TODO(mdevin): It might be a good idea to provide a AddN(<many values>)
+ // method to avoid grabbing/releasing the lock when adding many values.
+ void Add(double value);
+
+ void EncodeToProto(HistogramProto* proto, bool preserve_zero_buckets) const;
+ double Median() const;
+ double Percentile(double p) const;
+ double Average() const;
+ double StandardDeviation() const;
+ std::string ToString() const;
+
+ private:
+ mutable mutex mu_;
+ Histogram histogram_ GUARDED_BY(mu_);
+};
+
+} // namespace histogram
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_
diff --git a/tensorflow/core/lib/histogram/histogram_test.cc b/tensorflow/core/lib/histogram/histogram_test.cc
new file mode 100644
index 0000000000..ede44fe85b
--- /dev/null
+++ b/tensorflow/core/lib/histogram/histogram_test.cc
@@ -0,0 +1,112 @@
+#include "tensorflow/core/lib/histogram/histogram.h"
+#include <float.h>
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/framework/summary.pb.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace histogram {
+
+static void Validate(const Histogram& h) {
+ string s1 = h.ToString();
+ LOG(ERROR) << s1;
+
+ HistogramProto proto_with_zeroes;
+ h.EncodeToProto(&proto_with_zeroes, true);
+ Histogram h2;
+ EXPECT_TRUE(h2.DecodeFromProto(proto_with_zeroes));
+ string s2 = h2.ToString();
+ LOG(ERROR) << s2;
+
+ EXPECT_EQ(s1, s2);
+
+ HistogramProto proto_no_zeroes;
+ h.EncodeToProto(&proto_no_zeroes, false);
+ LOG(ERROR) << proto_no_zeroes.DebugString();
+ Histogram h3;
+ EXPECT_TRUE(h3.DecodeFromProto(proto_no_zeroes));
+ string s3 = h3.ToString();
+ LOG(ERROR) << s3;
+
+ EXPECT_EQ(s1, s3);
+}
+
+TEST(Histogram, Empty) {
+ Histogram h;
+ Validate(h);
+}
+
+TEST(Histogram, SingleValue) {
+ Histogram h;
+ h.Add(-3.0);
+ Validate(h);
+}
+
+TEST(Histogram, CustomBuckets) {
+ Histogram h({-10, -5, 0, 5, 10, 100, 1000, 10000, DBL_MAX});
+ h.Add(-3.0);
+ h.Add(4.99);
+ h.Add(5.0);
+ h.Add(1000.0);
+ Validate(h);
+}
+
+TEST(Histogram, Percentile) {
+ Histogram h({0, 10, 100, DBL_MAX});
+ h.Add(-2);
+ h.Add(-2);
+ h.Add(0);
+ double median = h.Percentile(50.0);
+ EXPECT_EQ(median, -0.5);
+}
+
+TEST(Histogram, Basic) {
+ Histogram h;
+ for (int i = 0; i < 100; i++) {
+ h.Add(i);
+ }
+ for (int i = 1000; i < 100000; i += 1000) {
+ h.Add(i);
+ }
+ Validate(h);
+}
+
+TEST(ThreadSafeHistogram, Basic) {
+ // Fill a normal histogram.
+ Histogram h;
+ for (int i = 0; i < 100; i++) {
+ h.Add(i);
+ }
+
+ // Fill a thread-safe histogram with the same values.
+ ThreadSafeHistogram tsh;
+ for (int i = 0; i < 100; i++) {
+ tsh.Add(i);
+ }
+
+ for (int i = 0; i < 2; ++i) {
+ bool preserve_zero_buckets = (i == 0);
+ HistogramProto h_proto;
+ h.EncodeToProto(&h_proto, preserve_zero_buckets);
+ HistogramProto tsh_proto;
+ tsh.EncodeToProto(&tsh_proto, preserve_zero_buckets);
+
+ // Let's decode from the proto of the other histogram type.
+ Histogram h2;
+ EXPECT_TRUE(h2.DecodeFromProto(tsh_proto));
+ ThreadSafeHistogram tsh2;
+ EXPECT_TRUE(tsh2.DecodeFromProto(h_proto));
+
+ // Now let's reencode and check they match.
+ EXPECT_EQ(h2.ToString(), tsh2.ToString());
+ }
+
+ EXPECT_EQ(h.Median(), tsh.Median());
+ EXPECT_EQ(h.Percentile(40.0), tsh.Percentile(40.0));
+ EXPECT_EQ(h.Average(), tsh.Average());
+ EXPECT_EQ(h.StandardDeviation(), tsh.StandardDeviation());
+ EXPECT_EQ(h.ToString(), tsh.ToString());
+}
+
+} // namespace histogram
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/block.cc b/tensorflow/core/lib/io/block.cc
new file mode 100644
index 0000000000..1ddaa2eb78
--- /dev/null
+++ b/tensorflow/core/lib/io/block.cc
@@ -0,0 +1,236 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// Decodes the blocks generated by block_builder.cc.
+
+#include "tensorflow/core/lib/io/block.h"
+
+#include <vector>
+#include <algorithm>
+#include "tensorflow/core/lib/io/format.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace table {
+
+inline uint32 Block::NumRestarts() const {
+ assert(size_ >= sizeof(uint32));
+ return core::DecodeFixed32(data_ + size_ - sizeof(uint32));
+}
+
+Block::Block(const BlockContents& contents)
+ : data_(contents.data.data()),
+ size_(contents.data.size()),
+ owned_(contents.heap_allocated) {
+ if (size_ < sizeof(uint32)) {
+ size_ = 0; // Error marker
+ } else {
+ size_t max_restarts_allowed = (size_ - sizeof(uint32)) / sizeof(uint32);
+ if (NumRestarts() > max_restarts_allowed) {
+ // The size is too small for NumRestarts()
+ size_ = 0;
+ } else {
+ restart_offset_ = size_ - (1 + NumRestarts()) * sizeof(uint32);
+ }
+ }
+}
+
+Block::~Block() {
+ if (owned_) {
+ delete[] data_;
+ }
+}
+
+// Helper routine: decode the next block entry starting at "p",
+// storing the number of shared key bytes, non_shared key bytes,
+// and the length of the value in "*shared", "*non_shared", and
+// "*value_length", respectively. Will not dereference past "limit".
+//
+// If any errors are detected, returns NULL. Otherwise, returns a
+// pointer to the key delta (just past the three decoded values).
+static inline const char* DecodeEntry(const char* p, const char* limit,
+ uint32* shared, uint32* non_shared,
+ uint32* value_length) {
+ if (limit - p < 3) return NULL;
+ *shared = reinterpret_cast<const unsigned char*>(p)[0];
+ *non_shared = reinterpret_cast<const unsigned char*>(p)[1];
+ *value_length = reinterpret_cast<const unsigned char*>(p)[2];
+ if ((*shared | *non_shared | *value_length) < 128) {
+ // Fast path: all three values are encoded in one byte each
+ p += 3;
+ } else {
+ if ((p = core::GetVarint32Ptr(p, limit, shared)) == NULL) return NULL;
+ if ((p = core::GetVarint32Ptr(p, limit, non_shared)) == NULL) return NULL;
+ if ((p = core::GetVarint32Ptr(p, limit, value_length)) == NULL) return NULL;
+ }
+
+ if (static_cast<uint32>(limit - p) < (*non_shared + *value_length)) {
+ return NULL;
+ }
+ return p;
+}
+
+class Block::Iter : public Iterator {
+ private:
+ const char* const data_; // underlying block contents
+ uint32 const restarts_; // Offset of restart array (list of fixed32)
+ uint32 const num_restarts_; // Number of uint32 entries in restart array
+
+ // current_ is offset in data_ of current entry. >= restarts_ if !Valid
+ uint32 current_;
+ uint32 restart_index_; // Index of restart block in which current_ falls
+ string key_;
+ StringPiece value_;
+ Status status_;
+
+ inline int Compare(const StringPiece& a, const StringPiece& b) const {
+ return a.compare(b);
+ }
+
+ // Return the offset in data_ just past the end of the current entry.
+ inline uint32 NextEntryOffset() const {
+ return (value_.data() + value_.size()) - data_;
+ }
+
+ uint32 GetRestartPoint(uint32 index) {
+ assert(index < num_restarts_);
+ return core::DecodeFixed32(data_ + restarts_ + index * sizeof(uint32));
+ }
+
+ void SeekToRestartPoint(uint32 index) {
+ key_.clear();
+ restart_index_ = index;
+ // current_ will be fixed by ParseNextKey();
+
+ // ParseNextKey() starts at the end of value_, so set value_ accordingly
+ uint32 offset = GetRestartPoint(index);
+ value_ = StringPiece(data_ + offset, 0);
+ }
+
+ public:
+ Iter(const char* data, uint32 restarts, uint32 num_restarts)
+ : data_(data),
+ restarts_(restarts),
+ num_restarts_(num_restarts),
+ current_(restarts_),
+ restart_index_(num_restarts_) {
+ assert(num_restarts_ > 0);
+ }
+
+ virtual bool Valid() const { return current_ < restarts_; }
+ virtual Status status() const { return status_; }
+ virtual StringPiece key() const {
+ assert(Valid());
+ return key_;
+ }
+ virtual StringPiece value() const {
+ assert(Valid());
+ return value_;
+ }
+
+ virtual void Next() {
+ assert(Valid());
+ ParseNextKey();
+ }
+
+ virtual void Seek(const StringPiece& target) {
+ // Binary search in restart array to find the last restart point
+ // with a key < target
+ uint32 left = 0;
+ uint32 right = num_restarts_ - 1;
+ while (left < right) {
+ uint32 mid = (left + right + 1) / 2;
+ uint32 region_offset = GetRestartPoint(mid);
+ uint32 shared, non_shared, value_length;
+ const char* key_ptr =
+ DecodeEntry(data_ + region_offset, data_ + restarts_, &shared,
+ &non_shared, &value_length);
+ if (key_ptr == NULL || (shared != 0)) {
+ CorruptionError();
+ return;
+ }
+ StringPiece mid_key(key_ptr, non_shared);
+ if (Compare(mid_key, target) < 0) {
+ // Key at "mid" is smaller than "target". Therefore all
+ // blocks before "mid" are uninteresting.
+ left = mid;
+ } else {
+ // Key at "mid" is >= "target". Therefore all blocks at or
+ // after "mid" are uninteresting.
+ right = mid - 1;
+ }
+ }
+
+ // Linear search (within restart block) for first key >= target
+ SeekToRestartPoint(left);
+ while (true) {
+ if (!ParseNextKey()) {
+ return;
+ }
+ if (Compare(key_, target) >= 0) {
+ return;
+ }
+ }
+ }
+
+ virtual void SeekToFirst() {
+ SeekToRestartPoint(0);
+ ParseNextKey();
+ }
+
+ private:
+ void CorruptionError() {
+ current_ = restarts_;
+ restart_index_ = num_restarts_;
+ status_ = errors::DataLoss("bad entry in block");
+ key_.clear();
+ value_.clear();
+ }
+
+ bool ParseNextKey() {
+ current_ = NextEntryOffset();
+ const char* p = data_ + current_;
+ const char* limit = data_ + restarts_; // Restarts come right after data
+ if (p >= limit) {
+ // No more entries to return. Mark as invalid.
+ current_ = restarts_;
+ restart_index_ = num_restarts_;
+ return false;
+ }
+
+ // Decode next entry
+ uint32 shared, non_shared, value_length;
+ p = DecodeEntry(p, limit, &shared, &non_shared, &value_length);
+ if (p == NULL || key_.size() < shared) {
+ CorruptionError();
+ return false;
+ } else {
+ key_.resize(shared);
+ key_.append(p, non_shared);
+ value_ = StringPiece(p + non_shared, value_length);
+ while (restart_index_ + 1 < num_restarts_ &&
+ GetRestartPoint(restart_index_ + 1) < current_) {
+ ++restart_index_;
+ }
+ return true;
+ }
+ }
+};
+
+Iterator* Block::NewIterator() {
+ if (size_ < sizeof(uint32)) {
+ return NewErrorIterator(errors::DataLoss("bad block contents"));
+ }
+ const uint32 num_restarts = NumRestarts();
+ if (num_restarts == 0) {
+ return NewEmptyIterator();
+ } else {
+ return new Iter(data_, restart_offset_, num_restarts);
+ }
+}
+
+} // namespace table
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/block.h b/tensorflow/core/lib/io/block.h
new file mode 100644
index 0000000000..bf53245b8d
--- /dev/null
+++ b/tensorflow/core/lib/io/block.h
@@ -0,0 +1,45 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef TENSORFLOW_LIB_IO_BLOCK_H_
+#define TENSORFLOW_LIB_IO_BLOCK_H_
+
+#include <stddef.h>
+#include <stdint.h>
+#include "tensorflow/core/lib/io/iterator.h"
+
+namespace tensorflow {
+namespace table {
+
+struct BlockContents;
+
+class Block {
+ public:
+ // Initialize the block with the specified contents.
+ explicit Block(const BlockContents& contents);
+
+ ~Block();
+
+ size_t size() const { return size_; }
+ Iterator* NewIterator();
+
+ private:
+ uint32 NumRestarts() const;
+
+ const char* data_;
+ size_t size_;
+ uint32 restart_offset_; // Offset in data_ of restart array
+ bool owned_; // Block owns data_[]
+
+ // No copying allowed
+ Block(const Block&);
+ void operator=(const Block&);
+
+ class Iter;
+};
+
+} // namespace table
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_BLOCK_H_
diff --git a/tensorflow/core/lib/io/block_builder.cc b/tensorflow/core/lib/io/block_builder.cc
new file mode 100644
index 0000000000..d94048d744
--- /dev/null
+++ b/tensorflow/core/lib/io/block_builder.cc
@@ -0,0 +1,107 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// BlockBuilder generates blocks where keys are prefix-compressed:
+//
+// When we store a key, we drop the prefix shared with the previous
+// string. This helps reduce the space requirement significantly.
+// Furthermore, once every K keys, we do not apply the prefix
+// compression and store the entire key. We call this a "restart
+// point". The tail end of the block stores the offsets of all of the
+// restart points, and can be used to do a binary search when looking
+// for a particular key. Values are stored as-is (without compression)
+// immediately following the corresponding key.
+//
+// An entry for a particular key-value pair has the form:
+// shared_bytes: varint32
+// unshared_bytes: varint32
+// value_length: varint32
+// key_delta: char[unshared_bytes]
+// value: char[value_length]
+// shared_bytes == 0 for restart points.
+//
+// The trailer of the block has the form:
+// restarts: uint32[num_restarts]
+// num_restarts: uint32
+// restarts[i] contains the offset within the block of the ith restart point.
+
+#include "tensorflow/core/lib/io/block_builder.h"
+
+#include <algorithm>
+#include <assert.h>
+#include "tensorflow/core/lib/io/table_builder.h"
+#include "tensorflow/core/lib/core/coding.h"
+
+namespace tensorflow {
+namespace table {
+
+BlockBuilder::BlockBuilder(const Options* options)
+ : options_(options), restarts_(), counter_(0), finished_(false) {
+ assert(options->block_restart_interval >= 1);
+ restarts_.push_back(0); // First restart point is at offset 0
+}
+
+void BlockBuilder::Reset() {
+ buffer_.clear();
+ restarts_.clear();
+ restarts_.push_back(0); // First restart point is at offset 0
+ counter_ = 0;
+ finished_ = false;
+ last_key_.clear();
+}
+
+size_t BlockBuilder::CurrentSizeEstimate() const {
+ return (buffer_.size() + // Raw data buffer
+ restarts_.size() * sizeof(uint32) + // Restart array
+ sizeof(uint32)); // Restart array length
+}
+
+StringPiece BlockBuilder::Finish() {
+ // Append restart array
+ for (size_t i = 0; i < restarts_.size(); i++) {
+ core::PutFixed32(&buffer_, restarts_[i]);
+ }
+ core::PutFixed32(&buffer_, restarts_.size());
+ finished_ = true;
+ return StringPiece(buffer_);
+}
+
+void BlockBuilder::Add(const StringPiece& key, const StringPiece& value) {
+ StringPiece last_key_piece(last_key_);
+ assert(!finished_);
+ assert(counter_ <= options_->block_restart_interval);
+ assert(buffer_.empty() // No values yet?
+ || key.compare(last_key_piece) > 0);
+ size_t shared = 0;
+ if (counter_ < options_->block_restart_interval) {
+ // See how much sharing to do with previous string
+ const size_t min_length = std::min(last_key_piece.size(), key.size());
+ while ((shared < min_length) && (last_key_piece[shared] == key[shared])) {
+ shared++;
+ }
+ } else {
+ // Restart compression
+ restarts_.push_back(buffer_.size());
+ counter_ = 0;
+ }
+ const size_t non_shared = key.size() - shared;
+
+ // Add "<shared><non_shared><value_size>" to buffer_
+ core::PutVarint32(&buffer_, shared);
+ core::PutVarint32(&buffer_, non_shared);
+ core::PutVarint32(&buffer_, value.size());
+
+ // Add string delta to buffer_ followed by value
+ buffer_.append(key.data() + shared, non_shared);
+ buffer_.append(value.data(), value.size());
+
+ // Update state
+ last_key_.resize(shared);
+ last_key_.append(key.data() + shared, non_shared);
+ assert(StringPiece(last_key_) == key);
+ counter_++;
+}
+
+} // namespace table
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/block_builder.h b/tensorflow/core/lib/io/block_builder.h
new file mode 100644
index 0000000000..e07a647805
--- /dev/null
+++ b/tensorflow/core/lib/io/block_builder.h
@@ -0,0 +1,57 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_
+#define TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_
+
+#include <vector>
+
+#include <stdint.h>
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+namespace table {
+
+struct Options;
+
+class BlockBuilder {
+ public:
+ explicit BlockBuilder(const Options* options);
+
+ // Reset the contents as if the BlockBuilder was just constructed.
+ void Reset();
+
+ // REQUIRES: Finish() has not been called since the last call to Reset().
+ // REQUIRES: key is larger than any previously added key
+ void Add(const StringPiece& key, const StringPiece& value);
+
+ // Finish building the block and return a slice that refers to the
+ // block contents. The returned slice will remain valid for the
+ // lifetime of this builder or until Reset() is called.
+ StringPiece Finish();
+
+ // Returns an estimate of the current (uncompressed) size of the block
+ // we are building.
+ size_t CurrentSizeEstimate() const;
+
+ // Return true iff no entries have been added since the last Reset()
+ bool empty() const { return buffer_.empty(); }
+
+ private:
+ const Options* options_;
+ string buffer_; // Destination buffer
+ std::vector<uint32> restarts_; // Restart points
+ int counter_; // Number of entries emitted since restart
+ bool finished_; // Has Finish() been called?
+ string last_key_;
+
+ // No copying allowed
+ BlockBuilder(const BlockBuilder&);
+ void operator=(const BlockBuilder&);
+};
+
+} // namespace table
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_
diff --git a/tensorflow/core/lib/io/format.cc b/tensorflow/core/lib/io/format.cc
new file mode 100644
index 0000000000..259cfc13dc
--- /dev/null
+++ b/tensorflow/core/lib/io/format.cc
@@ -0,0 +1,148 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "tensorflow/core/lib/io/format.h"
+
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/lib/io/block.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace table {
+
+void BlockHandle::EncodeTo(string* dst) const {
+ // Sanity check that all fields have been set
+ assert(offset_ != ~static_cast<uint64>(0));
+ assert(size_ != ~static_cast<uint64>(0));
+ core::PutVarint64(dst, offset_);
+ core::PutVarint64(dst, size_);
+}
+
+Status BlockHandle::DecodeFrom(StringPiece* input) {
+ if (core::GetVarint64(input, &offset_) && core::GetVarint64(input, &size_)) {
+ return Status::OK();
+ } else {
+ return errors::DataLoss("bad block handle");
+ }
+}
+
+void Footer::EncodeTo(string* dst) const {
+#ifndef NDEBUG
+ const size_t original_size = dst->size();
+#endif
+ metaindex_handle_.EncodeTo(dst);
+ index_handle_.EncodeTo(dst);
+ dst->resize(2 * BlockHandle::kMaxEncodedLength); // Padding
+ core::PutFixed32(dst, static_cast<uint32>(kTableMagicNumber & 0xffffffffu));
+ core::PutFixed32(dst, static_cast<uint32>(kTableMagicNumber >> 32));
+ assert(dst->size() == original_size + kEncodedLength);
+}
+
+Status Footer::DecodeFrom(StringPiece* input) {
+ const char* magic_ptr = input->data() + kEncodedLength - 8;
+ const uint32 magic_lo = core::DecodeFixed32(magic_ptr);
+ const uint32 magic_hi = core::DecodeFixed32(magic_ptr + 4);
+ const uint64 magic =
+ ((static_cast<uint64>(magic_hi) << 32) | (static_cast<uint64>(magic_lo)));
+ if (magic != kTableMagicNumber) {
+ return errors::DataLoss("not an sstable (bad magic number)");
+ }
+
+ Status result = metaindex_handle_.DecodeFrom(input);
+ if (result.ok()) {
+ result = index_handle_.DecodeFrom(input);
+ }
+ if (result.ok()) {
+ // We skip over any leftover data (just padding for now) in "input"
+ const char* end = magic_ptr + 8;
+ *input = StringPiece(end, input->data() + input->size() - end);
+ }
+ return result;
+}
+
+Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle,
+ BlockContents* result) {
+ result->data = StringPiece();
+ result->cachable = false;
+ result->heap_allocated = false;
+
+ // Read the block contents as well as the type/crc footer.
+ // See table_builder.cc for the code that built this structure.
+ size_t n = static_cast<size_t>(handle.size());
+ char* buf = new char[n + kBlockTrailerSize];
+ StringPiece contents;
+ Status s =
+ file->Read(handle.offset(), n + kBlockTrailerSize, &contents, buf);
+ if (!s.ok()) {
+ delete[] buf;
+ return s;
+ }
+ if (contents.size() != n + kBlockTrailerSize) {
+ delete[] buf;
+ return errors::DataLoss("truncated block read");
+ }
+
+ // Check the crc of the type and the block contents
+ const char* data = contents.data(); // Pointer to where Read put the data
+ // This checksum verification is optional. We leave it on for now
+ const bool verify_checksum = true;
+ if (verify_checksum) {
+ const uint32 crc = crc32c::Unmask(core::DecodeFixed32(data + n + 1));
+ const uint32 actual = crc32c::Value(data, n + 1);
+ if (actual != crc) {
+ delete[] buf;
+ s = errors::DataLoss("block checksum mismatch");
+ return s;
+ }
+ }
+
+ switch (data[n]) {
+ case kNoCompression:
+ if (data != buf) {
+ // File implementation gave us pointer to some other data.
+ // Use it directly under the assumption that it will be live
+ // while the file is open.
+ delete[] buf;
+ result->data = StringPiece(data, n);
+ result->heap_allocated = false;
+ result->cachable = false; // Do not double-cache
+ } else {
+ result->data = StringPiece(buf, n);
+ result->heap_allocated = true;
+ result->cachable = true;
+ }
+
+ // Ok
+ break;
+ case kSnappyCompression: {
+ size_t ulength = 0;
+ if (!port::Snappy_GetUncompressedLength(data, n, &ulength)) {
+ delete[] buf;
+ return errors::DataLoss("corrupted compressed block contents");
+ }
+ char* ubuf = new char[ulength];
+ if (!port::Snappy_Uncompress(data, n, ubuf)) {
+ delete[] buf;
+ delete[] ubuf;
+ return errors::DataLoss("corrupted compressed block contents");
+ }
+ delete[] buf;
+ result->data = StringPiece(ubuf, ulength);
+ result->heap_allocated = true;
+ result->cachable = true;
+ break;
+ }
+ default:
+ delete[] buf;
+ return errors::DataLoss("bad block type");
+ }
+
+ return Status::OK();
+}
+
+} // namespace table
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/format.h b/tensorflow/core/lib/io/format.h
new file mode 100644
index 0000000000..3121c41bb8
--- /dev/null
+++ b/tensorflow/core/lib/io/format.h
@@ -0,0 +1,99 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef TENSORFLOW_LIB_IO_FORMAT_H_
+#define TENSORFLOW_LIB_IO_FORMAT_H_
+
+#include <string>
+#include <stdint.h>
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/io/table_builder.h"
+
+namespace tensorflow {
+class RandomAccessFile;
+namespace table {
+
+class Block;
+
+// BlockHandle is a pointer to the extent of a file that stores a data
+// block or a meta block.
+class BlockHandle {
+ public:
+ BlockHandle();
+
+ // The offset of the block in the file.
+ uint64 offset() const { return offset_; }
+ void set_offset(uint64 offset) { offset_ = offset; }
+
+ // The size of the stored block
+ uint64 size() const { return size_; }
+ void set_size(uint64 size) { size_ = size; }
+
+ void EncodeTo(string* dst) const;
+ Status DecodeFrom(StringPiece* input);
+
+ // Maximum encoding length of a BlockHandle
+ enum { kMaxEncodedLength = 10 + 10 };
+
+ private:
+ uint64 offset_;
+ uint64 size_;
+};
+
+// Footer encapsulates the fixed information stored at the tail
+// end of every table file.
+class Footer {
+ public:
+ Footer() {}
+
+ // The block handle for the metaindex block of the table
+ const BlockHandle& metaindex_handle() const { return metaindex_handle_; }
+ void set_metaindex_handle(const BlockHandle& h) { metaindex_handle_ = h; }
+
+ // The block handle for the index block of the table
+ const BlockHandle& index_handle() const { return index_handle_; }
+ void set_index_handle(const BlockHandle& h) { index_handle_ = h; }
+
+ void EncodeTo(string* dst) const;
+ Status DecodeFrom(StringPiece* input);
+
+ // Encoded length of a Footer. Note that the serialization of a
+ // Footer will always occupy exactly this many bytes. It consists
+ // of two block handles and a magic number.
+ enum { kEncodedLength = 2 * BlockHandle::kMaxEncodedLength + 8 };
+
+ private:
+ BlockHandle metaindex_handle_;
+ BlockHandle index_handle_;
+};
+
+// kTableMagicNumber was picked by running
+// echo http://code.google.com/p/leveldb/ | sha1sum
+// and taking the leading 64 bits.
+static const uint64 kTableMagicNumber = 0xdb4775248b80fb57ull;
+
+// 1-byte type + 32-bit crc
+static const size_t kBlockTrailerSize = 5;
+
+struct BlockContents {
+ StringPiece data; // Actual contents of data
+ bool cachable; // True iff data can be cached
+ bool heap_allocated; // True iff caller should delete[] data.data()
+};
+
+// Read the block identified by "handle" from "file". On failure
+// return non-OK. On success fill *result and return OK.
+extern Status ReadBlock(RandomAccessFile* file, const BlockHandle& handle,
+ BlockContents* result);
+
+// Implementation details follow. Clients should ignore,
+
+inline BlockHandle::BlockHandle()
+ : offset_(~static_cast<uint64>(0)), size_(~static_cast<uint64>(0)) {}
+
+} // namespace table
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_FORMAT_H_
diff --git a/tensorflow/core/lib/io/inputbuffer.cc b/tensorflow/core/lib/io/inputbuffer.cc
new file mode 100644
index 0000000000..8fa245a546
--- /dev/null
+++ b/tensorflow/core/lib/io/inputbuffer.cc
@@ -0,0 +1,112 @@
+#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace io {
+
+InputBuffer::InputBuffer(RandomAccessFile* file, size_t buffer_bytes)
+ : file_(file),
+ file_pos_(0),
+ size_(buffer_bytes),
+ buf_(new char[size_]),
+ pos_(buf_),
+ limit_(buf_) {}
+
+InputBuffer::~InputBuffer() {
+ delete file_;
+ delete[] buf_;
+}
+
+Status InputBuffer::FillBuffer() {
+ StringPiece data;
+ Status s = file_->Read(file_pos_, size_, &data, buf_);
+ if (data.data() != buf_) {
+ memmove(buf_, data.data(), data.size());
+ }
+ pos_ = buf_;
+ limit_ = pos_ + data.size();
+ file_pos_ += data.size();
+ return s;
+}
+
+Status InputBuffer::ReadLine(string* result) {
+ result->clear();
+ int i;
+ Status s;
+ for (i = 0;; i++) {
+ if (pos_ == limit_) {
+ // Get more data into buffer
+ s = FillBuffer();
+ if (limit_ == buf_) {
+ break;
+ }
+ }
+ char c = *pos_++;
+ if (c == '\n') {
+ // We don't append the '\n' to *result
+ return Status::OK();
+ }
+ *result += c;
+ }
+ if (errors::IsOutOfRange(s) && !result->empty()) {
+ return Status::OK();
+ }
+ return s;
+}
+
+Status InputBuffer::ReadNBytes(int64 bytes_to_read, string* result) {
+ result->clear();
+ if (bytes_to_read < 0) {
+ return errors::InvalidArgument("Can't read a negative number of bytes: ",
+ bytes_to_read);
+ }
+ result->reserve(bytes_to_read);
+ Status s;
+ while (result->size() < static_cast<size_t>(bytes_to_read)) {
+ if (pos_ == limit_) {
+ // Get more data into buffer
+ s = FillBuffer();
+ if (limit_ == buf_) {
+ break;
+ }
+ }
+ const int64 bytes_to_copy =
+ std::min<int64>(limit_ - pos_, bytes_to_read - result->size());
+ result->insert(result->size(), pos_, bytes_to_copy);
+ pos_ += bytes_to_copy;
+ }
+ if (errors::IsOutOfRange(s) &&
+ (result->size() == static_cast<size_t>(bytes_to_read))) {
+ return Status::OK();
+ }
+ return s;
+}
+
+Status InputBuffer::SkipNBytes(int64 bytes_to_skip) {
+ if (bytes_to_skip < 0) {
+ return errors::InvalidArgument("Can only skip forward, not ",
+ bytes_to_skip);
+ }
+ int64 bytes_skipped = 0;
+ Status s;
+ while (bytes_skipped < bytes_to_skip) {
+ if (pos_ == limit_) {
+ // Get more data into buffer
+ s = FillBuffer();
+ if (limit_ == buf_) {
+ break;
+ }
+ }
+ const int64 bytes_to_advance =
+ std::min<int64>(limit_ - pos_, bytes_to_skip - bytes_skipped);
+ bytes_skipped += bytes_to_advance;
+ pos_ += bytes_to_advance;
+ }
+ if (errors::IsOutOfRange(s) && bytes_skipped == bytes_to_skip) {
+ return Status::OK();
+ }
+ return s;
+}
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/inputbuffer.h b/tensorflow/core/lib/io/inputbuffer.h
new file mode 100644
index 0000000000..6879f30567
--- /dev/null
+++ b/tensorflow/core/lib/io/inputbuffer.h
@@ -0,0 +1,62 @@
+#ifndef TENSORFLOW_LIB_IO_INPUTBUFFER_H_
+#define TENSORFLOW_LIB_IO_INPUTBUFFER_H_
+
+#include <string>
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+namespace io {
+
+// An InputBuffer provides a buffer on top of a RandomAccessFile.
+// A given instance of an InputBuffer is NOT safe for concurrent use
+// by multiple threads
+class InputBuffer {
+ public:
+ // Create an InputBuffer for "file" with a buffer size of
+ // "buffer_bytes" bytes. Takes ownership of "file" and will
+ // delete it when the InputBuffer is destroyed.
+ InputBuffer(RandomAccessFile* file, size_t buffer_bytes);
+ ~InputBuffer();
+
+ // Read one text line of data into "*result" until end-of-file or a
+ // \n is read. (The \n is not included in the result.) Overwrites
+ // any existing data in *result.
+ //
+ // If successful, returns OK. If we are already at the end of the
+ // file, we return an OUT_OF_RANGE error. Otherwise, we return
+ // some other non-OK status.
+ Status ReadLine(string* result);
+
+ // Reads bytes_to_read bytes into *result, overwriting *result.
+ //
+ // If successful, returns OK. If we there are not enough bytes to
+ // read before the end of the file, we return an OUT_OF_RANGE error.
+ // Otherwise, we return some other non-OK status.
+ Status ReadNBytes(int64 bytes_to_read, string* result);
+
+ // Like ReadNBytes() without returning the bytes read.
+ Status SkipNBytes(int64 bytes_to_skip);
+
+ // Returns the position in the file.
+ int64 Tell() const { return file_pos_ - (limit_ - pos_); }
+
+ private:
+ Status FillBuffer();
+
+ RandomAccessFile* file_; // Owned
+ int64 file_pos_; // Next position to read from in "file_"
+ size_t size_; // Size of "buf_"
+ char* buf_; // The buffer itself
+ // [pos_,limit_) hold the "limit_ - pos_" bytes just before "file_pos_"
+ char* pos_; // Current position in "buf"
+ char* limit_; // Just past end of valid data in "buf"
+
+ TF_DISALLOW_COPY_AND_ASSIGN(InputBuffer);
+};
+
+} // namespace io
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_INPUTBUFFER_H_
diff --git a/tensorflow/core/lib/io/inputbuffer_test.cc b/tensorflow/core/lib/io/inputbuffer_test.cc
new file mode 100644
index 0000000000..34094f018c
--- /dev/null
+++ b/tensorflow/core/lib/io/inputbuffer_test.cc
@@ -0,0 +1,174 @@
+#include "tensorflow/core/lib/io/inputbuffer.h"
+
+#include "tensorflow/core/public/env.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+static std::vector<int> BufferSizes() {
+ return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 65536};
+}
+
+TEST(InputBuffer, ReadLine_Empty) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/inputbuffer_test";
+ WriteStringToFile(env, fname, "");
+
+ for (auto buf_size : BufferSizes()) {
+ RandomAccessFile* file;
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &file));
+ string line;
+ io::InputBuffer in(file, buf_size);
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line)));
+ }
+}
+
+TEST(InputBuffer, ReadLine1) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/inputbuffer_test";
+ WriteStringToFile(env, fname, "line one\nline two\nline three\n");
+
+ for (auto buf_size : BufferSizes()) {
+ RandomAccessFile* file;
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &file));
+ string line;
+ io::InputBuffer in(file, buf_size);
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line one");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line two");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line three");
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line)));
+ // A second call should also return end of file
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line)));
+ }
+}
+
+TEST(InputBuffer, ReadLine_NoTrailingNewLine) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/inputbuffer_test";
+ WriteStringToFile(env, fname, "line one\nline two\nline three");
+
+ for (auto buf_size : BufferSizes()) {
+ RandomAccessFile* file;
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &file));
+ string line;
+ io::InputBuffer in(file, buf_size);
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line one");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line two");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line three");
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line)));
+ // A second call should also return end of file
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line)));
+ }
+}
+
+TEST(InputBuffer, ReadLine_EmptyLines) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/inputbuffer_test";
+ WriteStringToFile(env, fname, "line one\n\n\nline two\nline three");
+
+ for (auto buf_size : BufferSizes()) {
+ RandomAccessFile* file;
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &file));
+ string line;
+ io::InputBuffer in(file, buf_size);
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line one");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line two");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line three");
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line)));
+ // A second call should also return end of file
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line)));
+ }
+}
+
+TEST(InputBuffer, ReadNBytes) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/inputbuffer_test";
+ WriteStringToFile(env, fname, "0123456789");
+
+ for (auto buf_size : BufferSizes()) {
+ RandomAccessFile* file;
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &file));
+ string read;
+ io::InputBuffer in(file, buf_size);
+ EXPECT_EQ(0, in.Tell());
+ TF_CHECK_OK(in.ReadNBytes(3, &read));
+ EXPECT_EQ(read, "012");
+ EXPECT_EQ(3, in.Tell());
+ TF_CHECK_OK(in.ReadNBytes(0, &read));
+ EXPECT_EQ(read, "");
+ EXPECT_EQ(3, in.Tell());
+ TF_CHECK_OK(in.ReadNBytes(4, &read));
+ EXPECT_EQ(read, "3456");
+ EXPECT_EQ(7, in.Tell());
+ TF_CHECK_OK(in.ReadNBytes(0, &read));
+ EXPECT_EQ(read, "");
+ EXPECT_EQ(7, in.Tell());
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(5, &read)));
+ EXPECT_EQ(read, "789");
+ EXPECT_EQ(10, in.Tell());
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(5, &read)));
+ EXPECT_EQ(read, "");
+ EXPECT_EQ(10, in.Tell());
+ TF_CHECK_OK(in.ReadNBytes(0, &read));
+ EXPECT_EQ(read, "");
+ EXPECT_EQ(10, in.Tell());
+ }
+}
+
+TEST(InputBuffer, SkipNBytes) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/inputbuffer_test";
+ WriteStringToFile(env, fname, "0123456789");
+
+ for (auto buf_size : BufferSizes()) {
+ RandomAccessFile* file;
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &file));
+ string read;
+ io::InputBuffer in(file, buf_size);
+ EXPECT_EQ(0, in.Tell());
+ TF_CHECK_OK(in.SkipNBytes(3));
+ EXPECT_EQ(3, in.Tell());
+ TF_CHECK_OK(in.SkipNBytes(0));
+ EXPECT_EQ(3, in.Tell());
+ TF_CHECK_OK(in.ReadNBytes(2, &read));
+ EXPECT_EQ(read, "34");
+ EXPECT_EQ(5, in.Tell());
+ TF_CHECK_OK(in.SkipNBytes(0));
+ EXPECT_EQ(5, in.Tell());
+ TF_CHECK_OK(in.SkipNBytes(2));
+ EXPECT_EQ(7, in.Tell());
+ TF_CHECK_OK(in.ReadNBytes(1, &read));
+ EXPECT_EQ(read, "7");
+ EXPECT_EQ(8, in.Tell());
+ EXPECT_TRUE(errors::IsOutOfRange(in.SkipNBytes(5)));
+ EXPECT_EQ(10, in.Tell());
+ EXPECT_TRUE(errors::IsOutOfRange(in.SkipNBytes(5)));
+ EXPECT_EQ(10, in.Tell());
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadNBytes(5, &read)));
+ EXPECT_EQ(read, "");
+ EXPECT_EQ(10, in.Tell());
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/iterator.cc b/tensorflow/core/lib/io/iterator.cc
new file mode 100644
index 0000000000..878e93a911
--- /dev/null
+++ b/tensorflow/core/lib/io/iterator.cc
@@ -0,0 +1,72 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "tensorflow/core/lib/io/iterator.h"
+
+namespace tensorflow {
+namespace table {
+
+Iterator::Iterator() {
+ cleanup_.function = NULL;
+ cleanup_.next = NULL;
+}
+
+Iterator::~Iterator() {
+ if (cleanup_.function != NULL) {
+ (*cleanup_.function)(cleanup_.arg1, cleanup_.arg2);
+ for (Cleanup* c = cleanup_.next; c != NULL;) {
+ (*c->function)(c->arg1, c->arg2);
+ Cleanup* next = c->next;
+ delete c;
+ c = next;
+ }
+ }
+}
+
+void Iterator::RegisterCleanup(CleanupFunction func, void* arg1, void* arg2) {
+ assert(func != NULL);
+ Cleanup* c;
+ if (cleanup_.function == NULL) {
+ c = &cleanup_;
+ } else {
+ c = new Cleanup;
+ c->next = cleanup_.next;
+ cleanup_.next = c;
+ }
+ c->function = func;
+ c->arg1 = arg1;
+ c->arg2 = arg2;
+}
+
+namespace {
+class EmptyIterator : public Iterator {
+ public:
+ EmptyIterator(const Status& s) : status_(s) {}
+ virtual bool Valid() const { return false; }
+ virtual void Seek(const StringPiece& target) {}
+ virtual void SeekToFirst() {}
+ virtual void Next() { assert(false); }
+ StringPiece key() const {
+ assert(false);
+ return StringPiece();
+ }
+ StringPiece value() const {
+ assert(false);
+ return StringPiece();
+ }
+ virtual Status status() const { return status_; }
+
+ private:
+ Status status_;
+};
+} // namespace
+
+Iterator* NewEmptyIterator() { return new EmptyIterator(Status::OK()); }
+
+Iterator* NewErrorIterator(const Status& status) {
+ return new EmptyIterator(status);
+}
+
+} // namespace table
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/iterator.h b/tensorflow/core/lib/io/iterator.h
new file mode 100644
index 0000000000..603a2f95fe
--- /dev/null
+++ b/tensorflow/core/lib/io/iterator.h
@@ -0,0 +1,93 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// An iterator yields a sequence of key/value pairs from a source.
+// The following class defines the interface. Multiple implementations
+// are provided by this library. In particular, iterators are provided
+// to access the contents of a Table or a DB.
+//
+// Multiple threads can invoke const methods on an Iterator without
+// external synchronization, but if any of the threads may call a
+// non-const method, all threads accessing the same Iterator must use
+// external synchronization.
+
+#ifndef TENSORFLOW_LIB_IO_ITERATOR_H_
+#define TENSORFLOW_LIB_IO_ITERATOR_H_
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+namespace table {
+
+class Iterator {
+ public:
+ Iterator();
+ virtual ~Iterator();
+
+ // An iterator is either positioned at a key/value pair, or
+ // not valid. This method returns true iff the iterator is valid.
+ virtual bool Valid() const = 0;
+
+ // Position at the first key in the source. The iterator is Valid()
+ // after this call iff the source is not empty.
+ virtual void SeekToFirst() = 0;
+
+ // Position at the first key in the source that is at or past target.
+ // The iterator is Valid() after this call iff the source contains
+ // an entry that comes at or past target.
+ virtual void Seek(const StringPiece& target) = 0;
+
+ // Moves to the next entry in the source. After this call, Valid() is
+ // true iff the iterator was not positioned at the last entry in the source.
+ // REQUIRES: Valid()
+ virtual void Next() = 0;
+
+ // Return the key for the current entry. The underlying storage for
+ // the returned slice is valid only until the next modification of
+ // the iterator.
+ // REQUIRES: Valid()
+ virtual StringPiece key() const = 0;
+
+ // Return the value for the current entry. The underlying storage for
+ // the returned slice is valid only until the next modification of
+ // the iterator.
+ // REQUIRES: Valid()
+ virtual StringPiece value() const = 0;
+
+ // If an error has occurred, return it. Else return an ok status.
+ virtual Status status() const = 0;
+
+ // Clients are allowed to register function/arg1/arg2 triples that
+ // will be invoked when this iterator is destroyed.
+ //
+ // Note that unlike all of the preceding methods, this method is
+ // not abstract and therefore clients should not override it.
+ typedef void (*CleanupFunction)(void* arg1, void* arg2);
+ void RegisterCleanup(CleanupFunction function, void* arg1, void* arg2);
+
+ private:
+ struct Cleanup {
+ CleanupFunction function;
+ void* arg1;
+ void* arg2;
+ Cleanup* next;
+ };
+ Cleanup cleanup_;
+
+ // No copying allowed
+ Iterator(const Iterator&);
+ void operator=(const Iterator&);
+};
+
+// Return an empty iterator (yields nothing).
+extern Iterator* NewEmptyIterator();
+
+// Return an empty iterator with the specified status.
+extern Iterator* NewErrorIterator(const Status& status);
+
+} // namespace table
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_ITERATOR_H_
diff --git a/tensorflow/core/lib/io/match.cc b/tensorflow/core/lib/io/match.cc
new file mode 100644
index 0000000000..1563642d0b
--- /dev/null
+++ b/tensorflow/core/lib/io/match.cc
@@ -0,0 +1,31 @@
+#include "tensorflow/core/lib/io/match.h"
+#include <fnmatch.h>
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace io {
+
+Status GetMatchingFiles(Env* env, const string& pattern,
+ std::vector<string>* results) {
+ results->clear();
+ std::vector<string> all_files;
+ string dir = Dirname(pattern).ToString();
+ if (dir.empty()) dir = ".";
+ string basename_pattern = Basename(pattern).ToString();
+ Status s = env->GetChildren(dir, &all_files);
+ if (!s.ok()) {
+ return s;
+ }
+ for (const auto& f : all_files) {
+ int flags = 0;
+ if (fnmatch(basename_pattern.c_str(), Basename(f).ToString().c_str(),
+ flags) == 0) {
+ results->push_back(JoinPath(dir, f));
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/match.h b/tensorflow/core/lib/io/match.h
new file mode 100644
index 0000000000..fd194178e7
--- /dev/null
+++ b/tensorflow/core/lib/io/match.h
@@ -0,0 +1,24 @@
+#ifndef TENSORFLOW_LIB_IO_MATCH_H_
+#define TENSORFLOW_LIB_IO_MATCH_H_
+
+#include <vector>
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+class Env;
+namespace io {
+
+// Given a pattern, return the set of files that match the pattern.
+// Note that this routine only supports wildcard characters in the
+// basename portion of the pattern, not in the directory portion. If
+// successful, return Status::OK and store the matching files in
+// "*results". Otherwise, return a non-OK status.
+Status GetMatchingFiles(Env* env, const string& pattern,
+ std::vector<string>* results);
+
+} // namespace io
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_MATCH_H_
diff --git a/tensorflow/core/lib/io/match_test.cc b/tensorflow/core/lib/io/match_test.cc
new file mode 100644
index 0000000000..aaa56e4e7e
--- /dev/null
+++ b/tensorflow/core/lib/io/match_test.cc
@@ -0,0 +1,51 @@
+#include <algorithm>
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/match.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/env.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace io {
+
+static string Match(Env* env, const string& suffix_pattern) {
+ std::vector<string> results;
+ Status s = GetMatchingFiles(env, JoinPath(testing::TmpDir(), suffix_pattern),
+ &results);
+ if (!s.ok()) {
+ return s.ToString();
+ } else {
+ string r;
+ std::sort(results.begin(), results.end());
+ for (size_t i = 0; i < results.size(); i++) {
+ strings::StrAppend(&r, (i > 0) ? "," : "", Basename(results[i]));
+ }
+ return r;
+ }
+}
+TEST(GetMatchingFiles, Simple) {
+ Env* env = Env::Default();
+ EXPECT_EQ(Match(env, "thereisnosuchfile"), "");
+ EXPECT_EQ(Match(env, "thereisnosuchfile*"), "");
+
+ // Populate a few files
+ EXPECT_OK(WriteStringToFile(Env::Default(),
+ JoinPath(testing::TmpDir(), "match-00"), ""));
+ EXPECT_OK(WriteStringToFile(Env::Default(),
+ JoinPath(testing::TmpDir(), "match-0a"), ""));
+ EXPECT_OK(WriteStringToFile(Env::Default(),
+ JoinPath(testing::TmpDir(), "match-01"), ""));
+ EXPECT_OK(WriteStringToFile(Env::Default(),
+ JoinPath(testing::TmpDir(), "match-aaa"), ""));
+
+ EXPECT_EQ(Match(env, "match-*"), "match-00,match-01,match-0a,match-aaa");
+ EXPECT_EQ(Match(env, "match-0[0-9]"), "match-00,match-01");
+ EXPECT_EQ(Match(env, "match-?[0-9]"), "match-00,match-01");
+ EXPECT_EQ(Match(env, "match-?a*"), "match-0a,match-aaa");
+ EXPECT_EQ(Match(env, "match-??"), "match-00,match-01,match-0a");
+}
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/path.cc b/tensorflow/core/lib/io/path.cc
new file mode 100644
index 0000000000..1359ded0f0
--- /dev/null
+++ b/tensorflow/core/lib/io/path.cc
@@ -0,0 +1,92 @@
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace io {
+
+string JoinPath(StringPiece part1, StringPiece part2) {
+ string result;
+
+ StringPiece paths[2] = {part1, part2};
+ for (StringPiece path : paths) {
+ if (path.empty()) continue;
+
+ if (result.empty()) {
+ result = path.ToString();
+ continue;
+ }
+
+ if (result[result.size() - 1] == '/') {
+ if (IsAbsolutePath(path)) {
+ strings::StrAppend(&result, path.substr(1));
+ } else {
+ strings::StrAppend(&result, path);
+ }
+ } else {
+ if (IsAbsolutePath(path)) {
+ strings::StrAppend(&result, path);
+ } else {
+ strings::StrAppend(&result, "/", path);
+ }
+ }
+ }
+
+ return result;
+}
+
+namespace internal {
+
+// Return the parts of the path, split on the final "/". If there is no
+// "/" in the path, the first part of the output is empty and the second
+// is the input. If the only "/" in the path is the first character, it is
+// the first part of the output.
+std::pair<StringPiece, StringPiece> SplitPath(StringPiece path) {
+ auto pos = path.rfind('/');
+
+ // Handle the case with no '/' in 'path'.
+ if (pos == StringPiece::npos)
+ return std::make_pair(StringPiece(path.data(), 0), path);
+
+ // Handle the case with a single leading '/' in 'path'.
+ if (pos == 0)
+ return std::make_pair(StringPiece(path.data(), 1),
+ StringPiece(path.data() + 1, path.size() - 1));
+
+ return std::make_pair(
+ StringPiece(path.data(), pos),
+ StringPiece(path.data() + pos + 1, path.size() - (pos + 1)));
+}
+
+// Return the parts of the basename of path, split on the final ".".
+// If there is no "." in the basename or "." is the final character in the
+// basename, the second value will be empty.
+std::pair<StringPiece, StringPiece> SplitBasename(StringPiece path) {
+ path = Basename(path);
+
+ auto pos = path.rfind('.');
+ if (pos == StringPiece::npos)
+ return std::make_pair(path, StringPiece(path.data() + path.size(), 0));
+ return std::make_pair(
+ StringPiece(path.data(), pos),
+ StringPiece(path.data() + pos + 1, path.size() - (pos + 1)));
+}
+} // namespace internal
+
+bool IsAbsolutePath(StringPiece path) {
+ return !path.empty() && path[0] == '/';
+}
+
+StringPiece Dirname(StringPiece path) {
+ return internal::SplitPath(path).first;
+}
+
+StringPiece Basename(StringPiece path) {
+ return internal::SplitPath(path).second;
+}
+
+StringPiece Extension(StringPiece path) {
+ return internal::SplitBasename(path).second;
+}
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/path.h b/tensorflow/core/lib/io/path.h
new file mode 100644
index 0000000000..01483f1702
--- /dev/null
+++ b/tensorflow/core/lib/io/path.h
@@ -0,0 +1,47 @@
+#ifndef TENSORFLOW_LIB_IO_PATH_H_
+#define TENSORFLOW_LIB_IO_PATH_H_
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+class StringPiece;
+namespace io {
+
+// Utility routines for processing filenames
+
+// Join multiple paths together, without introducing unnecessary path
+// separators.
+// For example:
+//
+// Arguments | JoinPath
+// ---------------------------+----------
+// '/foo', 'bar' | /foo/bar
+// '/foo/', 'bar' | /foo/bar
+// '/foo', '/bar' | /foo/bar
+//
+// Usage:
+// string path = io::JoinPath("/mydir", filename);
+// string path = io::JoinPath(FLAGS_test_srcdir, filename);
+string JoinPath(StringPiece part1, StringPiece part2);
+
+// Return true if path is absolute.
+bool IsAbsolutePath(StringPiece path);
+
+// Returns the part of the path before the final "/". If there is a single
+// leading "/" in the path, the result will be the leading "/". If there is
+// no "/" in the path, the result is the empty prefix of the input.
+StringPiece Dirname(StringPiece path);
+
+// Returns the part of the path after the final "/". If there is no
+// "/" in the path, the result is the same as the input.
+StringPiece Basename(StringPiece path);
+
+// Returns the part of the basename of path after the final ".". If
+// there is no "." in the basename, the result is empty.
+StringPiece Extension(StringPiece path);
+
+} // namespace io
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_PATH_H_
diff --git a/tensorflow/core/lib/io/path_test.cc b/tensorflow/core/lib/io/path_test.cc
new file mode 100644
index 0000000000..b670e44f1f
--- /dev/null
+++ b/tensorflow/core/lib/io/path_test.cc
@@ -0,0 +1,65 @@
+#include "tensorflow/core/lib/io/path.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace io {
+
+TEST(PathTest, JoinPath) {
+ EXPECT_EQ("/foo/bar", JoinPath("/foo", "bar"));
+ EXPECT_EQ("foo/bar", JoinPath("foo", "bar"));
+ EXPECT_EQ("foo/bar", JoinPath("foo", "/bar"));
+ EXPECT_EQ("/foo/bar", JoinPath("/foo", "/bar"));
+
+ EXPECT_EQ("/bar", JoinPath("", "/bar"));
+ EXPECT_EQ("bar", JoinPath("", "bar"));
+ EXPECT_EQ("/foo", JoinPath("/foo", ""));
+
+ EXPECT_EQ("/foo/bar/baz/blah/blink/biz",
+ JoinPath("/foo/bar/baz/", "/blah/blink/biz"));
+}
+
+TEST(PathTest, IsAbsolutePath) {
+ EXPECT_FALSE(IsAbsolutePath(""));
+ EXPECT_FALSE(IsAbsolutePath("../foo"));
+ EXPECT_FALSE(IsAbsolutePath("foo"));
+ EXPECT_FALSE(IsAbsolutePath("./foo"));
+ EXPECT_FALSE(IsAbsolutePath("foo/bar/baz/"));
+ EXPECT_TRUE(IsAbsolutePath("/foo"));
+ EXPECT_TRUE(IsAbsolutePath("/foo/bar/../baz"));
+}
+
+TEST(PathTest, Dirname) {
+ EXPECT_EQ("/hello", Dirname("/hello/"));
+ EXPECT_EQ("/", Dirname("/hello"));
+ EXPECT_EQ("hello", Dirname("hello/world"));
+ EXPECT_EQ("hello", Dirname("hello/"));
+ EXPECT_EQ("", Dirname("world"));
+ EXPECT_EQ("/", Dirname("/"));
+ EXPECT_EQ("", Dirname(""));
+}
+
+TEST(PathTest, Basename) {
+ EXPECT_EQ("", Basename("/hello/"));
+ EXPECT_EQ("hello", Basename("/hello"));
+ EXPECT_EQ("world", Basename("hello/world"));
+ EXPECT_EQ("", Basename("hello/"));
+ EXPECT_EQ("world", Basename("world"));
+ EXPECT_EQ("", Basename("/"));
+ EXPECT_EQ("", Basename(""));
+}
+
+TEST(PathTest, Extension) {
+ EXPECT_EQ("gif", Extension("foo.gif"));
+ EXPECT_EQ("", Extension("foo."));
+ EXPECT_EQ("", Extension(""));
+ EXPECT_EQ("", Extension("/"));
+ EXPECT_EQ("", Extension("foo"));
+ EXPECT_EQ("", Extension("foo/"));
+ EXPECT_EQ("gif", Extension("/a/path/to/foo.gif"));
+ EXPECT_EQ("html", Extension("/a/path.bar/to/foo.html"));
+ EXPECT_EQ("", Extension("/a/path.bar/to/foo"));
+ EXPECT_EQ("baz", Extension("/a/path.bar/to/foo.bar.baz"));
+}
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
new file mode 100644
index 0000000000..2f0fabff63
--- /dev/null
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -0,0 +1,80 @@
+#include "tensorflow/core/lib/io/record_reader.h"
+
+#include <limits.h>
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace io {
+
+RecordReader::RecordReader(RandomAccessFile* file) : src_(file) {}
+
+RecordReader::~RecordReader() {}
+
+// Read n+4 bytes from file, verify that checksum of first n bytes is
+// stored in the last 4 bytes and store the first n bytes in *result.
+// May use *storage as backing store.
+static Status ReadChecksummed(RandomAccessFile* file, uint64 offset,
+ size_t n, StringPiece* result,
+ string* storage) {
+ if (n >= SIZE_MAX - sizeof(uint32)) {
+ return errors::DataLoss("record size too large");
+ }
+
+ const size_t expected = n + sizeof(uint32);
+ storage->resize(expected);
+ StringPiece data;
+ Status s = file->Read(offset, expected, &data, &(*storage)[0]);
+ if (!s.ok()) {
+ return s;
+ }
+ if (data.size() != expected) {
+ if (data.size() == 0) {
+ return errors::OutOfRange("eof");
+ } else {
+ return errors::DataLoss("truncated record at ", offset);
+ }
+ }
+ uint32 masked_crc = core::DecodeFixed32(data.data() + n);
+ if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) {
+ return errors::DataLoss("corrupted record at ", offset);
+ }
+ *result = StringPiece(data.data(), n);
+ return Status::OK();
+}
+
+Status RecordReader::ReadRecord(uint64* offset, string* record) {
+ static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
+ static const size_t kFooterSize = sizeof(uint32);
+
+ // Read length
+ StringPiece lbuf;
+ Status s = ReadChecksummed(src_, *offset, sizeof(uint64), &lbuf, record);
+ if (!s.ok()) {
+ return s;
+ }
+ const uint64 length = core::DecodeFixed64(lbuf.data());
+
+ // Read data
+ StringPiece data;
+ s = ReadChecksummed(src_, *offset + kHeaderSize, length, &data, record);
+ if (!s.ok()) {
+ if (errors::IsOutOfRange(s)) {
+ s = errors::DataLoss("truncated record at ", *offset);
+ }
+ return s;
+ }
+ if (record->data() != data.data()) {
+ // RandomAccessFile placed the data in some other location.
+ memmove(&(*record)[0], data.data(), data.size());
+ }
+
+ record->resize(data.size());
+ *offset += kHeaderSize + length + kFooterSize;
+ return Status::OK();
+}
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h
new file mode 100644
index 0000000000..a8c1b0dd5d
--- /dev/null
+++ b/tensorflow/core/lib/io/record_reader.h
@@ -0,0 +1,36 @@
+#ifndef TENSORFLOW_LIB_IO_RECORD_READER_H_
+#define TENSORFLOW_LIB_IO_RECORD_READER_H_
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+class RandomAccessFile;
+
+namespace io {
+
+class RecordReader {
+ public:
+ // Create a reader that will return log records from "*file".
+ // "*file" must remain live while this Reader is in use.
+ explicit RecordReader(RandomAccessFile* file);
+
+ ~RecordReader();
+
+ // Read the record at "*offset" into *record and update *offset to
+ // point to the offset of the next record. Returns OK on success,
+ // OUT_OF_RANGE for end of file, or something else for an error.
+ Status ReadRecord(uint64* offset, string* record);
+
+ private:
+ RandomAccessFile* src_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RecordReader);
+};
+
+} // namespace io
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_RECORD_READER_H_
diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc
new file mode 100644
index 0000000000..3d7f1509ab
--- /dev/null
+++ b/tensorflow/core/lib/io/record_writer.cc
@@ -0,0 +1,42 @@
+#include "tensorflow/core/lib/io/record_writer.h"
+
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
+
+namespace tensorflow {
+namespace io {
+
+RecordWriter::RecordWriter(WritableFile* dest) : dest_(dest) {}
+
+RecordWriter::~RecordWriter() {}
+
+static uint32 MaskedCrc(const char* data, size_t n) {
+ return crc32c::Mask(crc32c::Value(data, n));
+}
+
+Status RecordWriter::WriteRecord(StringPiece data) {
+ // Format of a single record:
+ // uint64 length
+ // uint32 masked crc of length
+ // byte data[length]
+ // uint32 masked crc of data
+ char header[sizeof(uint64) + sizeof(uint32)];
+ core::EncodeFixed64(header + 0, data.size());
+ core::EncodeFixed32(header + sizeof(uint64),
+ MaskedCrc(header, sizeof(uint64)));
+ Status s = dest_->Append(StringPiece(header, sizeof(header)));
+ if (!s.ok()) {
+ return s;
+ }
+ s = dest_->Append(data);
+ if (!s.ok()) {
+ return s;
+ }
+ char footer[sizeof(uint32)];
+ core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size()));
+ return dest_->Append(StringPiece(footer, sizeof(footer)));
+}
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h
new file mode 100644
index 0000000000..c7af00e5ae
--- /dev/null
+++ b/tensorflow/core/lib/io/record_writer.h
@@ -0,0 +1,34 @@
+#ifndef TENSORFLOW_LIB_IO_RECORD_WRITER_H_
+#define TENSORFLOW_LIB_IO_RECORD_WRITER_H_
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+class WritableFile;
+
+namespace io {
+
+class RecordWriter {
+ public:
+ // Create a writer that will append data to "*dest".
+ // "*dest" must be initially empty.
+ // "*dest" must remain live while this Writer is in use.
+ explicit RecordWriter(WritableFile* dest);
+
+ ~RecordWriter();
+
+ Status WriteRecord(StringPiece slice);
+
+ private:
+ WritableFile* const dest_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter);
+};
+
+} // namespace io
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_RECORD_WRITER_H_
diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc
new file mode 100644
index 0000000000..3e9c816443
--- /dev/null
+++ b/tensorflow/core/lib/io/recordio_test.cc
@@ -0,0 +1,245 @@
+#include "tensorflow/core/lib/io/record_reader.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace io {
+
+// Construct a string of the specified length made out of the supplied
+// partial string.
+static string BigString(const string& partial_string, size_t n) {
+ string result;
+ while (result.size() < n) {
+ result.append(partial_string);
+ }
+ result.resize(n);
+ return result;
+}
+
+// Construct a string from a number
+static string NumberString(int n) {
+ char buf[50];
+ snprintf(buf, sizeof(buf), "%d.", n);
+ return string(buf);
+}
+
+// Return a skewed potentially long string
+static string RandomSkewedString(int i, random::SimplePhilox* rnd) {
+ return BigString(NumberString(i), rnd->Skewed(17));
+}
+
+class RecordioTest : public testing::Test {
+ private:
+ class StringDest : public WritableFile {
+ public:
+ string contents_;
+
+ Status Close() override { return Status::OK(); }
+ Status Flush() override { return Status::OK(); }
+ Status Sync() override { return Status::OK(); }
+ Status Append(const StringPiece& slice) override {
+ contents_.append(slice.data(), slice.size());
+ return Status::OK();
+ }
+ };
+
+ class StringSource : public RandomAccessFile {
+ public:
+ StringPiece contents_;
+ mutable bool force_error_;
+ mutable bool returned_partial_;
+ StringSource() : force_error_(false), returned_partial_(false) {}
+
+ Status Read(uint64 offset, size_t n, StringPiece* result,
+ char* scratch) const override {
+ EXPECT_FALSE(returned_partial_) << "must not Read() after eof/error";
+
+ if (force_error_) {
+ force_error_ = false;
+ returned_partial_ = true;
+ return errors::DataLoss("read error");
+ }
+
+ if (offset >= contents_.size()) {
+ return errors::OutOfRange("end of file");
+ }
+
+ if (contents_.size() < offset + n) {
+ n = contents_.size() - offset;
+ returned_partial_ = true;
+ }
+ *result = StringPiece(contents_.data() + offset, n);
+ return Status::OK();
+ }
+ };
+
+ StringDest dest_;
+ StringSource source_;
+ bool reading_;
+ uint64 readpos_;
+ RecordWriter* writer_;
+ RecordReader* reader_;
+
+ public:
+ RecordioTest()
+ : reading_(false),
+ readpos_(0),
+ writer_(new RecordWriter(&dest_)),
+ reader_(new RecordReader(&source_)) {}
+
+ ~RecordioTest() override {
+ delete writer_;
+ delete reader_;
+ }
+
+ void Write(const string& msg) {
+ ASSERT_TRUE(!reading_) << "Write() after starting to read";
+ ASSERT_OK(writer_->WriteRecord(StringPiece(msg)));
+ }
+
+ size_t WrittenBytes() const { return dest_.contents_.size(); }
+
+ string Read() {
+ if (!reading_) {
+ reading_ = true;
+ source_.contents_ = StringPiece(dest_.contents_);
+ }
+ string record;
+ Status s = reader_->ReadRecord(&readpos_, &record);
+ if (s.ok()) {
+ return record;
+ } else if (errors::IsOutOfRange(s)) {
+ return "EOF";
+ } else {
+ return s.ToString();
+ }
+ }
+
+ void IncrementByte(int offset, int delta) {
+ dest_.contents_[offset] += delta;
+ }
+
+ void SetByte(int offset, char new_byte) {
+ dest_.contents_[offset] = new_byte;
+ }
+
+ void ShrinkSize(int bytes) {
+ dest_.contents_.resize(dest_.contents_.size() - bytes);
+ }
+
+ void FixChecksum(int header_offset, int len) {
+ // Compute crc of type/len/data
+ uint32_t crc = crc32c::Value(&dest_.contents_[header_offset + 6], 1 + len);
+ crc = crc32c::Mask(crc);
+ core::EncodeFixed32(&dest_.contents_[header_offset], crc);
+ }
+
+ void ForceError() { source_.force_error_ = true; }
+
+ void StartReadingAt(uint64_t initial_offset) { readpos_ = initial_offset; }
+
+ void CheckOffsetPastEndReturnsNoRecords(uint64_t offset_past_end) {
+ Write("foo");
+ Write("bar");
+ Write(BigString("x", 10000));
+ reading_ = true;
+ source_.contents_ = StringPiece(dest_.contents_);
+ uint64 offset = WrittenBytes() + offset_past_end;
+ string record;
+ Status s = reader_->ReadRecord(&offset, &record);
+ ASSERT_TRUE(errors::IsOutOfRange(s)) << s;
+ }
+};
+
+TEST_F(RecordioTest, Empty) { ASSERT_EQ("EOF", Read()); }
+
+TEST_F(RecordioTest, ReadWrite) {
+ Write("foo");
+ Write("bar");
+ Write("");
+ Write("xxxx");
+ ASSERT_EQ("foo", Read());
+ ASSERT_EQ("bar", Read());
+ ASSERT_EQ("", Read());
+ ASSERT_EQ("xxxx", Read());
+ ASSERT_EQ("EOF", Read());
+ ASSERT_EQ("EOF", Read()); // Make sure reads at eof work
+}
+
+TEST_F(RecordioTest, ManyRecords) {
+ for (int i = 0; i < 100000; i++) {
+ Write(NumberString(i));
+ }
+ for (int i = 0; i < 100000; i++) {
+ ASSERT_EQ(NumberString(i), Read());
+ }
+ ASSERT_EQ("EOF", Read());
+}
+
+TEST_F(RecordioTest, RandomRead) {
+ const int N = 500;
+ {
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ for (int i = 0; i < N; i++) {
+ Write(RandomSkewedString(i, &rnd));
+ }
+ }
+ {
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ for (int i = 0; i < N; i++) {
+ ASSERT_EQ(RandomSkewedString(i, &rnd), Read());
+ }
+ }
+ ASSERT_EQ("EOF", Read());
+}
+
+// Tests of all the error paths in log_reader.cc follow:
+static void AssertHasSubstr(StringPiece s, StringPiece expected) {
+ EXPECT_TRUE(StringPiece(s).contains(expected)) << s << " does not contain "
+ << expected;
+}
+
+TEST_F(RecordioTest, ReadError) {
+ Write("foo");
+ ForceError();
+ AssertHasSubstr(Read(), "Data loss");
+}
+
+TEST_F(RecordioTest, CorruptLength) {
+ Write("foo");
+ IncrementByte(6, 100);
+ AssertHasSubstr(Read(), "Data loss");
+}
+
+TEST_F(RecordioTest, CorruptLengthCrc) {
+ Write("foo");
+ IncrementByte(10, 100);
+ AssertHasSubstr(Read(), "Data loss");
+}
+
+TEST_F(RecordioTest, CorruptData) {
+ Write("foo");
+ IncrementByte(14, 10);
+ AssertHasSubstr(Read(), "Data loss");
+}
+
+TEST_F(RecordioTest, CorruptDataCrc) {
+ Write("foo");
+ IncrementByte(WrittenBytes() - 1, 10);
+ AssertHasSubstr(Read(), "Data loss");
+}
+
+TEST_F(RecordioTest, ReadEnd) { CheckOffsetPastEndReturnsNoRecords(0); }
+
+TEST_F(RecordioTest, ReadPastEnd) { CheckOffsetPastEndReturnsNoRecords(5); }
+
+} // namespace io
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/table.cc b/tensorflow/core/lib/io/table.cc
new file mode 100644
index 0000000000..769d7e72a5
--- /dev/null
+++ b/tensorflow/core/lib/io/table.cc
@@ -0,0 +1,169 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "tensorflow/core/lib/io/table.h"
+
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/io/block.h"
+#include "tensorflow/core/lib/io/format.h"
+#include "tensorflow/core/lib/io/table_options.h"
+#include "tensorflow/core/lib/io/two_level_iterator.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace table {
+
+struct Table::Rep {
+ ~Rep() { delete index_block; }
+
+ Options options;
+ Status status;
+ RandomAccessFile* file;
+ // XXX uint64 cache_id;
+
+ BlockHandle metaindex_handle; // Handle to metaindex_block: saved from footer
+ Block* index_block;
+};
+
+Status Table::Open(const Options& options, RandomAccessFile* file,
+ uint64 size, Table** table) {
+ *table = NULL;
+ if (size < Footer::kEncodedLength) {
+ return errors::DataLoss("file is too short to be an sstable");
+ }
+
+ char footer_space[Footer::kEncodedLength];
+ StringPiece footer_input;
+ Status s =
+ file->Read(size - Footer::kEncodedLength, Footer::kEncodedLength,
+ &footer_input, footer_space);
+ if (!s.ok()) return s;
+
+ Footer footer;
+ s = footer.DecodeFrom(&footer_input);
+ if (!s.ok()) return s;
+
+ // Read the index block
+ BlockContents contents;
+ Block* index_block = NULL;
+ if (s.ok()) {
+ s = ReadBlock(file, footer.index_handle(), &contents);
+ if (s.ok()) {
+ index_block = new Block(contents);
+ }
+ }
+
+ if (s.ok()) {
+ // We've successfully read the footer and the index block: we're
+ // ready to serve requests.
+ Rep* rep = new Table::Rep;
+ rep->options = options;
+ rep->file = file;
+ rep->metaindex_handle = footer.metaindex_handle();
+ rep->index_block = index_block;
+ // XXX rep->cache_id = (options.block_cache ?
+ // options.block_cache->NewId() : 0);
+ *table = new Table(rep);
+ } else {
+ if (index_block) delete index_block;
+ }
+
+ return s;
+}
+
+Table::~Table() { delete rep_; }
+
+static void DeleteBlock(void* arg, void* ignored) {
+ delete reinterpret_cast<Block*>(arg);
+}
+
+// Convert an index iterator value (i.e., an encoded BlockHandle)
+// into an iterator over the contents of the corresponding block.
+Iterator* Table::BlockReader(void* arg, const StringPiece& index_value) {
+ Table* table = reinterpret_cast<Table*>(arg);
+ // Cache* block_cache = table->rep_->options.block_cache;
+ Block* block = NULL;
+ // Cache::Handle* cache_handle = NULL;
+
+ BlockHandle handle;
+ StringPiece input = index_value;
+ Status s = handle.DecodeFrom(&input);
+ // We intentionally allow extra stuff in index_value so that we
+ // can add more features in the future.
+
+ if (s.ok()) {
+ BlockContents contents;
+ s = ReadBlock(table->rep_->file, handle, &contents);
+ if (s.ok()) {
+ block = new Block(contents);
+ }
+ }
+
+ Iterator* iter;
+ if (block != NULL) {
+ iter = block->NewIterator();
+ iter->RegisterCleanup(&DeleteBlock, block, NULL);
+ } else {
+ iter = NewErrorIterator(s);
+ }
+ return iter;
+}
+
+Iterator* Table::NewIterator() const {
+ return NewTwoLevelIterator(rep_->index_block->NewIterator(),
+ &Table::BlockReader, const_cast<Table*>(this));
+}
+
+Status Table::InternalGet(const StringPiece& k, void* arg,
+ void (*saver)(void*, const StringPiece&,
+ const StringPiece&)) {
+ Status s;
+ Iterator* iiter = rep_->index_block->NewIterator();
+ iiter->Seek(k);
+ if (iiter->Valid()) {
+ BlockHandle handle;
+ Iterator* block_iter = BlockReader(this, iiter->value());
+ block_iter->Seek(k);
+ if (block_iter->Valid()) {
+ (*saver)(arg, block_iter->key(), block_iter->value());
+ }
+ s = block_iter->status();
+ delete block_iter;
+ }
+ if (s.ok()) {
+ s = iiter->status();
+ }
+ delete iiter;
+ return s;
+}
+
+uint64 Table::ApproximateOffsetOf(const StringPiece& key) const {
+ Iterator* index_iter = rep_->index_block->NewIterator();
+ index_iter->Seek(key);
+ uint64 result;
+ if (index_iter->Valid()) {
+ BlockHandle handle;
+ StringPiece input = index_iter->value();
+ Status s = handle.DecodeFrom(&input);
+ if (s.ok()) {
+ result = handle.offset();
+ } else {
+ // Strange: we can't decode the block handle in the index block.
+ // We'll just return the offset of the metaindex block, which is
+ // close to the whole file size for this case.
+ result = rep_->metaindex_handle.offset();
+ }
+ } else {
+ // key is past the last key in the file. Approximate the offset
+ // by returning the offset of the metaindex block (which is
+ // right near the end of the file).
+ result = rep_->metaindex_handle.offset();
+ }
+ delete index_iter;
+ return result;
+}
+
+} // namespace table
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/table.h b/tensorflow/core/lib/io/table.h
new file mode 100644
index 0000000000..230dded2d4
--- /dev/null
+++ b/tensorflow/core/lib/io/table.h
@@ -0,0 +1,76 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef TENSORFLOW_LIB_IO_TABLE_H_
+#define TENSORFLOW_LIB_IO_TABLE_H_
+
+#include <stdint.h>
+#include "tensorflow/core/lib/io/iterator.h"
+
+namespace tensorflow {
+class RandomAccessFile;
+
+namespace table {
+
+class Block;
+class BlockHandle;
+class Footer;
+struct Options;
+
+// A Table is a sorted map from strings to strings. Tables are
+// immutable and persistent. A Table may be safely accessed from
+// multiple threads without external synchronization.
+class Table {
+ public:
+ // Attempt to open the table that is stored in bytes [0..file_size)
+ // of "file", and read the metadata entries necessary to allow
+ // retrieving data from the table.
+ //
+ // If successful, returns ok and sets "*table" to the newly opened
+ // table. The client should delete "*table" when no longer needed.
+ // If there was an error while initializing the table, sets "*table"
+ // to NULL and returns a non-ok status. Does not take ownership of
+ // "*file", but the client must ensure that "file" remains live
+ // for the duration of the returned table's lifetime.
+ static Status Open(const Options& options, RandomAccessFile* file,
+ uint64 file_size, Table** table);
+
+ ~Table();
+
+ // Returns a new iterator over the table contents.
+ // The result of NewIterator() is initially invalid (caller must
+ // call one of the Seek methods on the iterator before using it).
+ Iterator* NewIterator() const;
+
+ // Given a key, return an approximate byte offset in the file where
+ // the data for that key begins (or would begin if the key were
+ // present in the file). The returned value is in terms of file
+ // bytes, and so includes effects like compression of the underlying data.
+ // E.g., the approximate offset of the last key in the table will
+ // be close to the file length.
+ uint64 ApproximateOffsetOf(const StringPiece& key) const;
+
+ private:
+ struct Rep;
+ Rep* rep_;
+
+ explicit Table(Rep* rep) { rep_ = rep; }
+ static Iterator* BlockReader(void*, const StringPiece&);
+
+ // Calls (*handle_result)(arg, ...) with the entry found after a call
+ // to Seek(key). May not make such a call if filter policy says
+ // that key is not present.
+ Status InternalGet(const StringPiece& key, void* arg,
+ void (*handle_result)(void* arg, const StringPiece& k,
+ const StringPiece& v));
+
+ // No copying allowed
+ Table(const Table&);
+ void operator=(const Table&);
+};
+
+} // namespace table
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_TABLE_H_
diff --git a/tensorflow/core/lib/io/table_builder.cc b/tensorflow/core/lib/io/table_builder.cc
new file mode 100644
index 0000000000..b786888b30
--- /dev/null
+++ b/tensorflow/core/lib/io/table_builder.cc
@@ -0,0 +1,263 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "tensorflow/core/lib/io/table_builder.h"
+
+#include <assert.h>
+#include "tensorflow/core/lib/io/block_builder.h"
+#include "tensorflow/core/lib/io/format.h"
+#include "tensorflow/core/lib/io/table_options.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace table {
+
+namespace {
+
+void FindShortestSeparator(string* start, const StringPiece& limit) {
+ // Find length of common prefix
+ size_t min_length = std::min(start->size(), limit.size());
+ size_t diff_index = 0;
+ while ((diff_index < min_length) &&
+ ((*start)[diff_index] == limit[diff_index])) {
+ diff_index++;
+ }
+
+ if (diff_index >= min_length) {
+ // Do not shorten if one string is a prefix of the other
+ } else {
+ uint8 diff_byte = static_cast<uint8>((*start)[diff_index]);
+ if (diff_byte < static_cast<uint8>(0xff) &&
+ diff_byte + 1 < static_cast<uint8>(limit[diff_index])) {
+ (*start)[diff_index]++;
+ start->resize(diff_index + 1);
+ assert(StringPiece(*start).compare(limit) < 0);
+ }
+ }
+}
+
+void FindShortSuccessor(string* key) {
+ // Find first character that can be incremented
+ size_t n = key->size();
+ for (size_t i = 0; i < n; i++) {
+ const uint8 byte = (*key)[i];
+ if (byte != static_cast<uint8>(0xff)) {
+ (*key)[i] = byte + 1;
+ key->resize(i + 1);
+ return;
+ }
+ }
+ // *key is a run of 0xffs. Leave it alone.
+}
+} // namespace
+
+struct TableBuilder::Rep {
+ Options options;
+ Options index_block_options;
+ WritableFile* file;
+ uint64 offset;
+ Status status;
+ BlockBuilder data_block;
+ BlockBuilder index_block;
+ string last_key;
+ int64 num_entries;
+ bool closed; // Either Finish() or Abandon() has been called.
+
+ // We do not emit the index entry for a block until we have seen the
+ // first key for the next data block. This allows us to use shorter
+ // keys in the index block. For example, consider a block boundary
+ // between the keys "the quick brown fox" and "the who". We can use
+ // "the r" as the key for the index block entry since it is >= all
+ // entries in the first block and < all entries in subsequent
+ // blocks.
+ //
+ // Invariant: r->pending_index_entry is true only if data_block is empty.
+ bool pending_index_entry;
+ BlockHandle pending_handle; // Handle to add to index block
+
+ string compressed_output;
+
+ Rep(const Options& opt, WritableFile* f)
+ : options(opt),
+ index_block_options(opt),
+ file(f),
+ offset(0),
+ data_block(&options),
+ index_block(&index_block_options),
+ num_entries(0),
+ closed(false),
+ pending_index_entry(false) {
+ index_block_options.block_restart_interval = 1;
+ }
+};
+
+TableBuilder::TableBuilder(const Options& options, WritableFile* file)
+ : rep_(new Rep(options, file)) {}
+
+TableBuilder::~TableBuilder() {
+ assert(rep_->closed); // Catch errors where caller forgot to call Finish()
+ delete rep_;
+}
+
+void TableBuilder::Add(const StringPiece& key, const StringPiece& value) {
+ Rep* r = rep_;
+ assert(!r->closed);
+ if (!ok()) return;
+ if (r->num_entries > 0) {
+ assert(key.compare(StringPiece(r->last_key)) > 0);
+ // See if this key+value would make our current block overly large. If
+ // so, emit the current block before adding this key/value
+ const int kOverlyLargeBlockRatio = 2;
+ const size_t this_entry_bytes = key.size() + value.size();
+ if (this_entry_bytes >= kOverlyLargeBlockRatio * r->options.block_size) {
+ Flush();
+ }
+ }
+
+ if (r->pending_index_entry) {
+ assert(r->data_block.empty());
+ FindShortestSeparator(&r->last_key, key);
+ string handle_encoding;
+ r->pending_handle.EncodeTo(&handle_encoding);
+ r->index_block.Add(r->last_key, StringPiece(handle_encoding));
+ r->pending_index_entry = false;
+ }
+
+ r->last_key.assign(key.data(), key.size());
+ r->num_entries++;
+ r->data_block.Add(key, value);
+
+ const size_t estimated_block_size = r->data_block.CurrentSizeEstimate();
+ if (estimated_block_size >= r->options.block_size) {
+ Flush();
+ }
+}
+
+void TableBuilder::Flush() {
+ Rep* r = rep_;
+ assert(!r->closed);
+ if (!ok()) return;
+ if (r->data_block.empty()) return;
+ assert(!r->pending_index_entry);
+ WriteBlock(&r->data_block, &r->pending_handle);
+ if (ok()) {
+ r->pending_index_entry = true;
+ r->status = r->file->Flush();
+ }
+}
+
+void TableBuilder::WriteBlock(BlockBuilder* block, BlockHandle* handle) {
+ // File format contains a sequence of blocks where each block has:
+ // block_data: uint8[n]
+ // type: uint8
+ // crc: uint32
+ assert(ok());
+ Rep* r = rep_;
+ StringPiece raw = block->Finish();
+
+ StringPiece block_contents;
+ CompressionType type = r->options.compression;
+ // TODO(postrelease): Support more compression options: zlib?
+ switch (type) {
+ case kNoCompression:
+ block_contents = raw;
+ break;
+
+ case kSnappyCompression: {
+ string* compressed = &r->compressed_output;
+ if (port::Snappy_Compress(raw.data(), raw.size(), compressed) &&
+ compressed->size() < raw.size() - (raw.size() / 8u)) {
+ block_contents = *compressed;
+ } else {
+ // Snappy not supported, or compressed less than 12.5%, so just
+ // store uncompressed form
+ block_contents = raw;
+ type = kNoCompression;
+ }
+ break;
+ }
+ }
+ WriteRawBlock(block_contents, type, handle);
+ r->compressed_output.clear();
+ block->Reset();
+}
+
+void TableBuilder::WriteRawBlock(const StringPiece& block_contents,
+ CompressionType type, BlockHandle* handle) {
+ Rep* r = rep_;
+ handle->set_offset(r->offset);
+ handle->set_size(block_contents.size());
+ r->status = r->file->Append(block_contents);
+ if (r->status.ok()) {
+ char trailer[kBlockTrailerSize];
+ trailer[0] = type;
+ uint32 crc = crc32c::Value(block_contents.data(), block_contents.size());
+ crc = crc32c::Extend(crc, trailer, 1); // Extend crc to cover block type
+ core::EncodeFixed32(trailer + 1, crc32c::Mask(crc));
+ r->status = r->file->Append(StringPiece(trailer, kBlockTrailerSize));
+ if (r->status.ok()) {
+ r->offset += block_contents.size() + kBlockTrailerSize;
+ }
+ }
+}
+
+Status TableBuilder::status() const { return rep_->status; }
+
+Status TableBuilder::Finish() {
+ Rep* r = rep_;
+ Flush();
+ assert(!r->closed);
+ r->closed = true;
+
+ BlockHandle metaindex_block_handle, index_block_handle;
+
+ // Write metaindex block
+ if (ok()) {
+ BlockBuilder meta_index_block(&r->options);
+ // TODO(postrelease): Add stats and other meta blocks
+ WriteBlock(&meta_index_block, &metaindex_block_handle);
+ }
+
+ // Write index block
+ if (ok()) {
+ if (r->pending_index_entry) {
+ FindShortSuccessor(&r->last_key);
+ string handle_encoding;
+ r->pending_handle.EncodeTo(&handle_encoding);
+ r->index_block.Add(r->last_key, StringPiece(handle_encoding));
+ r->pending_index_entry = false;
+ }
+ WriteBlock(&r->index_block, &index_block_handle);
+ }
+
+ // Write footer
+ if (ok()) {
+ Footer footer;
+ footer.set_metaindex_handle(metaindex_block_handle);
+ footer.set_index_handle(index_block_handle);
+ string footer_encoding;
+ footer.EncodeTo(&footer_encoding);
+ r->status = r->file->Append(footer_encoding);
+ if (r->status.ok()) {
+ r->offset += footer_encoding.size();
+ }
+ }
+ return r->status;
+}
+
+void TableBuilder::Abandon() {
+ Rep* r = rep_;
+ assert(!r->closed);
+ r->closed = true;
+}
+
+uint64 TableBuilder::NumEntries() const { return rep_->num_entries; }
+
+uint64 TableBuilder::FileSize() const { return rep_->offset; }
+
+} // namespace table
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/table_builder.h b/tensorflow/core/lib/io/table_builder.h
new file mode 100644
index 0000000000..cebf4d8e0c
--- /dev/null
+++ b/tensorflow/core/lib/io/table_builder.h
@@ -0,0 +1,87 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+//
+// TableBuilder provides the interface used to build a Table
+// (an immutable and sorted map from keys to values).
+//
+// Multiple threads can invoke const methods on a TableBuilder without
+// external synchronization, but if any of the threads may call a
+// non-const method, all threads accessing the same TableBuilder must use
+// external synchronization.
+
+#ifndef TENSORFLOW_LIB_IO_TABLE_BUILDER_H_
+#define TENSORFLOW_LIB_IO_TABLE_BUILDER_H_
+
+#include <stdint.h>
+#include "tensorflow/core/lib/io/table_options.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+class WritableFile;
+namespace table {
+
+class BlockBuilder;
+class BlockHandle;
+
+class TableBuilder {
+ public:
+ // Create a builder that will store the contents of the table it is
+ // building in *file. Does not close the file. It is up to the
+ // caller to close the file after calling Finish().
+ TableBuilder(const Options& options, WritableFile* file);
+
+ // REQUIRES: Either Finish() or Abandon() has been called.
+ ~TableBuilder();
+
+ // Add key,value to the table being constructed.
+ // REQUIRES: key is after any previously added key in lexicographic order.
+ // REQUIRES: Finish(), Abandon() have not been called
+ void Add(const StringPiece& key, const StringPiece& value);
+
+ // Advanced operation: flush any buffered key/value pairs to file.
+ // Can be used to ensure that two adjacent entries never live in
+ // the same data block. Most clients should not need to use this method.
+ // REQUIRES: Finish(), Abandon() have not been called
+ void Flush();
+
+ // Return non-ok iff some error has been detected.
+ Status status() const;
+
+ // Finish building the table. Stops using the file passed to the
+ // constructor after this function returns.
+ // REQUIRES: Finish(), Abandon() have not been called
+ Status Finish();
+
+ // Indicate that the contents of this builder should be abandoned. Stops
+ // using the file passed to the constructor after this function returns.
+ // If the caller is not going to call Finish(), it must call Abandon()
+ // before destroying this builder.
+ // REQUIRES: Finish(), Abandon() have not been called
+ void Abandon();
+
+ // Number of calls to Add() so far.
+ uint64 NumEntries() const;
+
+ // Size of the file generated so far. If invoked after a successful
+ // Finish() call, returns the size of the final generated file.
+ uint64 FileSize() const;
+
+ private:
+ bool ok() const { return status().ok(); }
+ void WriteBlock(BlockBuilder* block, BlockHandle* handle);
+ void WriteRawBlock(const StringPiece& data, CompressionType,
+ BlockHandle* handle);
+
+ struct Rep;
+ Rep* rep_;
+
+ // No copying allowed
+ TableBuilder(const TableBuilder&);
+ void operator=(const TableBuilder&);
+};
+
+} // namespace table
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_TABLE_BUILDER_H_
diff --git a/tensorflow/core/lib/io/table_format.txt b/tensorflow/core/lib/io/table_format.txt
new file mode 100644
index 0000000000..7edb9fb121
--- /dev/null
+++ b/tensorflow/core/lib/io/table_format.txt
@@ -0,0 +1,8 @@
+File format
+===========
+
+The table format is heavily based on the table format for the LevelDB
+open source key/value store, with the exception that our tables
+do not support "filter" meta blocks (Bloom Filters). See:
+
+https://code.google.com/p/leveldb/source/browse/doc/table_format.txt
diff --git a/tensorflow/core/lib/io/table_options.h b/tensorflow/core/lib/io/table_options.h
new file mode 100644
index 0000000000..45b061b03b
--- /dev/null
+++ b/tensorflow/core/lib/io/table_options.h
@@ -0,0 +1,53 @@
+#ifndef TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_
+#define TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_
+
+#include <stddef.h>
+
+namespace tensorflow {
+namespace table {
+
+// DB contents are stored in a set of blocks, each of which holds a
+// sequence of key,value pairs. Each block may be compressed before
+// being stored in a file. The following enum describes which
+// compression method (if any) is used to compress a block.
+enum CompressionType {
+ // NOTE: do not change the values of existing entries, as these are
+ // part of the persistent format on disk.
+ kNoCompression = 0x0,
+ kSnappyCompression = 0x1
+};
+
+// Options to control the behavior of a table (passed to Table::Open)
+struct Options {
+ // Approximate size of user data packed per block. Note that the
+ // block size specified here corresponds to uncompressed data. The
+ // actual size of the unit read from disk may be smaller if
+ // compression is enabled. This parameter can be changed dynamically.
+ size_t block_size = 262144;
+
+ // Number of keys between restart points for delta encoding of keys.
+ // This parameter can be changed dynamically. Most clients should
+ // leave this parameter alone.
+ int block_restart_interval = 16;
+
+ // Compress blocks using the specified compression algorithm. This
+ // parameter can be changed dynamically.
+ //
+ // Default: kSnappyCompression, which gives lightweight but fast
+ // compression.
+ //
+ // Typical speeds of kSnappyCompression on an Intel(R) Core(TM)2 2.4GHz:
+ // ~200-500MB/s compression
+ // ~400-800MB/s decompression
+ // Note that these speeds are significantly faster than most
+ // persistent storage speeds, and therefore it is typically never
+ // worth switching to kNoCompression. Even if the input data is
+ // incompressible, the kSnappyCompression implementation will
+ // efficiently detect that and will switch to uncompressed mode.
+ CompressionType compression = kSnappyCompression;
+};
+
+} // namespace table
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_
diff --git a/tensorflow/core/lib/io/table_test.cc b/tensorflow/core/lib/io/table_test.cc
new file mode 100644
index 0000000000..66e90ac64e
--- /dev/null
+++ b/tensorflow/core/lib/io/table_test.cc
@@ -0,0 +1,601 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "tensorflow/core/lib/io/table.h"
+
+#include <map>
+#include <string>
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/io/block.h"
+#include "tensorflow/core/lib/io/block_builder.h"
+#include "tensorflow/core/lib/io/format.h"
+#include "tensorflow/core/lib/io/iterator.h"
+#include "tensorflow/core/lib/io/table_builder.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace table {
+
+namespace test {
+static StringPiece RandomString(random::SimplePhilox* rnd, int len,
+ string* dst) {
+ dst->resize(len);
+ for (int i = 0; i < len; i++) {
+ (*dst)[i] = static_cast<char>(' ' + rnd->Uniform(95)); // ' ' .. '~'
+ }
+ return StringPiece(*dst);
+}
+static string RandomKey(random::SimplePhilox* rnd, int len) {
+ // Make sure to generate a wide variety of characters so we
+ // test the boundary conditions for short-key optimizations.
+ static const char kTestChars[] = {'\0', '\1', 'a', 'b', 'c',
+ 'd', 'e', '\xfd', '\xfe', '\xff'};
+ string result;
+ for (int i = 0; i < len; i++) {
+ result += kTestChars[rnd->Uniform(sizeof(kTestChars))];
+ }
+ return result;
+}
+static StringPiece CompressibleString(random::SimplePhilox* rnd,
+ double compressed_fraction, size_t len,
+ string* dst) {
+ int raw = static_cast<int>(len * compressed_fraction);
+ if (raw < 1) raw = 1;
+ string raw_data;
+ RandomString(rnd, raw, &raw_data);
+
+ // Duplicate the random data until we have filled "len" bytes
+ dst->clear();
+ while (dst->size() < len) {
+ dst->append(raw_data);
+ }
+ dst->resize(len);
+ return StringPiece(*dst);
+}
+}
+
+static void Increment(string* key) { key->push_back('\0'); }
+
+// An STL comparator that compares two StringPieces
+namespace {
+struct STLLessThan {
+ STLLessThan() {}
+ bool operator()(const string& a, const string& b) const {
+ return StringPiece(a).compare(StringPiece(b)) < 0;
+ }
+};
+} // namespace
+
+class StringSink : public WritableFile {
+ public:
+ ~StringSink() {}
+
+ const string& contents() const { return contents_; }
+
+ virtual Status Close() { return Status::OK(); }
+ virtual Status Flush() { return Status::OK(); }
+ virtual Status Sync() { return Status::OK(); }
+
+ virtual Status Append(const StringPiece& data) {
+ contents_.append(data.data(), data.size());
+ return Status::OK();
+ }
+
+ private:
+ string contents_;
+};
+
+class StringSource : public RandomAccessFile {
+ public:
+ StringSource(const StringPiece& contents)
+ : contents_(contents.data(), contents.size()), bytes_read_(0) {}
+
+ virtual ~StringSource() {}
+
+ uint64 Size() const { return contents_.size(); }
+
+ virtual Status Read(uint64 offset, size_t n, StringPiece* result,
+ char* scratch) const {
+ if (offset > contents_.size()) {
+ return errors::InvalidArgument("invalid Read offset");
+ }
+ if (offset + n > contents_.size()) {
+ n = contents_.size() - offset;
+ }
+ memcpy(scratch, &contents_[offset], n);
+ *result = StringPiece(scratch, n);
+ bytes_read_ += n;
+ return Status::OK();
+ }
+
+ uint64 BytesRead() const { return bytes_read_; }
+
+ private:
+ string contents_;
+ mutable uint64 bytes_read_;
+};
+
+typedef std::map<string, string, STLLessThan> KVMap;
+
+// Helper class for tests to unify the interface between
+// BlockBuilder/TableBuilder and Block/Table.
+class Constructor {
+ public:
+ explicit Constructor() : data_(STLLessThan()) {}
+ virtual ~Constructor() {}
+
+ void Add(const string& key, const StringPiece& value) {
+ data_[key] = value.ToString();
+ }
+
+ // Finish constructing the data structure with all the keys that have
+ // been added so far. Returns the keys in sorted order in "*keys"
+ // and stores the key/value pairs in "*kvmap"
+ void Finish(const Options& options, std::vector<string>* keys, KVMap* kvmap) {
+ *kvmap = data_;
+ keys->clear();
+ for (KVMap::const_iterator it = data_.begin(); it != data_.end(); ++it) {
+ keys->push_back(it->first);
+ }
+ data_.clear();
+ Status s = FinishImpl(options, *kvmap);
+ ASSERT_TRUE(s.ok()) << s.ToString();
+ }
+
+ // Construct the data structure from the data in "data"
+ virtual Status FinishImpl(const Options& options, const KVMap& data) = 0;
+
+ virtual Iterator* NewIterator() const = 0;
+
+ virtual const KVMap& data() { return data_; }
+
+ private:
+ KVMap data_;
+};
+
+class BlockConstructor : public Constructor {
+ public:
+ BlockConstructor() : block_(NULL) {}
+ ~BlockConstructor() { delete block_; }
+ virtual Status FinishImpl(const Options& options, const KVMap& data) {
+ delete block_;
+ block_ = NULL;
+ BlockBuilder builder(&options);
+
+ for (KVMap::const_iterator it = data.begin(); it != data.end(); ++it) {
+ builder.Add(it->first, it->second);
+ }
+ // Open the block
+ data_ = builder.Finish().ToString();
+ BlockContents contents;
+ contents.data = data_;
+ contents.cachable = false;
+ contents.heap_allocated = false;
+ block_ = new Block(contents);
+ return Status::OK();
+ }
+ virtual Iterator* NewIterator() const { return block_->NewIterator(); }
+
+ private:
+ string data_;
+ Block* block_;
+};
+
+class TableConstructor : public Constructor {
+ public:
+ TableConstructor() : source_(NULL), table_(NULL) {}
+ ~TableConstructor() { Reset(); }
+ virtual Status FinishImpl(const Options& options, const KVMap& data) {
+ Reset();
+ StringSink sink;
+ TableBuilder builder(options, &sink);
+
+ for (KVMap::const_iterator it = data.begin(); it != data.end(); ++it) {
+ builder.Add(it->first, it->second);
+ TF_CHECK_OK(builder.status());
+ }
+ Status s = builder.Finish();
+ TF_CHECK_OK(s) << s.ToString();
+
+ CHECK_EQ(sink.contents().size(), builder.FileSize());
+
+ // Open the table
+ source_ = new StringSource(sink.contents());
+ Options table_options;
+ return Table::Open(table_options, source_, sink.contents().size(), &table_);
+ }
+
+ virtual Iterator* NewIterator() const { return table_->NewIterator(); }
+
+ uint64 ApproximateOffsetOf(const StringPiece& key) const {
+ return table_->ApproximateOffsetOf(key);
+ }
+
+ uint64 BytesRead() const { return source_->BytesRead(); }
+
+ private:
+ void Reset() {
+ delete table_;
+ delete source_;
+ table_ = NULL;
+ source_ = NULL;
+ }
+
+ StringSource* source_;
+ Table* table_;
+};
+
+enum TestType { TABLE_TEST, BLOCK_TEST };
+
+struct TestArgs {
+ TestType type;
+ int restart_interval;
+};
+
+static const TestArgs kTestArgList[] = {
+ {TABLE_TEST, 16}, {TABLE_TEST, 1}, {TABLE_TEST, 1024},
+ {BLOCK_TEST, 16}, {BLOCK_TEST, 1}, {BLOCK_TEST, 1024},
+};
+static const int kNumTestArgs = sizeof(kTestArgList) / sizeof(kTestArgList[0]);
+
+class Harness : public ::testing::Test {
+ public:
+ Harness() : constructor_(NULL) {}
+
+ void Init(const TestArgs& args) {
+ delete constructor_;
+ constructor_ = NULL;
+ options_ = Options();
+
+ options_.block_restart_interval = args.restart_interval;
+ // Use shorter block size for tests to exercise block boundary
+ // conditions more.
+ options_.block_size = 256;
+ switch (args.type) {
+ case TABLE_TEST:
+ constructor_ = new TableConstructor();
+ break;
+ case BLOCK_TEST:
+ constructor_ = new BlockConstructor();
+ break;
+ }
+ }
+
+ ~Harness() { delete constructor_; }
+
+ void Add(const string& key, const string& value) {
+ constructor_->Add(key, value);
+ }
+
+ void Test(random::SimplePhilox* rnd) {
+ std::vector<string> keys;
+ KVMap data;
+ constructor_->Finish(options_, &keys, &data);
+
+ TestForwardScan(keys, data);
+ TestRandomAccess(rnd, keys, data);
+ }
+
+ void TestForwardScan(const std::vector<string>& keys, const KVMap& data) {
+ Iterator* iter = constructor_->NewIterator();
+ ASSERT_TRUE(!iter->Valid());
+ iter->SeekToFirst();
+ for (KVMap::const_iterator model_iter = data.begin();
+ model_iter != data.end(); ++model_iter) {
+ ASSERT_EQ(ToString(data, model_iter), ToString(iter));
+ iter->Next();
+ }
+ ASSERT_TRUE(!iter->Valid());
+ delete iter;
+ }
+
+ void TestRandomAccess(random::SimplePhilox* rnd,
+ const std::vector<string>& keys, const KVMap& data) {
+ static const bool kVerbose = false;
+ Iterator* iter = constructor_->NewIterator();
+ ASSERT_TRUE(!iter->Valid());
+ KVMap::const_iterator model_iter = data.begin();
+ if (kVerbose) fprintf(stderr, "---\n");
+ for (int i = 0; i < 200; i++) {
+ const int toss = rnd->Uniform(3);
+ switch (toss) {
+ case 0: {
+ if (iter->Valid()) {
+ if (kVerbose) fprintf(stderr, "Next\n");
+ iter->Next();
+ ++model_iter;
+ ASSERT_EQ(ToString(data, model_iter), ToString(iter));
+ }
+ break;
+ }
+
+ case 1: {
+ if (kVerbose) fprintf(stderr, "SeekToFirst\n");
+ iter->SeekToFirst();
+ model_iter = data.begin();
+ ASSERT_EQ(ToString(data, model_iter), ToString(iter));
+ break;
+ }
+
+ case 2: {
+ string key = PickRandomKey(rnd, keys);
+ model_iter = data.lower_bound(key);
+ if (kVerbose)
+ fprintf(stderr, "Seek '%s'\n", str_util::CEscape(key).c_str());
+ iter->Seek(StringPiece(key));
+ ASSERT_EQ(ToString(data, model_iter), ToString(iter));
+ break;
+ }
+ }
+ }
+ delete iter;
+ }
+
+ string ToString(const KVMap& data, const KVMap::const_iterator& it) {
+ if (it == data.end()) {
+ return "END";
+ } else {
+ return "'" + it->first + "->" + it->second + "'";
+ }
+ }
+
+ string ToString(const KVMap& data, const KVMap::const_reverse_iterator& it) {
+ if (it == data.rend()) {
+ return "END";
+ } else {
+ return "'" + it->first + "->" + it->second + "'";
+ }
+ }
+
+ string ToString(const Iterator* it) {
+ if (!it->Valid()) {
+ return "END";
+ } else {
+ return "'" + it->key().ToString() + "->" + it->value().ToString() + "'";
+ }
+ }
+
+ string PickRandomKey(random::SimplePhilox* rnd,
+ const std::vector<string>& keys) {
+ if (keys.empty()) {
+ return "foo";
+ } else {
+ const int index = rnd->Uniform(keys.size());
+ string result = keys[index];
+ switch (rnd->Uniform(3)) {
+ case 0:
+ // Return an existing key
+ break;
+ case 1: {
+ // Attempt to return something smaller than an existing key
+ if (result.size() > 0 && result[result.size() - 1] > '\0') {
+ result[result.size() - 1]--;
+ }
+ break;
+ }
+ case 2: {
+ // Return something larger than an existing key
+ Increment(&result);
+ break;
+ }
+ }
+ return result;
+ }
+ }
+
+ private:
+ Options options_;
+ Constructor* constructor_;
+};
+
+// Test empty table/block.
+TEST_F(Harness, Empty) {
+ for (int i = 0; i < kNumTestArgs; i++) {
+ Init(kTestArgList[i]);
+ random::PhiloxRandom philox(testing::RandomSeed() + 1, 17);
+ random::SimplePhilox rnd(&philox);
+ Test(&rnd);
+ }
+}
+
+// Special test for a block with no restart entries. The C++ leveldb
+// code never generates such blocks, but the Java version of leveldb
+// seems to.
+TEST_F(Harness, ZeroRestartPointsInBlock) {
+ char data[sizeof(uint32)];
+ memset(data, 0, sizeof(data));
+ BlockContents contents;
+ contents.data = StringPiece(data, sizeof(data));
+ contents.cachable = false;
+ contents.heap_allocated = false;
+ Block block(contents);
+ Iterator* iter = block.NewIterator();
+ iter->SeekToFirst();
+ ASSERT_TRUE(!iter->Valid());
+ iter->Seek("foo");
+ ASSERT_TRUE(!iter->Valid());
+ delete iter;
+}
+
+// Test the empty key
+TEST_F(Harness, SimpleEmptyKey) {
+ for (int i = 0; i < kNumTestArgs; i++) {
+ Init(kTestArgList[i]);
+ random::PhiloxRandom philox(testing::RandomSeed() + 1, 17);
+ random::SimplePhilox rnd(&philox);
+ Add("", "v");
+ Test(&rnd);
+ }
+}
+
+TEST_F(Harness, SimpleSingle) {
+ for (int i = 0; i < kNumTestArgs; i++) {
+ Init(kTestArgList[i]);
+ random::PhiloxRandom philox(testing::RandomSeed() + 2, 17);
+ random::SimplePhilox rnd(&philox);
+ Add("abc", "v");
+ Test(&rnd);
+ }
+}
+
+TEST_F(Harness, SimpleMulti) {
+ for (int i = 0; i < kNumTestArgs; i++) {
+ Init(kTestArgList[i]);
+ random::PhiloxRandom philox(testing::RandomSeed() + 3, 17);
+ random::SimplePhilox rnd(&philox);
+ Add("abc", "v");
+ Add("abcd", "v");
+ Add("ac", "v2");
+ Test(&rnd);
+ }
+}
+
+TEST_F(Harness, SimpleMultiBigValues) {
+ for (int i = 0; i < kNumTestArgs; i++) {
+ Init(kTestArgList[i]);
+ random::PhiloxRandom philox(testing::RandomSeed() + 3, 17);
+ random::SimplePhilox rnd(&philox);
+ Add("ainitial", "tiny");
+ Add("anext", string(10000000, 'a'));
+ Add("anext2", string(10000000, 'b'));
+ Add("azz", "tiny");
+ Test(&rnd);
+ }
+}
+
+TEST_F(Harness, SimpleSpecialKey) {
+ for (int i = 0; i < kNumTestArgs; i++) {
+ Init(kTestArgList[i]);
+ random::PhiloxRandom philox(testing::RandomSeed() + 4, 17);
+ random::SimplePhilox rnd(&philox);
+ Add("\xff\xff", "v3");
+ Test(&rnd);
+ }
+}
+
+TEST_F(Harness, Randomized) {
+ for (int i = 0; i < kNumTestArgs; i++) {
+ Init(kTestArgList[i]);
+ random::PhiloxRandom philox(testing::RandomSeed() + 5, 17);
+ random::SimplePhilox rnd(&philox);
+ for (int num_entries = 0; num_entries < 2000;
+ num_entries += (num_entries < 50 ? 1 : 200)) {
+ if ((num_entries % 10) == 0) {
+ fprintf(stderr, "case %d of %d: num_entries = %d\n", (i + 1),
+ int(kNumTestArgs), num_entries);
+ }
+ for (int e = 0; e < num_entries; e++) {
+ string v;
+ Add(test::RandomKey(&rnd, rnd.Skewed(4)),
+ test::RandomString(&rnd, rnd.Skewed(5), &v).ToString());
+ }
+ Test(&rnd);
+ }
+ }
+}
+
+static bool Between(uint64 val, uint64 low, uint64 high) {
+ bool result = (val >= low) && (val <= high);
+ if (!result) {
+ fprintf(stderr, "Value %llu is not in range [%llu, %llu]\n",
+ (unsigned long long)(val), (unsigned long long)(low),
+ (unsigned long long)(high));
+ }
+ return result;
+}
+
+class TableTest {};
+
+TEST(TableTest, ApproximateOffsetOfPlain) {
+ TableConstructor c;
+ c.Add("k01", "hello");
+ c.Add("k02", "hello2");
+ c.Add("k03", string(10000, 'x'));
+ c.Add("k04", string(200000, 'x'));
+ c.Add("k05", string(300000, 'x'));
+ c.Add("k06", "hello3");
+ c.Add("k07", string(100000, 'x'));
+ std::vector<string> keys;
+ KVMap kvmap;
+ Options options;
+ options.block_size = 1024;
+ options.compression = kNoCompression;
+ c.Finish(options, &keys, &kvmap);
+
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("abc"), 0, 0));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k01"), 0, 0));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k01a"), 0, 0));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k02"), 0, 0));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k03"), 10, 500));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k04"), 10000, 11000));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k04a"), 210000, 211000));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k05"), 210000, 211000));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k06"), 510000, 511000));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k07"), 510000, 511000));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("xyz"), 610000, 612000));
+}
+
+static bool SnappyCompressionSupported() {
+ string out;
+ StringPiece in = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
+ return port::Snappy_Compress(in.data(), in.size(), &out);
+}
+
+TEST(TableTest, ApproximateOffsetOfCompressed) {
+ if (!SnappyCompressionSupported()) {
+ fprintf(stderr, "skipping compression tests\n");
+ return;
+ }
+
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ TableConstructor c;
+ string tmp;
+ c.Add("k01", "hello");
+ c.Add("k02", test::CompressibleString(&rnd, 0.25, 10000, &tmp));
+ c.Add("k03", "hello3");
+ c.Add("k04", test::CompressibleString(&rnd, 0.25, 10000, &tmp));
+ std::vector<string> keys;
+ KVMap kvmap;
+ Options options;
+ options.block_size = 1024;
+ options.compression = kSnappyCompression;
+ c.Finish(options, &keys, &kvmap);
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("abc"), 0, 0));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k01"), 0, 0));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k02"), 10, 100));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k03"), 2000, 3000));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("k04"), 2000, 3000));
+ ASSERT_TRUE(Between(c.ApproximateOffsetOf("xyz"), 4000, 6000));
+}
+
+TEST(TableTest, SeekToFirstKeyDoesNotReadTooMuch) {
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ string tmp;
+ TableConstructor c;
+ c.Add("k01", "firstvalue");
+ c.Add("k03", test::CompressibleString(&rnd, 0.25, 1000000, &tmp));
+ c.Add("k04", "abc");
+ std::vector<string> keys;
+ KVMap kvmap;
+ Options options;
+ options.block_size = 1024;
+ options.compression = kNoCompression;
+ c.Finish(options, &keys, &kvmap);
+
+ Iterator* iter = c.NewIterator();
+ iter->Seek("k01");
+ delete iter;
+ // Make sure we don't read the big second block when just trying to
+ // retrieve the data in the first key
+ EXPECT_LT(c.BytesRead(), 200);
+}
+
+} // namespace table
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/two_level_iterator.cc b/tensorflow/core/lib/io/two_level_iterator.cc
new file mode 100644
index 0000000000..409baade6d
--- /dev/null
+++ b/tensorflow/core/lib/io/two_level_iterator.cc
@@ -0,0 +1,148 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#include "tensorflow/core/lib/io/two_level_iterator.h"
+
+#include "tensorflow/core/lib/io/table.h"
+#include "tensorflow/core/lib/io/block.h"
+#include "tensorflow/core/lib/io/format.h"
+#include "tensorflow/core/lib/io/iterator.h"
+
+namespace tensorflow {
+namespace table {
+
+namespace {
+
+typedef Iterator* (*BlockFunction)(void*, const StringPiece&);
+
+class TwoLevelIterator : public Iterator {
+ public:
+ TwoLevelIterator(Iterator* index_iter, BlockFunction block_function,
+ void* arg);
+
+ virtual ~TwoLevelIterator();
+
+ virtual void Seek(const StringPiece& target);
+ virtual void SeekToFirst();
+ virtual void Next();
+
+ virtual bool Valid() const {
+ return (data_iter_ == nullptr) ? false : data_iter_->Valid();
+ }
+ virtual StringPiece key() const {
+ assert(Valid());
+ return data_iter_->key();
+ }
+ virtual StringPiece value() const {
+ assert(Valid());
+ return data_iter_->value();
+ }
+ virtual Status status() const {
+ // It'd be nice if status() returned a const Status& instead of a
+ // Status
+ if (!index_iter_->status().ok()) {
+ return index_iter_->status();
+ } else if (data_iter_ != NULL && !data_iter_->status().ok()) {
+ return data_iter_->status();
+ } else {
+ return status_;
+ }
+ }
+
+ private:
+ void SaveError(const Status& s) {
+ if (status_.ok() && !s.ok()) status_ = s;
+ }
+ void SkipEmptyDataBlocksForward();
+ void SetDataIterator(Iterator* data_iter);
+ void InitDataBlock();
+
+ BlockFunction block_function_;
+ void* arg_;
+ Status status_;
+ Iterator* index_iter_;
+ Iterator* data_iter_; // May be NULL
+ // If data_iter_ is non-NULL, then "data_block_handle_" holds the
+ // "index_value" passed to block_function_ to create the data_iter_.
+ string data_block_handle_;
+};
+
+TwoLevelIterator::TwoLevelIterator(Iterator* index_iter,
+ BlockFunction block_function, void* arg)
+ : block_function_(block_function),
+ arg_(arg),
+ index_iter_(index_iter),
+ data_iter_(NULL) {}
+
+TwoLevelIterator::~TwoLevelIterator() {
+ delete index_iter_;
+ delete data_iter_;
+}
+
+void TwoLevelIterator::Seek(const StringPiece& target) {
+ index_iter_->Seek(target);
+ InitDataBlock();
+ if (data_iter_ != NULL) data_iter_->Seek(target);
+ SkipEmptyDataBlocksForward();
+}
+
+void TwoLevelIterator::SeekToFirst() {
+ index_iter_->SeekToFirst();
+ InitDataBlock();
+ if (data_iter_ != NULL) data_iter_->SeekToFirst();
+ SkipEmptyDataBlocksForward();
+}
+
+void TwoLevelIterator::Next() {
+ assert(Valid());
+ data_iter_->Next();
+ SkipEmptyDataBlocksForward();
+}
+
+void TwoLevelIterator::SkipEmptyDataBlocksForward() {
+ while (data_iter_ == NULL || !data_iter_->Valid()) {
+ // Move to next block
+ if (!index_iter_->Valid()) {
+ SetDataIterator(NULL);
+ return;
+ }
+ index_iter_->Next();
+ InitDataBlock();
+ if (data_iter_ != NULL) data_iter_->SeekToFirst();
+ }
+}
+
+void TwoLevelIterator::SetDataIterator(Iterator* data_iter) {
+ if (data_iter_ != NULL) {
+ SaveError(data_iter_->status());
+ delete data_iter_;
+ }
+ data_iter_ = data_iter;
+}
+
+void TwoLevelIterator::InitDataBlock() {
+ if (!index_iter_->Valid()) {
+ SetDataIterator(NULL);
+ } else {
+ StringPiece handle = index_iter_->value();
+ if (data_iter_ != NULL && handle.compare(data_block_handle_) == 0) {
+ // data_iter_ is already constructed with this iterator, so
+ // no need to change anything
+ } else {
+ Iterator* iter = (*block_function_)(arg_, handle);
+ data_block_handle_.assign(handle.data(), handle.size());
+ SetDataIterator(iter);
+ }
+ }
+}
+
+} // namespace
+
+Iterator* NewTwoLevelIterator(Iterator* index_iter,
+ BlockFunction block_function, void* arg) {
+ return new TwoLevelIterator(index_iter, block_function, arg);
+}
+
+} // namespace table
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/two_level_iterator.h b/tensorflow/core/lib/io/two_level_iterator.h
new file mode 100644
index 0000000000..1cc5d2f921
--- /dev/null
+++ b/tensorflow/core/lib/io/two_level_iterator.h
@@ -0,0 +1,30 @@
+// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file. See the AUTHORS file for names of contributors.
+
+#ifndef TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_
+#define TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_
+
+#include "tensorflow/core/lib/io/iterator.h"
+
+namespace tensorflow {
+namespace table {
+
+// Return a new two level iterator. A two-level iterator contains an
+// index iterator whose values point to a sequence of blocks where
+// each block is itself a sequence of key,value pairs. The returned
+// two-level iterator yields the concatenation of all key/value pairs
+// in the sequence of blocks. Takes ownership of "index_iter" and
+// will delete it when no longer needed.
+//
+// Uses a supplied function to convert an index_iter value into
+// an iterator over the contents of the corresponding block.
+extern Iterator* NewTwoLevelIterator(
+ Iterator* index_iter,
+ Iterator* (*block_function)(void* arg, const StringPiece& index_value),
+ void* arg);
+
+} // namespace table
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_
diff --git a/tensorflow/core/lib/jpeg/jpeg_handle.cc b/tensorflow/core/lib/jpeg/jpeg_handle.cc
new file mode 100644
index 0000000000..4521be0afb
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/jpeg_handle.cc
@@ -0,0 +1,162 @@
+// This file implements a memory destination for libjpeg
+// The design is very similar to jdatadst.c in libjpeg
+// These functions are not meant to be used directly, see jpeg_mem.h instead.
+// We are filling out stubs required by jpeglib, those stubs are private to
+// the implementation, we are just making available JPGMemSrc, JPGMemDest
+
+#include "tensorflow/core/lib/jpeg/jpeg_handle.h"
+
+#include <setjmp.h>
+#include <stddef.h>
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace jpeg {
+
+void CatchError(j_common_ptr cinfo) {
+ (*cinfo->err->output_message)(cinfo);
+ jmp_buf *jpeg_jmpbuf = reinterpret_cast<jmp_buf *>(cinfo->client_data);
+ jpeg_destroy(cinfo);
+ longjmp(*jpeg_jmpbuf, 1);
+}
+
+// *****************************************************************************
+// *****************************************************************************
+// *****************************************************************************
+// Destination functions
+
+// -----------------------------------------------------------------------------
+void MemInitDestination(j_compress_ptr cinfo) {
+ MemDestMgr *dest = reinterpret_cast<MemDestMgr *>(cinfo->dest);
+ VLOG(1) << "Initializing buffer=" << dest->bufsize << " bytes";
+ dest->pub.next_output_byte = dest->buffer;
+ dest->pub.free_in_buffer = dest->bufsize;
+ dest->datacount = 0;
+ if (dest->dest) {
+ dest->dest->clear();
+ }
+}
+
+// -----------------------------------------------------------------------------
+boolean MemEmptyOutputBuffer(j_compress_ptr cinfo) {
+ MemDestMgr *dest = reinterpret_cast<MemDestMgr *>(cinfo->dest);
+ VLOG(1) << "Writing " << dest->bufsize << " bytes";
+ if (dest->dest) {
+ dest->dest->append(reinterpret_cast<char *>(dest->buffer), dest->bufsize);
+ }
+ dest->pub.next_output_byte = dest->buffer;
+ dest->pub.free_in_buffer = dest->bufsize;
+ return TRUE;
+}
+
+// -----------------------------------------------------------------------------
+void MemTermDestination(j_compress_ptr cinfo) {
+ MemDestMgr *dest = reinterpret_cast<MemDestMgr *>(cinfo->dest);
+ VLOG(1) << "Writing " << dest->bufsize - dest->pub.free_in_buffer << " bytes";
+ if (dest->dest) {
+ dest->dest->append(reinterpret_cast<char *>(dest->buffer),
+ dest->bufsize - dest->pub.free_in_buffer);
+ VLOG(1) << "Total size= " << dest->dest->size();
+ }
+ dest->datacount = dest->bufsize - dest->pub.free_in_buffer;
+}
+
+// -----------------------------------------------------------------------------
+void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize) {
+ SetDest(cinfo, buffer, bufsize, NULL);
+}
+
+// -----------------------------------------------------------------------------
+void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize,
+ string *destination) {
+ MemDestMgr *dest;
+ if (cinfo->dest == NULL) {
+ cinfo->dest = reinterpret_cast<struct jpeg_destination_mgr *>(
+ (*cinfo->mem->alloc_small)(reinterpret_cast<j_common_ptr>(cinfo),
+ JPOOL_PERMANENT, sizeof(MemDestMgr)));
+ }
+
+ dest = reinterpret_cast<MemDestMgr *>(cinfo->dest);
+ dest->bufsize = bufsize;
+ dest->buffer = static_cast<JOCTET *>(buffer);
+ dest->dest = destination;
+ dest->pub.init_destination = MemInitDestination;
+ dest->pub.empty_output_buffer = MemEmptyOutputBuffer;
+ dest->pub.term_destination = MemTermDestination;
+}
+
+// *****************************************************************************
+// *****************************************************************************
+// *****************************************************************************
+// Source functions
+
+// -----------------------------------------------------------------------------
+void MemInitSource(j_decompress_ptr cinfo) {
+ MemSourceMgr *src = reinterpret_cast<MemSourceMgr *>(cinfo->src);
+ src->pub.next_input_byte = src->data;
+ src->pub.bytes_in_buffer = src->datasize;
+}
+
+// -----------------------------------------------------------------------------
+// We emulate the same error-handling as fill_input_buffer() from jdatasrc.c,
+// for coherency's sake.
+boolean MemFillInputBuffer(j_decompress_ptr cinfo) {
+ static const JOCTET kEOIBuffer[2] = {0xff, JPEG_EOI};
+ MemSourceMgr *src = reinterpret_cast<MemSourceMgr *>(cinfo->src);
+ if (src->pub.bytes_in_buffer == 0 && src->pub.next_input_byte == src->data) {
+ // empty file -> treated as an error.
+ ERREXIT(cinfo, JERR_INPUT_EMPTY);
+ return FALSE;
+ } else if (src->pub.bytes_in_buffer) {
+ // if there's still some data left, it's probably corrupted
+ return src->try_recover_truncated_jpeg ? TRUE : FALSE;
+ } else if (src->pub.next_input_byte != kEOIBuffer &&
+ src->try_recover_truncated_jpeg) {
+ // In an attempt to recover truncated files, we insert a fake EOI
+ WARNMS(cinfo, JWRN_JPEG_EOF);
+ src->pub.next_input_byte = kEOIBuffer;
+ src->pub.bytes_in_buffer = 2;
+ return TRUE;
+ } else {
+ // We already inserted a fake EOI and it wasn't enough, so this time
+ // it's really an error.
+ ERREXIT(cinfo, JERR_FILE_READ);
+ return FALSE;
+ }
+}
+
+// -----------------------------------------------------------------------------
+void MemTermSource(j_decompress_ptr cinfo) {}
+
+// -----------------------------------------------------------------------------
+void MemSkipInputData(j_decompress_ptr cinfo, long jump) {
+ MemSourceMgr *src = reinterpret_cast<MemSourceMgr *>(cinfo->src);
+ src->pub.bytes_in_buffer -= jump;
+ src->pub.next_input_byte += jump;
+}
+
+// -----------------------------------------------------------------------------
+void SetSrc(j_decompress_ptr cinfo, const void *data,
+ unsigned long int datasize, bool try_recover_truncated_jpeg) {
+ MemSourceMgr *src;
+
+ cinfo->src = reinterpret_cast<struct jpeg_source_mgr *>(
+ (*cinfo->mem->alloc_small)(reinterpret_cast<j_common_ptr>(cinfo),
+ JPOOL_PERMANENT, sizeof(MemSourceMgr)));
+
+ src = reinterpret_cast<MemSourceMgr *>(cinfo->src);
+ src->pub.init_source = MemInitSource;
+ src->pub.fill_input_buffer = MemFillInputBuffer;
+ src->pub.skip_input_data = MemSkipInputData;
+ src->pub.resync_to_restart = jpeg_resync_to_restart;
+ src->pub.term_source = MemTermSource;
+ src->data = reinterpret_cast<const unsigned char *>(data);
+ src->datasize = datasize;
+ src->pub.bytes_in_buffer = 0;
+ src->pub.next_input_byte = NULL;
+ src->try_recover_truncated_jpeg = try_recover_truncated_jpeg;
+}
+
+} // namespace jpeg
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/jpeg/jpeg_handle.h b/tensorflow/core/lib/jpeg/jpeg_handle.h
new file mode 100644
index 0000000000..58f7f6f666
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/jpeg_handle.h
@@ -0,0 +1,51 @@
+// This file declares the functions and structures for memory I/O with libjpeg
+// These functions are not meant to be used directly, see jpeg_mem.h isntead.
+
+#ifndef TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_
+#define TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_
+
+extern "C" {
+#include "external/jpeg_archive/jpeg-9a/jinclude.h"
+#include "external/jpeg_archive/jpeg-9a/jpeglib.h"
+#include "external/jpeg_archive/jpeg-9a/jerror.h"
+#include "external/jpeg_archive/jpeg-9a/transupp.h" // for rotations
+}
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace jpeg {
+
+// Handler for fatal JPEG library errors: clean up & return
+void CatchError(j_common_ptr cinfo);
+
+typedef struct {
+ struct jpeg_destination_mgr pub;
+ JOCTET *buffer;
+ int bufsize;
+ int datacount;
+ string *dest;
+} MemDestMgr;
+
+typedef struct {
+ struct jpeg_source_mgr pub;
+ const unsigned char *data;
+ unsigned long int datasize;
+ bool try_recover_truncated_jpeg;
+} MemSourceMgr;
+
+void SetSrc(j_decompress_ptr cinfo, const void *data,
+ unsigned long int datasize, bool try_recover_truncated_jpeg);
+
+// JPEG destination: we will store all the data in a buffer "buffer" of total
+// size "bufsize", if the buffer overflows, we will be in trouble.
+void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize);
+// Same as above, except that buffer is only used as a temporary structure and
+// is emptied into "destination" as soon as it fills up.
+void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize,
+ string *destination);
+
+} // namespace jpeg
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_
diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.cc b/tensorflow/core/lib/jpeg/jpeg_mem.cc
new file mode 100644
index 0000000000..556f13e388
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/jpeg_mem.cc
@@ -0,0 +1,557 @@
+// This file defines functions to compress and uncompress JPEG data
+// to and from memory, as well as some direct manipulations of JPEG string
+
+#include "tensorflow/core/lib/jpeg/jpeg_mem.h"
+
+#include <setjmp.h>
+#include <string.h>
+#include <algorithm>
+#include <memory>
+#include <string>
+
+#include "tensorflow/core/lib/jpeg/jpeg_handle.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace jpeg {
+
+// -----------------------------------------------------------------------------
+// Decompression
+
+namespace {
+
+enum JPEGErrors {
+ JPEGERRORS_OK,
+ JPEGERRORS_UNEXPECTED_END_OF_DATA,
+ JPEGERRORS_BAD_PARAM
+};
+
+// Prevent bad compiler behaviour in ASAN mode by wrapping most of the
+// arguments in a struct struct.
+class FewerArgsForCompiler {
+ public:
+ FewerArgsForCompiler(int datasize, const UncompressFlags& flags, int* nwarn,
+ std::function<uint8*(int, int, int)> allocate_output)
+ : datasize_(datasize),
+ flags_(flags),
+ pnwarn_(nwarn),
+ allocate_output_(allocate_output),
+ fraction_read_(0.),
+ height_(0),
+ stride_(0) {
+ if (pnwarn_ != nullptr) *pnwarn_ = 0;
+ }
+
+ const int datasize_;
+ const UncompressFlags flags_;
+ int* const pnwarn_;
+ std::function<uint8*(int, int, int)> allocate_output_;
+ float fraction_read_; // fraction of scanline lines successfully read
+ int height_;
+ int stride_;
+};
+
+uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) {
+ // unpack the argball
+ const int datasize = argball->datasize_;
+ const auto& flags = argball->flags_;
+ const int ratio = flags.ratio;
+ int components = flags.components;
+ int stride = flags.stride; // may be 0
+ int* const nwarn = argball->pnwarn_; // may be NULL
+
+ // can't decode if the ratio is not recognized by libjpeg
+ if ((ratio != 1) && (ratio != 2) && (ratio != 4) && (ratio != 8)) {
+ return nullptr;
+ }
+
+ // if empty image, return
+ if (datasize == 0 || srcdata == NULL) return nullptr;
+
+ // Declare temporary buffer pointer here so that we can free on error paths
+ JSAMPLE* tempdata = nullptr;
+
+ // Initialize libjpeg structures to have a memory source
+ // Modify the usual jpeg error manager to catch fatal errors.
+ JPEGErrors error = JPEGERRORS_OK;
+ struct jpeg_decompress_struct cinfo;
+ struct jpeg_error_mgr jerr;
+ cinfo.err = jpeg_std_error(&jerr);
+ jmp_buf jpeg_jmpbuf;
+ cinfo.client_data = &jpeg_jmpbuf;
+ jerr.error_exit = CatchError;
+ if (setjmp(jpeg_jmpbuf)) {
+ return nullptr;
+ }
+
+ jpeg_create_decompress(&cinfo);
+ SetSrc(&cinfo, srcdata, datasize, flags.try_recover_truncated_jpeg);
+ jpeg_read_header(&cinfo, TRUE);
+
+ // Set components automatically if desired
+ if (components == 0) components = cinfo.num_components;
+
+ // set grayscale and ratio parameters
+ switch (components) {
+ case 1:
+ cinfo.out_color_space = JCS_GRAYSCALE;
+ break;
+ case 3:
+ case 4:
+ if (cinfo.jpeg_color_space == JCS_CMYK ||
+ cinfo.jpeg_color_space == JCS_YCCK) {
+ // always use cmyk for output in a 4 channel jpeg. libjpeg has a builtin
+ // decoder.
+ cinfo.out_color_space = JCS_CMYK;
+ } else {
+ cinfo.out_color_space = JCS_RGB;
+ }
+ break;
+ default:
+ LOG(ERROR) << " Invalid components value " << components << std::endl;
+ jpeg_destroy_decompress(&cinfo);
+ return nullptr;
+ }
+ cinfo.do_fancy_upsampling = boolean(flags.fancy_upscaling);
+ cinfo.scale_num = 1;
+ cinfo.scale_denom = ratio;
+ // Activating this has a quality/speed trade-off implication:
+ // cinfo.dct_method = JDCT_IFAST;
+
+ jpeg_start_decompress(&cinfo);
+
+ // check for compatible stride
+ const int min_stride = cinfo.output_width * components * sizeof(JSAMPLE);
+ if (stride == 0) {
+ stride = min_stride;
+ } else if (stride < min_stride) {
+ LOG(ERROR) << "Incompatible stride: " << stride << " < " << min_stride;
+ jpeg_destroy_decompress(&cinfo);
+ return nullptr;
+ }
+
+ // Remember stride and height for use in Uncompress
+ argball->height_ = cinfo.output_height;
+ argball->stride_ = stride;
+
+ uint8* const dstdata = argball->allocate_output_(
+ cinfo.output_width, cinfo.output_height, components);
+ if (dstdata == nullptr) {
+ jpeg_destroy_decompress(&cinfo);
+ return nullptr;
+ }
+ JSAMPLE* output_line = static_cast<JSAMPLE*>(dstdata);
+
+ // Temporary buffer used for CMYK -> RGB conversion.
+ const bool use_cmyk = (cinfo.out_color_space == JCS_CMYK);
+ tempdata = use_cmyk ? new JSAMPLE[cinfo.output_width * 4] : NULL;
+
+ // If there is an error reading a line, this aborts the reading.
+ // Save the fraction of the image that has been read.
+ argball->fraction_read_ = 1.0;
+ while (cinfo.output_scanline < cinfo.output_height) {
+ int num_lines_read = 0;
+ if (cinfo.out_color_space == JCS_CMYK) {
+ num_lines_read = jpeg_read_scanlines(&cinfo, &tempdata, 1);
+ // Convert CMYK to RGB
+ for (size_t i = 0; i < cinfo.output_width; ++i) {
+ int c = tempdata[4 * i + 0];
+ int m = tempdata[4 * i + 1];
+ int y = tempdata[4 * i + 2];
+ int k = tempdata[4 * i + 3];
+ int r, g, b;
+ if (cinfo.saw_Adobe_marker) {
+ r = (k * c) / 255;
+ g = (k * m) / 255;
+ b = (k * y) / 255;
+ } else {
+ r = (255 - k) * (255 - c) / 255;
+ g = (255 - k) * (255 - m) / 255;
+ b = (255 - k) * (255 - y) / 255;
+ }
+ output_line[3 * i + 0] = r;
+ output_line[3 * i + 1] = g;
+ output_line[3 * i + 2] = b;
+ }
+ } else {
+ num_lines_read = jpeg_read_scanlines(&cinfo, &output_line, 1);
+ }
+ // Handle error cases
+ if (num_lines_read == 0) {
+ LOG(ERROR) << "Premature end of JPEG data. Stopped at line "
+ << cinfo.output_scanline << "/" << cinfo.output_height;
+ if (!flags.try_recover_truncated_jpeg) {
+ argball->fraction_read_ =
+ static_cast<float>(cinfo.output_scanline) / cinfo.output_height;
+ error = JPEGERRORS_UNEXPECTED_END_OF_DATA;
+ } else {
+ for (size_t line = cinfo.output_scanline; line < cinfo.output_height;
+ ++line) {
+ if (line == 0) {
+ // If even the first line is missing, fill with black color
+ memset(output_line, 0, min_stride);
+ } else {
+ // else, just replicate the line above.
+ memcpy(output_line, output_line - stride, min_stride);
+ }
+ output_line += stride;
+ }
+ argball->fraction_read_ = 1.0; // consider all lines as read
+ // prevent error-on-exit in libjpeg:
+ cinfo.output_scanline = cinfo.output_height;
+ }
+ break;
+ }
+ DCHECK_EQ(num_lines_read, 1);
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(output_line, min_stride);
+ output_line += stride;
+ }
+ delete[] tempdata;
+
+ // Convert the RGB data to RGBA, with alpha set to 0xFF to indicate
+ // opacity.
+ // RGBRGBRGB... --> RGBARGBARGBA...
+ if (components == 4) {
+ // Start on the last line.
+ JSAMPLE* scanlineptr =
+ static_cast<JSAMPLE*>(dstdata + (cinfo.output_height - 1) * stride);
+ const JSAMPLE kOpaque = -1; // All ones appropriate for JSAMPLE.
+ const int right_rgb = (cinfo.output_width - 1) * 3;
+ const int right_rgba = (cinfo.output_width - 1) * 4;
+
+ for (int y = cinfo.output_height; y-- > 0;) {
+ // We do all the transformations in place, going backwards for each row.
+ const JSAMPLE* rgb_pixel = scanlineptr + right_rgb;
+ JSAMPLE* rgba_pixel = scanlineptr + right_rgba;
+ scanlineptr -= stride;
+ for (int x = cinfo.output_width; x-- > 0;
+ rgba_pixel -= 4, rgb_pixel -= 3) {
+ // We copy the 3 bytes at rgb_pixel into the 4 bytes at rgba_pixel
+ // The "a" channel is set to be opaque.
+ rgba_pixel[3] = kOpaque;
+ rgba_pixel[2] = rgb_pixel[2];
+ rgba_pixel[1] = rgb_pixel[1];
+ rgba_pixel[0] = rgb_pixel[0];
+ }
+ }
+ }
+
+ switch (components) {
+ case 1:
+ if (cinfo.output_components != 1) {
+ error = JPEGERRORS_BAD_PARAM;
+ }
+ break;
+ case 3:
+ case 4:
+ if (cinfo.out_color_space == JCS_CMYK) {
+ if (cinfo.output_components != 4) {
+ error = JPEGERRORS_BAD_PARAM;
+ }
+ } else {
+ if (cinfo.output_components != 3) {
+ error = JPEGERRORS_BAD_PARAM;
+ }
+ }
+ break;
+ default:
+ // will never happen, should be catched by the previous switch
+ LOG(ERROR) << "Invalid components value " << components << std::endl;
+ jpeg_destroy_decompress(&cinfo);
+ return nullptr;
+ }
+
+ // save number of warnings if requested
+ if (nwarn != nullptr) {
+ *nwarn = cinfo.err->num_warnings;
+ }
+
+ // Handle errors in JPEG
+ switch (error) {
+ case JPEGERRORS_OK:
+ jpeg_finish_decompress(&cinfo);
+ break;
+ case JPEGERRORS_UNEXPECTED_END_OF_DATA:
+ case JPEGERRORS_BAD_PARAM:
+ jpeg_abort(reinterpret_cast<j_common_ptr>(&cinfo));
+ break;
+ default:
+ LOG(ERROR) << "Unhandled case " << error;
+ break;
+ }
+ jpeg_destroy_decompress(&cinfo);
+
+ return dstdata;
+}
+
+} // anonymous namespace
+
+// -----------------------------------------------------------------------------
+// We do the apparently silly thing of packing 5 of the arguments
+// into a structure that is then passed to another routine
+// that does all the work. The reason is that we want to catch
+// fatal JPEG library errors with setjmp/longjmp, and g++ and
+// associated libraries aren't good enough to guarantee that 7
+// parameters won't get clobbered by the longjmp. So we help
+// it out a little.
+uint8* Uncompress(const void* srcdata, int datasize,
+ const UncompressFlags& flags, int* nwarn,
+ std::function<uint8*(int, int, int)> allocate_output) {
+ FewerArgsForCompiler argball(datasize, flags, nwarn, allocate_output);
+ uint8* const dstdata = UncompressLow(srcdata, &argball);
+ const float fraction_read = argball.fraction_read_;
+ if (dstdata == NULL ||
+ fraction_read < std::min(1.0f, flags.min_acceptable_fraction)) {
+ // Major failure, none or too-partial read returned; get out
+ return NULL;
+ }
+
+ // If there was an error in reading the jpeg data,
+ // set the unread pixels to black
+ if (fraction_read < 1.0) {
+ const int first_bad_line =
+ static_cast<int>(fraction_read * argball.height_);
+ uint8* start = dstdata + first_bad_line * argball.stride_;
+ const int nbytes = (argball.height_ - first_bad_line) * argball.stride_;
+ memset(static_cast<void*>(start), 0, nbytes);
+ }
+
+ return dstdata;
+}
+
+uint8* Uncompress(const void* srcdata, int datasize,
+ const UncompressFlags& flags, int* pwidth, int* pheight,
+ int* pcomponents, int* nwarn) {
+ uint8* buffer = NULL;
+ uint8* result =
+ Uncompress(srcdata, datasize, flags, nwarn,
+ [=, &buffer](int width, int height, int components) {
+ if (pwidth != nullptr) *pwidth = width;
+ if (pheight != nullptr) *pheight = height;
+ if (pcomponents != nullptr) *pcomponents = components;
+ buffer = new uint8[height * width * components];
+ return buffer;
+ });
+ if (!result) delete[] buffer;
+ return result;
+}
+
+// ----------------------------------------------------------------------------
+// Computes image information from jpeg header.
+// Returns true on success; false on failure.
+bool GetImageInfo(const void* srcdata, int datasize, int* width, int* height,
+ int* components) {
+ // Init in case of failure
+ if (width) *width = 0;
+ if (height) *height = 0;
+ if (components) *components = 0;
+
+ // If empty image, return
+ if (datasize == 0 || srcdata == NULL) return false;
+
+ // Initialize libjpeg structures to have a memory source
+ // Modify the usual jpeg error manager to catch fatal errors.
+ struct jpeg_decompress_struct cinfo;
+ struct jpeg_error_mgr jerr;
+ jmp_buf jpeg_jmpbuf;
+ cinfo.err = jpeg_std_error(&jerr);
+ cinfo.client_data = &jpeg_jmpbuf;
+ jerr.error_exit = CatchError;
+ if (setjmp(jpeg_jmpbuf)) {
+ return false;
+ }
+
+ // set up, read header, set image parameters, save size
+ jpeg_create_decompress(&cinfo);
+ SetSrc(&cinfo, srcdata, datasize, false);
+
+ jpeg_read_header(&cinfo, TRUE);
+ jpeg_start_decompress(&cinfo); // required to transfer image size to cinfo
+ if (width) *width = cinfo.output_width;
+ if (height) *height = cinfo.output_height;
+ if (components) *components = cinfo.output_components;
+
+ jpeg_destroy_decompress(&cinfo);
+
+ return true;
+}
+
+// -----------------------------------------------------------------------------
+// Compression
+
+namespace {
+bool CompressInternal(const uint8* srcdata, int width, int height,
+ const CompressFlags& flags, string* output) {
+ output->clear();
+ const int components = (static_cast<int>(flags.format) & 0xff);
+ int in_stride = flags.stride;
+ if (in_stride == 0) {
+ in_stride = width * (static_cast<int>(flags.format) & 0xff);
+ } else if (in_stride < width * components) {
+ LOG(ERROR) << "Incompatible input stride";
+ return false;
+ }
+
+ JOCTET* buffer = 0;
+
+ // NOTE: for broader use xmp_metadata should be made a unicode string
+ CHECK(srcdata != nullptr);
+ CHECK(output != nullptr);
+ // This struct contains the JPEG compression parameters and pointers to
+ // working space
+ struct jpeg_compress_struct cinfo;
+ // This struct represents a JPEG error handler.
+ struct jpeg_error_mgr jerr;
+ jmp_buf jpeg_jmpbuf; // recovery point in case of error
+
+ // Step 1: allocate and initialize JPEG compression object
+ // Use the usual jpeg error manager.
+ cinfo.err = jpeg_std_error(&jerr);
+ cinfo.client_data = &jpeg_jmpbuf;
+ jerr.error_exit = CatchError;
+ if (setjmp(jpeg_jmpbuf)) {
+ output->clear();
+ delete[] buffer;
+ return false;
+ }
+
+ jpeg_create_compress(&cinfo);
+
+ // Step 2: specify data destination
+ // We allocate a buffer of reasonable size. If we have a small image, just
+ // estimate the size of the output using the number of bytes of the input.
+ // If this is getting too big, we will append to the string by chunks of 1MB.
+ // This seems like a reasonable compromise between performance and memory.
+ int bufsize = std::min(width * height * components, 1 << 20);
+ buffer = new JOCTET[bufsize];
+ SetDest(&cinfo, buffer, bufsize, output);
+
+ // Step 3: set parameters for compression
+ cinfo.image_width = width;
+ cinfo.image_height = height;
+ switch (components) {
+ case 1:
+ cinfo.input_components = 1;
+ cinfo.in_color_space = JCS_GRAYSCALE;
+ break;
+ case 3:
+ case 4:
+ cinfo.input_components = 3;
+ cinfo.in_color_space = JCS_RGB;
+ break;
+ default:
+ LOG(ERROR) << " Invalid components value " << components << std::endl;
+ output->clear();
+ delete[] buffer;
+ return false;
+ }
+ jpeg_set_defaults(&cinfo);
+ if (flags.optimize_jpeg_size) cinfo.optimize_coding = TRUE;
+
+ cinfo.density_unit = flags.density_unit; // JFIF code for pixel size units:
+ // 1 = in, 2 = cm
+ cinfo.X_density = flags.x_density; // Horizontal pixel density
+ cinfo.Y_density = flags.y_density; // Vertical pixel density
+ jpeg_set_quality(&cinfo, flags.quality, TRUE);
+
+ if (flags.progressive) {
+ jpeg_simple_progression(&cinfo);
+ }
+
+ if (!flags.chroma_downsampling) {
+ // Turn off chroma subsampling (it is on by default). For more details on
+ // chroma subsampling, see http://en.wikipedia.org/wiki/Chroma_subsampling.
+ for (int i = 0; i < cinfo.num_components; ++i) {
+ cinfo.comp_info[i].h_samp_factor = 1;
+ cinfo.comp_info[i].v_samp_factor = 1;
+ }
+ }
+
+ jpeg_start_compress(&cinfo, TRUE);
+
+ // Embed XMP metadata if any
+ if (!flags.xmp_metadata.empty()) {
+ // XMP metadata is embedded in the APP1 tag of JPEG and requires this
+ // namespace header string (null-terminated)
+ const string name_space = "http://ns.adobe.com/xap/1.0/";
+ const int name_space_length = name_space.size();
+ const int metadata_length = flags.xmp_metadata.size();
+ const int packet_length = metadata_length + name_space_length + 1;
+ std::unique_ptr<JOCTET[]> joctet_packet(new JOCTET[packet_length]);
+
+ for (int i = 0; i < name_space_length; i++) {
+ // Conversion char --> JOCTET
+ joctet_packet[i] = name_space[i];
+ }
+ joctet_packet[name_space_length] = 0; // null-terminate namespace string
+
+ for (int i = 0; i < metadata_length; i++) {
+ // Conversion char --> JOCTET
+ joctet_packet[i + name_space_length + 1] = flags.xmp_metadata[i];
+ }
+ jpeg_write_marker(&cinfo, JPEG_APP0 + 1, joctet_packet.get(),
+ packet_length);
+ }
+
+ // JSAMPLEs per row in image_buffer
+ std::unique_ptr<JSAMPLE[]> row_temp(
+ new JSAMPLE[width * cinfo.input_components]);
+ while (cinfo.next_scanline < cinfo.image_height) {
+ JSAMPROW row_pointer[1]; // pointer to JSAMPLE row[s]
+ const uint8* r = &srcdata[cinfo.next_scanline * in_stride];
+ uint8* p = static_cast<uint8*>(row_temp.get());
+ switch (flags.format) {
+ case FORMAT_RGBA: {
+ for (int i = 0; i < width; ++i, p += 3, r += 4) {
+ p[0] = r[0];
+ p[1] = r[1];
+ p[2] = r[2];
+ }
+ row_pointer[0] = row_temp.get();
+ break;
+ }
+ case FORMAT_ABGR: {
+ for (int i = 0; i < width; ++i, p += 3, r += 4) {
+ p[0] = r[3];
+ p[1] = r[2];
+ p[2] = r[1];
+ }
+ row_pointer[0] = row_temp.get();
+ break;
+ }
+ default: {
+ row_pointer[0] = reinterpret_cast<JSAMPLE*>(const_cast<JSAMPLE*>(r));
+ }
+ }
+ CHECK_EQ(jpeg_write_scanlines(&cinfo, row_pointer, 1), 1);
+ }
+ jpeg_finish_compress(&cinfo);
+
+ // release JPEG compression object
+ jpeg_destroy_compress(&cinfo);
+ delete[] buffer;
+ return true;
+}
+
+} // anonymous namespace
+
+// -----------------------------------------------------------------------------
+
+bool Compress(const void* srcdata, int width, int height,
+ const CompressFlags& flags, string* output) {
+ return CompressInternal(static_cast<const uint8*>(srcdata), width, height,
+ flags, output);
+}
+
+string Compress(const void* srcdata, int width, int height,
+ const CompressFlags& flags) {
+ string temp;
+ CompressInternal(static_cast<const uint8*>(srcdata), width, height, flags,
+ &temp);
+ // If CompressInternal fails, temp will be empty.
+ return temp;
+}
+
+} // namespace jpeg
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.h b/tensorflow/core/lib/jpeg/jpeg_mem.h
new file mode 100644
index 0000000000..19ba7d4acf
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/jpeg_mem.h
@@ -0,0 +1,130 @@
+// This file defines functions to compress and uncompress JPEG files
+// to and from memory. It provides interfaces for raw images
+// (data array and size fields).
+// Direct manipulation of JPEG strings are supplied: Flip, Rotate, Crop..
+
+#ifndef TENSORFLOW_LIB_JPEG_JPEG_MEM_H_
+#define TENSORFLOW_LIB_JPEG_JPEG_MEM_H_
+
+#include <functional>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+namespace jpeg {
+
+// Flags for Uncompress
+struct UncompressFlags {
+ // ratio can be 1, 2, 4, or 8 and represent the denominator for the scaling
+ // factor (eg ratio = 4 means that the resulting image will be at 1/4 original
+ // size in both directions).
+ int ratio = 1;
+
+ // The number of bytes per pixel (1, 3 or 4), or 0 for autodetect.
+ int components = 0;
+
+ // If true, decoder will use a slower but nicer upscaling of the chroma
+ // planes (yuv420/422 only).
+ bool fancy_upscaling = true;
+
+ // If true, will attempt to fill in missing lines of truncated files
+ bool try_recover_truncated_jpeg = false;
+
+ // The minimum required fraction of lines read before the image is accepted.
+ float min_acceptable_fraction = 1.0;
+
+ // The distance in bytes from one scanline to the other. Should be at least
+ // equal to width*components*sizeof(JSAMPLE). If 0 is passed, the stride
+ // used will be this minimal value.
+ int stride = 0;
+};
+
+// Uncompress some raw JPEG data given by the pointer srcdata and the length
+// datasize.
+// - width and height are the address where to store the size of the
+// uncompressed image in pixels. May be nullptr.
+// - components is the address where the number of read components are
+// stored. This is *output only*: to request a specific number of
+// components use flags.components. May be nullptr.
+// - nwarn is the address in which to store the number of warnings.
+// May be nullptr.
+// The function returns a pointer to the raw uncompressed data or NULL if
+// there was an error. The caller of the function is responsible for
+// freeing the memory (using delete []).
+uint8* Uncompress(const void* srcdata, int datasize,
+ const UncompressFlags& flags, int* width, int* height,
+ int* components, // Output only: useful with autodetect
+ int* nwarn);
+
+// Version of Uncompress that allocates memory via a callback. The callback
+// arguments are (width, height, components). If the size is known ahead of
+// time this function can return an existing buffer; passing a callback allows
+// the buffer to be shaped based on the JPEG header. The caller is responsible
+// for freeing the memory *even along error paths*.
+uint8* Uncompress(const void* srcdata, int datasize,
+ const UncompressFlags& flags, int* nwarn,
+ std::function<uint8*(int, int, int)> allocate_output);
+
+// Read jpeg header and get image information. Returns true on success.
+// The width, height, and components points may be null.
+bool GetImageInfo(const void* srcdata, int datasize, int* width, int* height,
+ int* components);
+
+// Note: (format & 0xff) = number of components (<=> bytes per pixels)
+enum Format {
+ FORMAT_GRAYSCALE = 0x001, // 1 byte/pixel
+ FORMAT_RGB = 0x003, // 3 bytes/pixel RGBRGBRGBRGB...
+ FORMAT_RGBA = 0x004, // 4 bytes/pixel RGBARGBARGBARGBA...
+ FORMAT_ABGR = 0x104 // 4 bytes/pixel ABGRABGRABGR...
+};
+
+// Flags for compression
+struct CompressFlags {
+ // Encoding of the input data for compression
+ Format format;
+
+ // Quality of the compression from 0-100
+ int quality = 95;
+
+ // If true, create a jpeg image that loads progressively
+ bool progressive = false;
+
+ // If true, reduce jpeg size without changing quality (at the cost of CPU/RAM)
+ bool optimize_jpeg_size = false;
+
+ // See http://en.wikipedia.org/wiki/Chroma_subsampling
+ bool chroma_downsampling = true;
+
+ // Resolution
+ int density_unit = 1; // 1 = in, 2 = cm
+ int x_density = 300;
+ int y_density = 300;
+
+ // If not empty, embed this XMP metadata in the image header
+ StringPiece xmp_metadata;
+
+ // The distance in bytes from one scanline to the other. Should be at least
+ // equal to width*components*sizeof(JSAMPLE). If 0 is passed, the stride
+ // used will be this minimal value.
+ int stride = 0;
+};
+
+// Compress some raw image given in srcdata, the data is a 2D array of size
+// stride*height with one of the formats enumerated above.
+// The encoded data is returned as a string.
+// If not empty, XMP metadata can be embedded in the image header
+// On error, returns the empty string (which is never a valid jpeg).
+string Compress(const void* srcdata, int width, int height,
+ const CompressFlags& flags);
+
+// On error, returns false and sets output to empty.
+bool Compress(const void* srcdata, int width, int height,
+ const CompressFlags& flags, string* output);
+
+} // namespace jpeg
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_JPEG_JPEG_MEM_H_
diff --git a/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc b/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc
new file mode 100644
index 0000000000..23e72f9d57
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc
@@ -0,0 +1,304 @@
+#include "tensorflow/core/lib/jpeg/jpeg_mem.h"
+
+#include <setjmp.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include <memory>
+
+#include "tensorflow/core/lib/jpeg/jpeg_handle.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/env.h"
+#include <gtest/gtest.h>
+
+#include "tensorflow/core/lib/core/casts.h"
+
+namespace tensorflow {
+namespace jpeg {
+namespace {
+
+const char kTestData[] = "tensorflow/core/lib/jpeg/testdata/";
+
+int ComputeSumAbsoluteDifference(const uint8* a, const uint8* b, int width,
+ int height, int a_stride, int b_stride) {
+ int totalerr = 0;
+ for (int i = 0; i < height; i++) {
+ const uint8* const pa = a + i * a_stride;
+ const uint8* const pb = b + i * b_stride;
+ for (int j = 0; j < 3 * width; j++) {
+ totalerr += abs(static_cast<int>(pa[j]) - static_cast<int>(pb[j]));
+ }
+ }
+ return totalerr;
+}
+
+// Reads the contents of the file into output
+void ReadFileToStringOrDie(Env* env, const string& filename, string* output) {
+ TF_CHECK_OK(ReadFileToString(env, filename, output));
+}
+
+void TestJPEG(Env* env, const string& jpegfile) {
+ // Read the data from the jpeg file into memory
+ string jpeg;
+ ReadFileToStringOrDie(Env::Default(), jpegfile, &jpeg);
+ const int fsize = jpeg.size();
+ const uint8* const temp = bit_cast<const uint8*>(jpeg.data());
+
+ // try partial decoding (half of the data)
+ int w, h, c;
+ std::unique_ptr<uint8[]> imgdata;
+
+ UncompressFlags flags;
+ flags.components = 3;
+
+ // set min_acceptable_fraction to something insufficient
+ flags.min_acceptable_fraction = 0.8;
+ imgdata.reset(Uncompress(temp, fsize / 2, flags, &w, &h, &c, NULL));
+ CHECK(imgdata.get() == NULL);
+
+ // now, use a value that makes fsize/2 be enough for a black-filling
+ flags.min_acceptable_fraction = 0.01;
+ imgdata.reset(Uncompress(temp, fsize / 2, flags, &w, &h, &c, NULL));
+ CHECK(imgdata.get() != NULL);
+
+ // finally, uncompress the whole data
+ flags.min_acceptable_fraction = 1.0;
+ imgdata.reset(Uncompress(temp, fsize, flags, &w, &h, &c, NULL));
+ CHECK(imgdata.get() != NULL);
+
+ // Uncompress the data to RGBA, too
+ flags.min_acceptable_fraction = 1.0;
+ flags.components = 4;
+ imgdata.reset(Uncompress(temp, fsize, flags, &w, &h, &c, NULL));
+ CHECK(imgdata.get() != NULL);
+}
+
+TEST(JpegMemTest, Jpeg) {
+ Env* env = Env::Default();
+ const string data_path = kTestData;
+
+ // Name of a valid jpeg file on the disk
+ TestJPEG(env, data_path + "jpeg_merge_test1.jpg");
+
+ // Exercise CMYK machinery as well
+ TestJPEG(env, data_path + "jpeg_merge_test1_cmyk.jpg");
+}
+
+TEST(JpegMemTest, Jpeg2) {
+ // create known data, for size in_w x in_h
+ const int in_w = 256;
+ const int in_h = 256;
+ const int stride1 = 3 * in_w;
+ const std::unique_ptr<uint8[]> refdata1(new uint8[stride1 * in_h]);
+ for (int i = 0; i < in_h; i++) {
+ for (int j = 0; j < in_w; j++) {
+ const int offset = i * stride1 + 3 * j;
+ refdata1[offset + 0] = i;
+ refdata1[offset + 1] = j;
+ refdata1[offset + 2] = static_cast<uint8>((i + j) >> 1);
+ }
+ }
+
+ // duplicate with weird input stride
+ const int stride2 = 3 * 357;
+ const std::unique_ptr<uint8[]> refdata2(new uint8[stride2 * in_h]);
+ for (int i = 0; i < in_h; i++) {
+ memcpy(&refdata2[i * stride2], &refdata1[i * stride1], 3 * in_w);
+ }
+
+ // Test compression
+ string cpdata1, cpdata2;
+ {
+ const string kXMP = "XMP_TEST_123";
+
+ // Compress it to JPEG
+ CompressFlags flags;
+ flags.format = FORMAT_RGB;
+ flags.quality = 97;
+ flags.xmp_metadata = kXMP;
+ cpdata1 = Compress(refdata1.get(), in_w, in_h, flags);
+ flags.stride = stride2;
+ cpdata2 = Compress(refdata2.get(), in_w, in_h, flags);
+ // Different input stride shouldn't change the output
+ CHECK_EQ(cpdata1, cpdata2);
+
+ // Verify valid XMP.
+ CHECK_NE(string::npos, cpdata1.find(kXMP));
+
+ // Test the other API, where a storage string is supplied
+ string cptest;
+ flags.stride = 0;
+ Compress(refdata1.get(), in_w, in_h, flags, &cptest);
+ CHECK_EQ(cptest, cpdata1);
+ flags.stride = stride2;
+ Compress(refdata2.get(), in_w, in_h, flags, &cptest);
+ CHECK_EQ(cptest, cpdata2);
+ }
+
+ // Uncompress twice: once with 3 components and once with autodetect
+ std::unique_ptr<uint8[]> imgdata1;
+ for (const int components : {0, 3}) {
+ // Uncompress it
+ UncompressFlags flags;
+ flags.components = components;
+ int w, h, c;
+ imgdata1.reset(
+ Uncompress(cpdata1.c_str(), cpdata1.length(), flags, &w, &h, &c, NULL));
+
+ // Check obvious formatting stuff
+ CHECK_EQ(w, in_w);
+ CHECK_EQ(h, in_h);
+ CHECK_EQ(c, 3);
+ CHECK(imgdata1.get());
+
+ // Compare the two images
+ const int totalerr = ComputeSumAbsoluteDifference(
+ imgdata1.get(), refdata1.get(), in_w, in_h, stride1, stride1);
+ CHECK_LE(totalerr, 85000);
+ }
+
+ // check the second image too. Should be bitwise identical to the first.
+ // uncompress using a weird stride
+ {
+ UncompressFlags flags;
+ flags.stride = 3 * 411;
+ const std::unique_ptr<uint8[]> imgdata2(new uint8[flags.stride * in_h]);
+ CHECK(imgdata2.get() == Uncompress(cpdata2.c_str(), cpdata2.length(), flags,
+ NULL, [&imgdata2](int w, int h, int c) {
+ CHECK_EQ(w, in_w);
+ CHECK_EQ(h, in_h);
+ CHECK_EQ(c, 3);
+ return imgdata2.get();
+ }));
+ const int totalerr = ComputeSumAbsoluteDifference(
+ imgdata1.get(), imgdata2.get(), in_w, in_h, stride1, flags.stride);
+ CHECK_EQ(totalerr, 0);
+ }
+}
+
+// Takes JPEG data and reads its headers to determine whether or not the JPEG
+// was chroma downsampled.
+bool IsChromaDownsampled(const string& jpegdata) {
+ // Initialize libjpeg structures to have a memory source
+ // Modify the usual jpeg error manager to catch fatal errors.
+ struct jpeg_decompress_struct cinfo;
+ struct jpeg_error_mgr jerr;
+ jmp_buf jpeg_jmpbuf;
+ cinfo.err = jpeg_std_error(&jerr);
+ cinfo.client_data = &jpeg_jmpbuf;
+ jerr.error_exit = CatchError;
+ if (setjmp(jpeg_jmpbuf)) return false;
+
+ // set up, read header, set image parameters, save size
+ jpeg_create_decompress(&cinfo);
+ SetSrc(&cinfo, jpegdata.c_str(), jpegdata.size(), false);
+
+ jpeg_read_header(&cinfo, TRUE);
+ jpeg_start_decompress(&cinfo); // required to transfer image size to cinfo
+ const int components = cinfo.output_components;
+ if (components == 1) return false;
+
+ // Check validity
+ CHECK_EQ(3, components);
+ CHECK_EQ(cinfo.comp_info[1].h_samp_factor, cinfo.comp_info[2].h_samp_factor)
+ << "The h sampling factors should be the same.";
+ CHECK_EQ(cinfo.comp_info[1].v_samp_factor, cinfo.comp_info[2].v_samp_factor)
+ << "The v sampling factors should be the same.";
+ for (int i = 0; i < components; ++i) {
+ CHECK_GT(cinfo.comp_info[i].h_samp_factor, 0) << "Invalid sampling factor.";
+ CHECK_EQ(cinfo.comp_info[i].h_samp_factor, cinfo.comp_info[i].v_samp_factor)
+ << "The sampling factor should be the same in both directions.";
+ }
+
+ // We're downsampled if we use fewer samples for color than for brightness.
+ // Do this before deallocating cinfo.
+ const bool downsampled =
+ cinfo.comp_info[1].h_samp_factor < cinfo.comp_info[0].h_samp_factor;
+
+ jpeg_destroy_decompress(&cinfo);
+ return downsampled;
+}
+
+TEST(JpegMemTest, ChromaDownsampling) {
+ // Read the data from a test jpeg file into memory
+ const string jpegfile = string(kTestData) + "jpeg_merge_test1.jpg";
+ string jpeg;
+ ReadFileToStringOrDie(Env::Default(), jpegfile, &jpeg);
+
+ // Verify that compressing the JPEG with chroma downsampling works.
+ //
+ // First, uncompress the JPEG.
+ UncompressFlags unflags;
+ unflags.components = 3;
+ int w, h, c, num_warnings;
+ std::unique_ptr<uint8[]> uncompressed(Uncompress(
+ jpeg.c_str(), jpeg.size(), unflags, &w, &h, &c, &num_warnings));
+ CHECK(uncompressed.get() != NULL);
+ CHECK_EQ(num_warnings, 0);
+
+ // Recompress the JPEG with and without chroma downsampling
+ for (const bool downsample : {false, true}) {
+ CompressFlags flags;
+ flags.format = FORMAT_RGB;
+ flags.quality = 85;
+ flags.chroma_downsampling = downsample;
+ string recompressed;
+ Compress(uncompressed.get(), w, h, flags, &recompressed);
+ CHECK(!recompressed.empty());
+ CHECK_EQ(IsChromaDownsampled(recompressed), downsample);
+ }
+}
+
+void TestBadJPEG(Env* env, const string& bad_jpeg_file, int expected_width,
+ int expected_height, const string& reference_RGB_file,
+ const bool try_recover_truncated_jpeg) {
+ string jpeg;
+ ReadFileToStringOrDie(env, bad_jpeg_file, &jpeg);
+
+ UncompressFlags flags;
+ flags.components = 3;
+ flags.try_recover_truncated_jpeg = try_recover_truncated_jpeg;
+
+ int width, height, components;
+ std::unique_ptr<uint8[]> imgdata;
+ imgdata.reset(Uncompress(jpeg.c_str(), jpeg.size(), flags, &width, &height,
+ &components, NULL));
+ if (expected_width > 0) { // we expect the file to decode into 'something'
+ CHECK_EQ(width, expected_width);
+ CHECK_EQ(height, expected_height);
+ CHECK_EQ(components, 3);
+ CHECK(imgdata.get());
+ if (!reference_RGB_file.empty()) {
+ string ref;
+ ReadFileToStringOrDie(env, reference_RGB_file, &ref);
+ CHECK(!memcmp(ref.data(), imgdata.get(), ref.size()));
+ }
+ } else { // no decodable
+ CHECK(!imgdata.get()) << "file:" << bad_jpeg_file;
+ }
+}
+
+TEST(JpegMemTest, BadJpeg) {
+ Env* env = Env::Default();
+ const string data_path = kTestData;
+
+ // Test corrupt file
+ TestBadJPEG(env, data_path + "bad_huffman.jpg", 1024, 768, "", false);
+ TestBadJPEG(env, data_path + "corrupt.jpg", 0 /*120*/, 90, "", false);
+
+ // Truncated files, undecodable because of missing lines:
+ TestBadJPEG(env, data_path + "corrupt34_2.jpg", 0, 3300, "", false);
+ TestBadJPEG(env, data_path + "corrupt34_3.jpg", 0, 3300, "", false);
+ TestBadJPEG(env, data_path + "corrupt34_4.jpg", 0, 3300, "", false);
+
+ // Try in 'recover' mode now:
+ TestBadJPEG(env, data_path + "corrupt34_2.jpg", 2544, 3300, "", true);
+ TestBadJPEG(env, data_path + "corrupt34_3.jpg", 2544, 3300, "", true);
+ TestBadJPEG(env, data_path + "corrupt34_4.jpg", 2544, 3300, "", true);
+}
+
+} // namespace
+} // namespace jpeg
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/jpeg/testdata/bad_huffman.jpg b/tensorflow/core/lib/jpeg/testdata/bad_huffman.jpg
new file mode 100644
index 0000000000..ef5b6f12c5
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/testdata/bad_huffman.jpg
Binary files differ
diff --git a/tensorflow/core/lib/jpeg/testdata/corrupt.jpg b/tensorflow/core/lib/jpeg/testdata/corrupt.jpg
new file mode 100644
index 0000000000..5e2fe6c56f
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/testdata/corrupt.jpg
Binary files differ
diff --git a/tensorflow/core/lib/jpeg/testdata/corrupt34_2.jpg b/tensorflow/core/lib/jpeg/testdata/corrupt34_2.jpg
new file mode 100644
index 0000000000..4211155c45
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/testdata/corrupt34_2.jpg
Binary files differ
diff --git a/tensorflow/core/lib/jpeg/testdata/corrupt34_3.jpg b/tensorflow/core/lib/jpeg/testdata/corrupt34_3.jpg
new file mode 100644
index 0000000000..c1c2a9d1e1
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/testdata/corrupt34_3.jpg
Binary files differ
diff --git a/tensorflow/core/lib/jpeg/testdata/corrupt34_4.jpg b/tensorflow/core/lib/jpeg/testdata/corrupt34_4.jpg
new file mode 100644
index 0000000000..b8e7308ba0
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/testdata/corrupt34_4.jpg
Binary files differ
diff --git a/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1.jpg b/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1.jpg
new file mode 100644
index 0000000000..5e348a12fd
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1.jpg
Binary files differ
diff --git a/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg b/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg
new file mode 100644
index 0000000000..15f895960d
--- /dev/null
+++ b/tensorflow/core/lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg
Binary files differ
diff --git a/tensorflow/core/lib/png/png_io.cc b/tensorflow/core/lib/png/png_io.cc
new file mode 100644
index 0000000000..43b84e41e0
--- /dev/null
+++ b/tensorflow/core/lib/png/png_io.cc
@@ -0,0 +1,385 @@
+// Functions to read and write images in PNG format.
+
+#include <string.h>
+#include <sys/types.h>
+#include <string>
+#include <utility>
+#include <vector>
+// NOTE(skal): we don't '#include <setjmp.h>' before png/png.h as it otherwise
+// provokes a compile error. We instead let png.h include what is needed.
+
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/png/png_io.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h" // endian
+#include "external/png_archive/libpng-1.2.53/png.h"
+
+namespace tensorflow {
+namespace png {
+
+////////////////////////////////////////////////////////////////////////////////
+// Encode an 8- or 16-bit rgb/grayscale image to PNG string
+////////////////////////////////////////////////////////////////////////////////
+
+namespace {
+
+#define PTR_INC(type, ptr, del) (ptr = \
+ reinterpret_cast<type*>(reinterpret_cast<char*>(ptr) + (del)))
+#define CPTR_INC(type, ptr, del) (ptr = \
+ reinterpret_cast<const type*>(reinterpret_cast<const char*>(ptr) + (del)))
+
+// Convert from 8 bit components to 16. This works in-place.
+static void Convert8to16(const uint8* p8, int num_comps, int p8_row_bytes,
+ int width, int height, uint16* p16,
+ int p16_row_bytes) {
+ // Adjust pointers to copy backwards
+ width *= num_comps;
+ CPTR_INC(uint8, p8, (height - 1) * p8_row_bytes +
+ (width - 1) * sizeof(*p8));
+ PTR_INC(uint16, p16, (height - 1) * p16_row_bytes +
+ (width - 1) * sizeof(*p16));
+ int bump8 = width * sizeof(*p8) - p8_row_bytes;
+ int bump16 = width * sizeof(*p16) - p16_row_bytes;
+ for (; height-- != 0;
+ CPTR_INC(uint8, p8, bump8), PTR_INC(uint16, p16, bump16)) {
+ for (int w = width; w-- != 0; --p8, --p16) {
+ uint pix = *p8;
+ pix |= pix << 8;
+ *p16 = static_cast<uint16>(pix);
+ }
+ }
+}
+
+#undef PTR_INC
+#undef CPTR_INC
+
+void ErrorHandler(png_structp png_ptr, png_const_charp msg) {
+ DecodeContext* const ctx = bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
+ ctx->error_condition = true;
+ // To prevent log spam, errors are logged as VLOG(1) instead of ERROR.
+ VLOG(1) << "PNG error: " << msg;
+ longjmp(png_jmpbuf(png_ptr), 1);
+}
+
+void WarningHandler(png_structp png_ptr, png_const_charp msg) {
+ LOG(WARNING) << "PNG warning: " << msg;
+}
+
+void StringReader(png_structp png_ptr,
+ png_bytep data, png_size_t length) {
+ DecodeContext* const ctx = bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
+ if (static_cast<png_size_t>(ctx->data_left) < length) {
+ if (!ctx->error_condition) {
+ VLOG(1) << "PNG read decoding error";
+ ctx->error_condition = true;
+ }
+ memset(data, 0, length);
+ } else {
+ memcpy(data, ctx->data, length);
+ ctx->data += length;
+ ctx->data_left -= length;
+ }
+}
+
+void StringWriter(png_structp png_ptr, png_bytep data, png_size_t length) {
+ string* const s = bit_cast<string*>(png_get_io_ptr(png_ptr));
+ s->append(bit_cast<const char*>(data), length);
+}
+
+void StringWriterFlush(png_structp png_ptr) {
+}
+
+char* check_metadata_string(const string& s) {
+ const char* const c_string = s.c_str();
+ const size_t length = s.size();
+ if (strlen(c_string) != length) {
+ LOG(WARNING) << "Warning! Metadata contains \\0 character(s).";
+ }
+ return const_cast<char*>(c_string);
+}
+
+} // namespace
+
+// We move CommonInitDecode() and CommonFinishDecode()
+// out of the CommonDecode() template to save code space.
+void CommonFreeDecode(DecodeContext* context) {
+ if (context->png_ptr) {
+ png_destroy_read_struct(&context->png_ptr,
+ context->info_ptr ? &context->info_ptr : NULL, 0);
+ context->png_ptr = nullptr;
+ context->info_ptr = nullptr;
+ }
+}
+
+bool DecodeHeader(StringPiece png_string, int* width, int* height,
+ int* components, int* channel_bit_depth,
+ std::vector<std::pair<string, string> >* metadata) {
+ DecodeContext context;
+ // Ask for 16 bits even if there may be fewer. This assures that sniffing
+ // the metadata will succeed in all cases.
+ //
+ // TODO(skal): CommonInitDecode() mixes the operation of sniffing the
+ // metadata with setting up the data conversions. These should be separated.
+ constexpr int kDesiredNumChannels = 1;
+ constexpr int kDesiredChannelBits = 16;
+ if (!CommonInitDecode(png_string, kDesiredNumChannels, kDesiredChannelBits,
+ &context)) {
+ return false;
+ }
+ CHECK_NOTNULL(width);
+ *width = static_cast<int>(context.width);
+ CHECK_NOTNULL(height);
+ *height = static_cast<int>(context.height);
+ if (components != NULL) {
+ switch (context.color_type) {
+ case PNG_COLOR_TYPE_PALETTE:
+ *components = (context.info_ptr->valid & PNG_INFO_tRNS) ? 4 : 3;
+ break;
+ case PNG_COLOR_TYPE_GRAY:
+ *components = 1;
+ break;
+ case PNG_COLOR_TYPE_GRAY_ALPHA:
+ *components = 2;
+ break;
+ case PNG_COLOR_TYPE_RGB:
+ *components = 3;
+ break;
+ case PNG_COLOR_TYPE_RGB_ALPHA:
+ *components = 4;
+ break;
+ default:
+ *components = 0;
+ break;
+ }
+ }
+ if (channel_bit_depth != NULL) {
+ *channel_bit_depth = context.bit_depth;
+ }
+ if (metadata != NULL) {
+ metadata->clear();
+ for (int i = 0; i < context.info_ptr->num_text; i++) {
+ const png_text& text = context.info_ptr->text[i];
+ metadata->push_back(std::make_pair(text.key, text.text));
+ }
+ }
+ CommonFreeDecode(&context);
+ return true;
+}
+
+bool CommonInitDecode(StringPiece png_string, int desired_channels,
+ int desired_channel_bits, DecodeContext* context) {
+ CHECK(desired_channel_bits == 8 || desired_channel_bits == 16)
+ << "desired_channel_bits = " << desired_channel_bits;
+ CHECK(0 <= desired_channels && desired_channels <= 4) << "desired_channels = "
+ << desired_channels;
+ context->error_condition = false;
+ context->channels = desired_channels;
+ context->png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, context,
+ ErrorHandler, WarningHandler);
+ if (!context->png_ptr) {
+ VLOG(1) << ": DecodePNG <- png_create_read_struct failed";
+ return false;
+ }
+ if (setjmp(png_jmpbuf(context->png_ptr))) {
+ VLOG(1) << ": DecodePNG error trapped.";
+ CommonFreeDecode(context);
+ return false;
+ }
+ context->info_ptr = png_create_info_struct(context->png_ptr);
+ if (!context->info_ptr || context->error_condition) {
+ VLOG(1) << ": DecodePNG <- png_create_info_struct failed";
+ CommonFreeDecode(context);
+ return false;
+ }
+ context->data = bit_cast<const uint8*>(png_string.data());
+ context->data_left = png_string.size();
+ png_set_read_fn(context->png_ptr, context, StringReader);
+ png_read_info(context->png_ptr, context->info_ptr);
+ png_get_IHDR(context->png_ptr, context->info_ptr,
+ &context->width, &context->height,
+ &context->bit_depth, &context->color_type,
+ 0, 0, 0);
+ if (context->error_condition) {
+ VLOG(1) << ": DecodePNG <- error during header parsing.";
+ CommonFreeDecode(context);
+ return false;
+ }
+ if (context->width <= 0 || context->height <= 0) {
+ VLOG(1) << ": DecodePNG <- invalid dimensions";
+ CommonFreeDecode(context);
+ return false;
+ }
+ if (context->channels == 0) { // Autodetect number of channels
+ context->channels = context->info_ptr->channels;
+ }
+ const bool has_tRNS = (context->info_ptr->valid & PNG_INFO_tRNS) != 0;
+ const bool has_alpha = (context->color_type & PNG_COLOR_MASK_ALPHA) != 0;
+ if ((context->channels & 1) == 0) { // We desire alpha
+ if (has_alpha) { // There is alpha
+ } else if (has_tRNS) {
+ png_set_tRNS_to_alpha(context->png_ptr); // Convert transparency to alpha
+ } else {
+ png_set_add_alpha(
+ context->png_ptr, (1 << context->bit_depth) - 1, PNG_FILLER_AFTER);
+ }
+ } else { // We don't want alpha
+ if (has_alpha || has_tRNS) { // There is alpha
+ png_set_strip_alpha(context->png_ptr); // Strip alpha
+ }
+ }
+
+ // If we only want 8 bits, but are given 16, strip off the LS 8 bits
+ if (context->bit_depth > 8 && desired_channel_bits <= 8)
+ png_set_strip_16(context->png_ptr);
+
+ context->need_to_synthesize_16 =
+ (context->bit_depth <= 8 && desired_channel_bits == 16);
+
+ png_set_packing(context->png_ptr);
+ context->num_passes = png_set_interlace_handling(context->png_ptr);
+ png_read_update_info(context->png_ptr, context->info_ptr);
+
+#ifdef IS_LITTLE_ENDIAN
+ if (desired_channel_bits > 8)
+ png_set_swap(context->png_ptr);
+#endif // IS_LITTLE_ENDIAN
+
+ // convert palette to rgb(a) if needs be.
+ if (context->color_type == PNG_COLOR_TYPE_PALETTE)
+ png_set_palette_to_rgb(context->png_ptr);
+
+ // handle grayscale case for source or destination
+ const bool want_gray = (context->channels < 3);
+ const bool is_gray = !(context->color_type & PNG_COLOR_MASK_COLOR);
+ if (is_gray) { // upconvert gray to 8-bit if needed.
+ if (context->bit_depth < 8)
+ png_set_gray_1_2_4_to_8(context->png_ptr);
+ }
+ if (want_gray) { // output is grayscale
+ if (!is_gray)
+ png_set_rgb_to_gray(context->png_ptr, 1, 0.299, 0.587); // 601, JPG
+ } else { // output is rgb(a)
+ if (is_gray)
+ png_set_gray_to_rgb(context->png_ptr); // Enable gray -> RGB conversion
+ }
+ return true;
+}
+
+bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context) {
+ CHECK_NOTNULL(data);
+
+ // we need to re-set the jump point so that we trap the errors
+ // within *this* function (and not CommonInitDecode())
+ if (setjmp(png_jmpbuf(context->png_ptr))) {
+ VLOG(1) << ": DecodePNG error trapped.";
+ CommonFreeDecode(context);
+ return false;
+ }
+ // png_read_row() takes care of offsetting the pointer based on interlacing
+ for (int p = 0; p < context->num_passes; ++p) {
+ png_bytep row = data;
+ for (int h = context->height; h-- != 0; row += row_bytes) {
+ png_read_row(context->png_ptr, row, NULL);
+ }
+ }
+
+ context->info_ptr->valid |= PNG_INFO_IDAT;
+ png_read_end(context->png_ptr, context->info_ptr);
+
+ // Clean up.
+ const bool ok = !context->error_condition;
+ CommonFreeDecode(context);
+
+ // Synthesize 16 bits from 8 if requested.
+ if (context->need_to_synthesize_16)
+ Convert8to16(bit_cast<uint8*>(data), context->channels, row_bytes,
+ context->width, context->height, bit_cast<uint16*>(data),
+ row_bytes);
+ return ok;
+}
+
+bool WriteImageToBuffer(
+ const void* image, int width, int height, int row_bytes, int num_channels,
+ int channel_bits, int compression, string* png_string,
+ const std::vector<std::pair<string, string> >* metadata) {
+ CHECK_NOTNULL(image);
+ CHECK_NOTNULL(png_string);
+ // Although this case is checked inside png.cc and issues an error message,
+ // that error causes memory corruption.
+ if (width == 0 || height == 0)
+ return false;
+
+ png_string->resize(0);
+ png_infop info_ptr = NULL;
+ png_structp png_ptr =
+ png_create_write_struct(PNG_LIBPNG_VER_STRING,
+ NULL, ErrorHandler, WarningHandler);
+ if (png_ptr == NULL) return false;
+ if (setjmp(png_jmpbuf(png_ptr))) {
+ png_destroy_write_struct(&png_ptr, info_ptr ? &info_ptr : NULL);
+ return false;
+ }
+ info_ptr = png_create_info_struct(png_ptr);
+ if (info_ptr == NULL) {
+ png_destroy_write_struct(&png_ptr, NULL);
+ return false;
+ }
+
+ int color_type = -1;
+ switch (num_channels) {
+ case 1:
+ color_type = PNG_COLOR_TYPE_GRAY;
+ break;
+ case 2:
+ color_type = PNG_COLOR_TYPE_GRAY_ALPHA;
+ break;
+ case 3:
+ color_type = PNG_COLOR_TYPE_RGB;
+ break;
+ case 4:
+ color_type = PNG_COLOR_TYPE_RGB_ALPHA;
+ break;
+ default:
+ png_destroy_write_struct(&png_ptr, &info_ptr);
+ return false;
+ }
+
+ png_set_write_fn(png_ptr, png_string, StringWriter, StringWriterFlush);
+ if (compression < 0) compression = Z_DEFAULT_COMPRESSION;
+ png_set_compression_level(png_ptr, compression);
+ png_set_compression_mem_level(png_ptr, MAX_MEM_LEVEL);
+ // There used to be a call to png_set_filter here turning off filtering
+ // entirely, but it produced pessimal compression ratios. I'm not sure
+ // why it was there.
+ png_set_IHDR(png_ptr, info_ptr, width, height, channel_bits, color_type,
+ PNG_INTERLACE_NONE, PNG_COMPRESSION_TYPE_DEFAULT,
+ PNG_FILTER_TYPE_DEFAULT);
+ // If we have metadata write to it.
+ if (metadata && !metadata->empty()) {
+ std::vector<png_text> text;
+ for (const auto& pair : *metadata) {
+ png_text txt;
+ txt.compression = PNG_TEXT_COMPRESSION_NONE;
+ txt.key = check_metadata_string(pair.first);
+ txt.text = check_metadata_string(pair.second);
+ text.push_back(txt);
+ }
+ png_set_text(png_ptr, info_ptr, &text[0], text.size());
+ }
+
+ png_write_info(png_ptr, info_ptr);
+#ifdef IS_LITTLE_ENDIAN
+ if (channel_bits > 8)
+ png_set_swap(png_ptr);
+#endif // IS_LITTLE_ENDIAN
+
+ png_byte* row = reinterpret_cast<png_byte*>(const_cast<void*>(image));
+ for (; height--; row += row_bytes) png_write_row(png_ptr, row);
+ png_write_end(png_ptr, NULL);
+
+ png_destroy_write_struct(&png_ptr, &info_ptr);
+ return true;
+}
+
+} // namespace png
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/png/png_io.h b/tensorflow/core/lib/png/png_io.h
new file mode 100644
index 0000000000..df9bff7be8
--- /dev/null
+++ b/tensorflow/core/lib/png/png_io.h
@@ -0,0 +1,88 @@
+// Functions to read and write images in PNG format.
+//
+// The advantage over image/codec/png{enc,dec}ocder.h is that this library
+// supports both 8 and 16 bit images.
+//
+// The decoding routine accepts binary image data as a StringPiece. These are
+// implicitly constructed from strings or char* so they're completely
+// transparent to the caller. They're also very cheap to construct so this
+// doesn't introduce any additional overhead.
+//
+// The primary benefit of StringPieces being, in this case, that APIs already
+// returning StringPieces (e.g., Bigtable Scanner) or Cords (e.g., IOBuffer;
+// only when they're flat, though) or protocol buffer fields typed to either of
+// these can be decoded without copying the data into a C++ string.
+
+#ifndef TENSORFLOW_LIB_PNG_PNG_IO_H_
+#define TENSORFLOW_LIB_PNG_PNG_IO_H_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "external/png_archive/libpng-1.2.53/png.h"
+
+namespace tensorflow {
+namespace png {
+
+// Handy container for decoding informations and struct pointers
+struct DecodeContext {
+ const uint8* data;
+ int data_left;
+ png_structp png_ptr;
+ png_infop info_ptr;
+ png_uint_32 width, height;
+ int num_passes;
+ int color_type;
+ int bit_depth;
+ int channels;
+ bool need_to_synthesize_16;
+ bool error_condition;
+ DecodeContext() : png_ptr(NULL), info_ptr(NULL) {}
+};
+
+bool DecodeHeader(StringPiece png_string, int* width, int* height,
+ int* components, int* channel_bit_depth,
+ std::vector<std::pair<string, string> >* metadata);
+
+// Sample usage for reading PNG:
+//
+// string png_string; /* fill with input PNG format data */
+// DecodeContext context;
+// CHECK(CommonInitDecode(png_string, 3 /*RGB*/, 8 /*uint8*/, &context));
+// char* image_buffer = new char[3*context.width*context.height];
+// CHECK(CommonFinishDecode(bit_cast<png_byte*>(image_buffer),
+// 3*context.width /*stride*/, &context));
+//
+// desired_channels may be 0 to detected it from the input.
+
+bool CommonInitDecode(StringPiece png_string, int desired_channels,
+ int desired_channel_bits, DecodeContext* context);
+
+bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context);
+
+// Normally called automatically from CommonFinishDecode. If CommonInitDecode
+// is called but not CommonFinishDecode, call this to clean up. Safe to call
+// extra times.
+void CommonFreeDecode(DecodeContext* context);
+
+// Sample usage for writing PNG:
+//
+// uint16* image_buffer = new uint16[width*height]; /* fill with pixels */
+// string png_string;
+// CHECK(WriteImageToBuffer(image_buffer, width, height, 2*width /*stride*/,
+// 1 /*gray*/, 16 /*uint16*/, &png_string, NULL));
+//
+// compression is in [-1,9], where 0 is fast and weak compression, 9 is slow
+// and strong, and -1 is the zlib default.
+
+bool WriteImageToBuffer(
+ const void* image, int width, int height, int row_bytes, int num_channels,
+ int channel_bits, int compression, string* png_string,
+ const std::vector<std::pair<string, string> >* metadata);
+
+} // namespace png
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_PNG_PNG_IO_H_
diff --git a/tensorflow/core/lib/png/testdata/lena_gray.png b/tensorflow/core/lib/png/testdata/lena_gray.png
new file mode 100644
index 0000000000..8bc73159b0
--- /dev/null
+++ b/tensorflow/core/lib/png/testdata/lena_gray.png
Binary files differ
diff --git a/tensorflow/core/lib/png/testdata/lena_rgba.png b/tensorflow/core/lib/png/testdata/lena_rgba.png
new file mode 100644
index 0000000000..79f1f84a62
--- /dev/null
+++ b/tensorflow/core/lib/png/testdata/lena_rgba.png
Binary files differ
diff --git a/tensorflow/core/lib/random/distribution_sampler.cc b/tensorflow/core/lib/random/distribution_sampler.cc
new file mode 100644
index 0000000000..341f1bd595
--- /dev/null
+++ b/tensorflow/core/lib/random/distribution_sampler.cc
@@ -0,0 +1,80 @@
+#include "tensorflow/core/lib/random/distribution_sampler.h"
+
+#include <memory>
+#include <vector>
+
+namespace tensorflow {
+namespace random {
+
+DistributionSampler::DistributionSampler(
+ const gtl::ArraySlice<float>& weights) {
+ DCHECK(!weights.empty());
+ int n = weights.size();
+ num_ = n;
+ data_.reset(new std::pair<float, int>[n]);
+
+ std::unique_ptr<double[]> pr(new double[n]);
+
+ double sum = 0.0;
+ for (int i = 0; i < n; i++) {
+ sum += weights[i];
+ set_alt(i, -1);
+ }
+
+ // These are long/short items - called high/low because of reserved keywords.
+ std::vector<int> high;
+ high.reserve(n);
+ std::vector<int> low;
+ low.reserve(n);
+
+ // compute propotional weights
+ for (int i = 0; i < n; i++) {
+ double p = (weights[i] * n) / sum;
+ pr[i] = p;
+ if (p < 1.0) {
+ low.push_back(i);
+ } else {
+ high.push_back(i);
+ }
+ }
+
+ // Now pair high with low.
+ while (!high.empty() && !low.empty()) {
+ int l = low.back();
+ low.pop_back();
+ int h = high.back();
+ high.pop_back();
+
+ set_alt(l, h);
+ DCHECK_GE(pr[h], 1.0);
+ double remaining = pr[h] - (1.0 - pr[l]);
+ pr[h] = remaining;
+
+ if (remaining < 1.0) {
+ low.push_back(h);
+ } else {
+ high.push_back(h);
+ }
+ }
+ // Transfer pr to prob with rounding errors.
+ for (int i = 0; i < n; i++) {
+ set_prob(i, pr[i]);
+ }
+ // Because of rounding errors, both high and low may have elements, that are
+ // close to 1.0 prob.
+ for (size_t i = 0; i < high.size(); i++) {
+ int idx = high[i];
+ set_prob(idx, 1.0);
+ // set alt to self to prevent rounding errors returning 0
+ set_alt(idx, idx);
+ }
+ for (size_t i = 0; i < low.size(); i++) {
+ int idx = low[i];
+ set_prob(idx, 1.0);
+ // set alt to self to prevent rounding errors returning 0
+ set_alt(idx, idx);
+ }
+}
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/distribution_sampler.h b/tensorflow/core/lib/random/distribution_sampler.h
new file mode 100644
index 0000000000..ab9598a205
--- /dev/null
+++ b/tensorflow/core/lib/random/distribution_sampler.h
@@ -0,0 +1,79 @@
+// DistributionSampler allows generating a discrete random variable with a given
+// distribution.
+// The values taken by the variable are [0, N) and relative weights for each
+// value are specified using a vector of size N.
+//
+// The Algorithm takes O(N) time to precompute data at construction time and
+// takes O(1) time (2 random number generation, 2 lookups) for each sample.
+// The data structure takes O(N) memory.
+//
+// In contrast, util/random/weighted-picker.h provides O(lg N) sampling.
+// The advantage of that implementation is that weights can be adjusted
+// dynamically, while DistributionSampler doesn't allow weight adjustment.
+//
+// The algorithm used is Walker's Aliasing algorithm, described in Knuth, Vol 2.
+
+#ifndef TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
+#define TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace random {
+
+class DistributionSampler {
+ public:
+ explicit DistributionSampler(const gtl::ArraySlice<float>& weights);
+
+ ~DistributionSampler() {}
+
+ int Sample(SimplePhilox* rand) const {
+ float r = rand->RandFloat();
+ // Since n is typically low, we don't bother with UnbiasedUniform.
+ int idx = rand->Uniform(num_);
+ if (r < prob(idx)) return idx;
+ // else pick alt from that bucket.
+ DCHECK_NE(-1, alt(idx));
+ return alt(idx);
+ }
+
+ int num() const { return num_; }
+
+ private:
+ float prob(int idx) const {
+ DCHECK_LT(idx, num_);
+ return data_[idx].first;
+ }
+
+ int alt(int idx) const {
+ DCHECK_LT(idx, num_);
+ return data_[idx].second;
+ }
+
+ void set_prob(int idx, float f) {
+ DCHECK_LT(idx, num_);
+ data_[idx].first = f;
+ }
+
+ void set_alt(int idx, int val) {
+ DCHECK_LT(idx, num_);
+ data_[idx].second = val;
+ }
+
+ int num_;
+ std::unique_ptr<std::pair<float, int>[]> data_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DistributionSampler);
+};
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
diff --git a/tensorflow/core/lib/random/distribution_sampler_test.cc b/tensorflow/core/lib/random/distribution_sampler_test.cc
new file mode 100644
index 0000000000..d61a8daa0f
--- /dev/null
+++ b/tensorflow/core/lib/random/distribution_sampler_test.cc
@@ -0,0 +1,90 @@
+#include "tensorflow/core/lib/random/distribution_sampler.h"
+
+#include <string.h>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+
+class DistributionSamplerTest : public ::testing::Test {
+ protected:
+ // Returns the Chi-Squared statistic for the two distributions.
+ float TestWeights(const std::vector<float>& weights, int trials_per_bin) {
+ int iters = weights.size() * trials_per_bin;
+ std::unique_ptr<float[]> counts(new float[weights.size()]);
+ memset(counts.get(), 0, sizeof(float) * weights.size());
+ DistributionSampler sampler(weights);
+ PhiloxRandom philox(testing::RandomSeed(), 17);
+ SimplePhilox random(&philox);
+ for (int i = 0; i < iters; i++) {
+ int r = sampler.Sample(&random);
+ EXPECT_LT(r, weights.size());
+ EXPECT_GE(r, 0);
+ counts[r] += 1.0;
+ }
+ float chi2 = 0.0;
+ for (size_t i = 0; i < weights.size(); i++) {
+ counts[i] /= iters;
+ float err = (counts[i] - weights[i]);
+ chi2 += (err * err) / weights[i];
+ }
+ return chi2;
+ }
+
+ void TestDistribution(float* arr, int n) {
+ std::vector<float> w;
+ w.reserve(n);
+ for (int i = 0; i < n; i++) {
+ w.push_back(arr[i]);
+ }
+ float var = TestWeights(w, 1000);
+ if (var < 0.001) return;
+ // Maybe a statistical skew. Let's try more iterations.
+ var = TestWeights(w, 100000);
+ if (var < 0.001) return;
+ EXPECT_TRUE(false) << "Chi2 is " << var << " in " << n * 100000
+ << "iterations";
+ }
+};
+
+TEST_F(DistributionSamplerTest, KnownDistribution) {
+ float kEven2[] = {0.5, 0.5};
+ float kEven3[] = {0.33333333, 0.33333333, 0.33333333};
+ float kEven4[] = {0.25, 0.25, 0.25, 0.25};
+
+ float kDist1[] = {0.8, 0.15, 0.05};
+
+ TestDistribution(kEven2, TF_ARRAYSIZE(kEven2));
+ TestDistribution(kEven3, TF_ARRAYSIZE(kEven3));
+ TestDistribution(kEven4, TF_ARRAYSIZE(kEven4));
+ TestDistribution(kDist1, TF_ARRAYSIZE(kDist1));
+}
+
+static void BM_DistributionSampler(int iters, int n) {
+ testing::StopTiming();
+ PhiloxRandom philox(173, 371);
+ SimplePhilox rand(&philox);
+ std::vector<float> weights(n, 0);
+ for (int i = 0; i < n; i++) {
+ weights[i] = rand.Uniform(100);
+ }
+ DistributionSampler picker(weights);
+ testing::StartTiming();
+ int r = 0;
+ for (int i = 0; i < iters; i++) {
+ r |= picker.Sample(&rand);
+ }
+ CHECK_NE(r, kint32max);
+}
+
+BENCHMARK(BM_DistributionSampler)->Arg(10)->Arg(100)->Arg(1000);
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/exact_uniform_int.h b/tensorflow/core/lib/random/exact_uniform_int.h
new file mode 100644
index 0000000000..616354cc5c
--- /dev/null
+++ b/tensorflow/core/lib/random/exact_uniform_int.h
@@ -0,0 +1,68 @@
+// Exact uniform integers using rejection sampling
+
+#ifndef TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_
+#define TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_
+
+#include <type_traits>
+
+namespace tensorflow {
+namespace random {
+
+template <typename UintType, typename RandomBits>
+UintType ExactUniformInt(const UintType n, const RandomBits& random) {
+ static_assert(std::is_unsigned<UintType>::value,
+ "UintType must be an unsigned int");
+ static_assert(std::is_same<UintType, decltype(random())>::value,
+ "random() should return UintType");
+ if (n == 0) {
+ // Consume a value anyway
+ // TODO(irving): Assert n != 0, since this case makes no sense.
+ return random() * n;
+ } else if (0 == (n & (n - 1))) {
+ // N is a power of two, so just mask off the lower bits.
+ return random() & (n - 1);
+ } else {
+ // Reject all numbers that skew the distribution towards 0.
+
+ // random's output is uniform in the half-open interval [0, 2^{bits}).
+ // For any interval [m,n), the number of elements in it is n-m.
+
+ const UintType range = ~static_cast<UintType>(0);
+ const UintType rem = (range % n) + 1;
+ UintType rnd;
+
+ // rem = ((2^bits-1) \bmod n) + 1
+ // 1 <= rem <= n
+
+ // NB: rem == n is impossible, since n is not a power of 2 (from
+ // earlier check).
+
+ do {
+ rnd = random(); // rnd uniform over [0, 2^{bits})
+ } while (rnd < rem); // reject [0, rem)
+ // rnd is uniform over [rem, 2^{bits})
+ //
+ // The number of elements in the half-open interval is
+ //
+ // 2^{bits} - rem = 2^{bits} - ((2^{bits}-1) \bmod n) - 1
+ // = 2^{bits}-1 - ((2^{bits}-1) \bmod n)
+ // = n \cdot \lfloor (2^{bits}-1)/n \rfloor
+ //
+ // therefore n evenly divides the number of integers in the
+ // interval.
+ //
+ // The function v \rightarrow v % n takes values from [bias,
+ // 2^{bits}) to [0, n). Each integer in the range interval [0, n)
+ // will have exactly \lfloor (2^{bits}-1)/n \rfloor preimages from
+ // the domain interval.
+ //
+ // Therefore, v % n is uniform over [0, n). QED.
+
+ return rnd % n;
+ }
+}
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_EXACT_UNIFORM_H_
diff --git a/tensorflow/core/lib/random/philox_random.h b/tensorflow/core/lib/random/philox_random.h
new file mode 100644
index 0000000000..2c3cd0c4b9
--- /dev/null
+++ b/tensorflow/core/lib/random/philox_random.h
@@ -0,0 +1,232 @@
+// Implement the Philox algorithm to generate random numbers in parallel.
+// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
+// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
+
+#ifndef TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_
+#define TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_
+
+#include <stdlib.h>
+
+#include "tensorflow/core/platform/port.h"
+
+// Function qualifiers that need to work on both CPU and GPU.
+#ifdef __CUDA_ARCH__
+// For nvcc.
+#define PHILOX_DEVICE_FUNC __host__ __device__
+#define PHILOX_INLINE __inline__
+#else
+// For non-nvcc.
+#define PHILOX_DEVICE_FUNC
+#define PHILOX_INLINE inline
+#endif
+#define PHILOX_DEVICE_INLINE PHILOX_DEVICE_FUNC PHILOX_INLINE
+
+#include <math.h>
+
+namespace tensorflow {
+namespace random {
+
+// A class that represents an inline array. It can be used on both CPU and GPU,
+// and also trivially copyable between CPU and GPU.
+// Arguments:
+// T: the array element type;
+// ElementCount: the fixed size of the array;
+template <typename T, int ElementCount>
+class Array {
+ public:
+ PHILOX_DEVICE_INLINE Array() {
+ for (int i = 0; i < ElementCount; ++i) {
+ data_[i] = T();
+ }
+ }
+
+ PHILOX_DEVICE_INLINE const T& operator[](int index) const {
+ return data_[index];
+ }
+
+ PHILOX_DEVICE_INLINE T& operator[](int index) { return data_[index]; }
+
+ size_t size() const { return ElementCount; }
+
+ private:
+ T data_[ElementCount];
+};
+
+// A class that encapsulates all the states for a random number generator using
+// the philox_4x32_10 algorithm. Each invocation returns a 128-bit random bits
+// in the form of four uint32.
+// There are multiple variants of this algorithm, we picked the 4x32_10 version
+// that is most suited for our applications.
+// Since this class is meant to be copied between CPU to GPU, it maintains a
+// value semantics.
+//
+// For example: To use this class and populate an array of 1024 randoms on CPU
+// with two threads,
+//
+// void Fill(PhiloxRandom rnd, uint32* output, int start, int limit) {
+// assert(start % 4 == 0);
+// assert(limit % 4 == 0);
+// rnd.Skip(start / 4);
+// for (int i = start; i < limit; i += 4) {
+// auto sample = rnd();
+// ... copy sample[0..3] to output[i..i+3]
+// }
+// }
+//
+// PhiloxRandom rng(seed);
+// PhiloxRandom rng_copy = rng;
+// rng.Skip(1000/4);
+//
+// ... schedule Fill(rng_copy, output, 0, 512) in thread 1;
+// ... schedule Fill(rng_copy, output, 512, 1024) in thread 2;
+// ... wait for thread 1 & 2 to finish executing Fill().
+//
+// NOTE:
+// 1. PhiloxRandom is trivially copyable.
+// 2. PhiloxRandom is compilable by gcc and nvcc.
+class PhiloxRandom {
+ public:
+ typedef Array<uint32, 4> ResultType;
+ typedef uint32 ResultElementType;
+ // The number of elements that will be returned.
+ static const int kResultElementCount = 4;
+
+ PHILOX_DEVICE_INLINE
+ PhiloxRandom() {}
+
+ PHILOX_DEVICE_INLINE
+ explicit PhiloxRandom(uint64 seed) {
+ key_[0] = static_cast<uint32>(seed);
+ key_[1] = static_cast<uint32>(seed >> 32);
+ }
+
+ PHILOX_DEVICE_INLINE
+ explicit PhiloxRandom(uint64 seed_lo, uint64 seed_hi) {
+ key_[0] = static_cast<uint32>(seed_lo);
+ key_[1] = static_cast<uint32>(seed_lo >> 32);
+ counter_[2] = static_cast<uint32>(seed_hi);
+ counter_[3] = static_cast<uint32>(seed_hi >> 32);
+ }
+
+ // Skip the specified number of samples of 128-bits in the current stream.
+ PHILOX_DEVICE_INLINE
+ void Skip(uint64 count) {
+ const uint32 count_lo = static_cast<uint32>(count);
+ uint32 count_hi = static_cast<uint32>(count >> 32);
+
+ counter_[0] += count_lo;
+ if (counter_[0] < count_lo) {
+ ++count_hi;
+ }
+
+ counter_[1] += count_hi;
+ if (counter_[1] < count_hi) {
+ if (++counter_[2] == 0) {
+ ++counter_[3];
+ }
+ }
+ }
+
+ // Returns a group of four random numbers using the underlying Philox
+ // algorithm.
+ PHILOX_DEVICE_INLINE ResultType operator()() {
+ ResultType counter = counter_;
+ Key key = key_;
+
+ // Run the single rounds for ten times. Manually unrolling the loop
+ // for better performance.
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+
+ SkipOne();
+
+ return counter;
+ }
+
+ private:
+ // The type for the 64-bit key stored in the form of two 32-bit uint
+ // that are used in the diffusion process.
+ typedef Array<uint32, 2> Key;
+
+ // We use the same constants as recommended by the original paper.
+ static const uint32 kPhiloxW32A = 0x9E3779B9;
+ static const uint32 kPhiloxW32B = 0xBB67AE85;
+ static const uint32 kPhiloxM4x32A = 0xD2511F53;
+ static const uint32 kPhiloxM4x32B = 0xCD9E8D57;
+
+ // Helper function to skip the next sample of 128-bits in the current stream.
+ PHILOX_DEVICE_INLINE void SkipOne() {
+ if (++counter_[0] == 0) {
+ if (++counter_[1] == 0) {
+ if (++counter_[2] == 0) {
+ ++counter_[3];
+ }
+ }
+ }
+ }
+
+ // Helper function to return the lower and higher 32-bits from two 32-bit
+ // integer multiplications.
+ PHILOX_DEVICE_INLINE
+ static void MultiplyHighLow(uint32 a, uint32 b, uint32* result_low,
+ uint32* result_high) {
+#ifndef __GCUDACC__
+ const uint64 product = static_cast<uint64>(a) * b;
+ *result_low = static_cast<uint32>(product);
+ *result_high = static_cast<uint32>(product >> 32);
+#else
+ *result_low = a * b;
+ *result_high = __umulhi(a, b);
+#endif
+ }
+
+ // Helper function for a single round of the underlying Philox algorithm.
+ PHILOX_DEVICE_INLINE static ResultType ComputeSingleRound(
+ const ResultType& counter, const Key& key) {
+ uint32 lo0;
+ uint32 hi0;
+ MultiplyHighLow(kPhiloxM4x32A, counter[0], &lo0, &hi0);
+
+ uint32 lo1;
+ uint32 hi1;
+ MultiplyHighLow(kPhiloxM4x32B, counter[2], &lo1, &hi1);
+
+ ResultType result;
+ result[0] = hi1 ^ counter[1] ^ key[0];
+ result[1] = lo1;
+ result[2] = hi0 ^ counter[3] ^ key[1];
+ result[3] = lo0;
+ return result;
+ }
+
+ PHILOX_DEVICE_INLINE void RaiseKey(Key* key) {
+ (*key)[0] += kPhiloxW32A;
+ (*key)[1] += kPhiloxW32B;
+ }
+
+ private:
+ ResultType counter_;
+ Key key_;
+};
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_
diff --git a/tensorflow/core/lib/random/philox_random_test.cc b/tensorflow/core/lib/random/philox_random_test.cc
new file mode 100644
index 0000000000..997c0263b7
--- /dev/null
+++ b/tensorflow/core/lib/random/philox_random_test.cc
@@ -0,0 +1,58 @@
+#include "tensorflow/core/lib/random/philox_random.h"
+
+#include <math.h>
+#include <algorithm>
+#include <functional>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/random/philox_random_test_utils.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+namespace {
+
+// A trivial distribution that just returns the PhiloxRandom as a distribution
+class TrivialPhiloxDistribution {
+ public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = PhiloxRandom::kResultElementCount;
+ typedef PhiloxRandom::ResultType ResultType;
+ typedef PhiloxRandom::ResultElementType ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(PhiloxRandom* gen) { return (*gen)(); }
+};
+
+// This test checks that skipping certain number of samples, is equivalent to
+// generate the same number of samples without skipping.
+TEST(PhiloxRandomTest, SkipMatchTest) {
+ constexpr int count = 1024;
+ constexpr int skip_count = 2048;
+
+ uint64 test_seed = GetTestSeed();
+ std::vector<uint32> v1(count);
+ {
+ PhiloxRandom gen(test_seed);
+ gen.Skip(skip_count / 4);
+ FillRandoms<TrivialPhiloxDistribution>(gen, &v1[0], v1.size());
+ }
+
+ std::vector<uint32> v2(count + skip_count);
+ {
+ PhiloxRandom gen(test_seed);
+ FillRandoms<TrivialPhiloxDistribution>(gen, &v2[0], v2.size());
+ }
+
+ for (int i = 0; i < count; ++i) {
+ ASSERT_EQ(v1[i], v2[i + skip_count]);
+ }
+}
+
+} // namespace
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/philox_random_test_utils.h b/tensorflow/core/lib/random/philox_random_test_utils.h
new file mode 100644
index 0000000000..d22f6b36e4
--- /dev/null
+++ b/tensorflow/core/lib/random/philox_random_test_utils.h
@@ -0,0 +1,36 @@
+#ifndef TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_
+#define TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_
+
+#include <algorithm>
+
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace random {
+
+// Return a random seed.
+inline uint64 GetTestSeed() { return New64(); }
+
+// A utility function to fill the given array with samples from the given
+// distribution.
+template <class Distribution>
+void FillRandoms(PhiloxRandom gen, typename Distribution::ResultElementType* p,
+ int64 size) {
+ const int granularity = Distribution::kResultElementCount;
+
+ CHECK(size % granularity == 0) << " size: " << size
+ << " granularity: " << granularity;
+
+ Distribution dist;
+ for (int i = 0; i < size; i += granularity) {
+ const auto sample = dist(&gen);
+ std::copy(&sample[0], &sample[0] + granularity, &p[i]);
+ }
+}
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_
diff --git a/tensorflow/core/lib/random/random.cc b/tensorflow/core/lib/random/random.cc
new file mode 100644
index 0000000000..2959b05382
--- /dev/null
+++ b/tensorflow/core/lib/random/random.cc
@@ -0,0 +1,22 @@
+#include "tensorflow/core/lib/random/random.h"
+
+#include <random>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace random {
+
+std::mt19937_64* InitRng() {
+ std::random_device device("/dev/random");
+ return new std::mt19937_64(device());
+}
+
+uint64 New64() {
+ static std::mt19937_64* rng = InitRng();
+ static mutex mu;
+ mutex_lock l(mu);
+ return (*rng)();
+}
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/random.h b/tensorflow/core/lib/random/random.h
new file mode 100644
index 0000000000..1a20436c4e
--- /dev/null
+++ b/tensorflow/core/lib/random/random.h
@@ -0,0 +1,16 @@
+#ifndef TENSORFLOW_LIB_RANDOM_RANDOM_H_
+#define TENSORFLOW_LIB_RANDOM_RANDOM_H_
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace random {
+
+// Return a 64-bit random value. Different sequences are generated
+// in different processes.
+uint64 New64();
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_RANDOM_H_
diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h
new file mode 100644
index 0000000000..caafcde513
--- /dev/null
+++ b/tensorflow/core/lib/random/random_distributions.h
@@ -0,0 +1,361 @@
+#ifndef TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
+#define TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
+
+#include <math.h>
+#include <string.h>
+#include <algorithm>
+
+#include "tensorflow/core/lib/random/philox_random.h"
+
+namespace tensorflow {
+namespace random {
+
+// Helper function to convert a 32-bit integer to a float between [0..1).
+PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x);
+// Helper function to convert two 32-bit integers to a double between [0..1).
+PHILOX_DEVICE_INLINE double Uint64ToDouble(uint32 x0, uint32 x1);
+
+// A class that generates uniform distribution random numbers from the
+// underlying random integer generator.
+// Arguments:
+// Generator: a generator type that returns a number of uint32 upon each
+// each invocation. It needs to define kResultElementCount for the
+// sample count for each invocation, and ResultType for actual
+// returned sample type.
+// RealType: the data type of the real numberes that will be returned by the
+// distribution. This could be either float or double for now.
+// This class is meant to be implemented through specialization. The default
+// is not defined by design.
+template <class Generator, typename RealType>
+class UniformDistribution;
+
+template <class Generator>
+class UniformDistribution<Generator, float> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = Generator::kResultElementCount;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = false;
+ typedef Array<float, kResultElementCount> ResultType;
+ typedef float ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator* gen) {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; ++i) {
+ result[i] = Uint32ToFloat(sample[i]);
+ }
+ return result;
+ }
+};
+
+template <class Generator>
+class UniformDistribution<Generator, double> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = Generator::kResultElementCount / 2;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = false;
+ typedef Array<double, kResultElementCount> ResultType;
+ typedef double ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator* gen) {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; ++i) {
+ result[i] = Uint64ToDouble(sample[2 * i], sample[2 * i + 1]);
+ }
+ return result;
+ }
+};
+
+// A class that adapts the underlying native multiple samples to return a single
+// sample at a time.
+template <class Generator>
+class SingleSampleAdapter {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = 1;
+ // The number of elements that will be returned by the underlying generator.
+ static const int kNativeElementCount = Generator::kResultElementCount;
+ typedef typename Generator::ResultElementType ResultType;
+ typedef typename Generator::ResultElementType ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ explicit SingleSampleAdapter(Generator* gen)
+ : generator_(gen), used_result_index_(Generator::kResultElementCount) {}
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()() {
+ if (used_result_index_ == Generator::kResultElementCount) {
+ unused_results_ = (*generator_)();
+ used_result_index_ = 0;
+ }
+
+ return unused_results_[used_result_index_++];
+ }
+
+ private:
+ Generator* generator_;
+ typename Generator::ResultType unused_results_;
+ int used_result_index_;
+};
+
+// A class that generates unit normal distribution random numbers from the
+// underlying random integer generator.
+// Arguments:
+// Generator: a generator type that returns a number of uint32 upon each
+// each invocation. It needs to define kResultElementCount for the
+// sample count for each invocation, and ResultType for actual
+// returned sample type.
+// RealType: the data type of the real numberes that will be returned by the
+// distribution. This could be either float or double for now.
+// This class is meant to be implemented through specialization. The default
+// is not defined by design.
+template <class Generator, typename RealType>
+class NormalDistribution;
+
+PHILOX_DEVICE_INLINE
+void BoxMullerFloat(uint32 x0, uint32 x1, float* f0, float* f1);
+
+PHILOX_DEVICE_INLINE
+void BoxMullerDouble(uint32 x0, uint32 x1, uint32 x2, uint32 x3, double* d0,
+ double* d1);
+
+template <class Generator>
+class NormalDistribution<Generator, float> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = Generator::kResultElementCount;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = false;
+ typedef Array<float, kResultElementCount> ResultType;
+ typedef float ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator* gen) {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; i += 2) {
+ BoxMullerFloat(sample[i], sample[i + 1], &result[i], &result[i + 1]);
+ }
+ return result;
+ }
+};
+
+template <class Generator>
+class NormalDistribution<Generator, double> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount = Generator::kResultElementCount / 2;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = false;
+ typedef Array<double, kResultElementCount> ResultType;
+ typedef double ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator* gen) {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; i += 2) {
+ const int i2 = 2 * i;
+ BoxMullerDouble(sample[i2], sample[i2 + 1], sample[i2 + 2],
+ sample[i2 + 3], &result[i], &result[i + 1]);
+ }
+ return result;
+ }
+};
+
+// A class that returns standard normal distribution between
+// [-kTruncateValue, kTruncateValue].
+// Arguments:
+// Generator: a generator type that returns a number of uint32 upon each
+// each invocation. It needs to define kResultElementCount for the
+// sample count for each invocation, and ResultType for actual
+// returned sample type.
+// RealType: the data type of the real numberes that will be returned by the
+// distribution. This could be either float or double for now.
+// This class is meant to be implemented through specialization. The default
+// is not defined by design.
+template <class SingleSampleGenerator, typename RealType>
+class TruncatedNormalDistribution;
+
+// Partial specialization for float.
+template <class SingleSampleGenerator>
+class TruncatedNormalDistribution<SingleSampleGenerator, float> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount =
+ SingleSampleGenerator::kNativeElementCount;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = true;
+ // The threshold where the normal distribution is truncated.
+ const float kTruncateValue = 2.0f;
+
+ typedef Array<float, kResultElementCount> ResultType;
+ typedef float ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(SingleSampleGenerator* gen) {
+ ResultType results;
+ int index = 0;
+ while (true) {
+ // Repeatedly take samples from the normal distribution, until we have
+ // the desired number of elements that fall within the pre-defined cutoff
+ // threshold.
+ const uint32 x0 = (*gen)();
+ const uint32 x1 = (*gen)();
+ float f[2];
+ BoxMullerFloat(x0, x1, &f[0], &f[1]);
+
+ for (int i = 0; i < 2; ++i) {
+ if (fabs(f[i]) < kTruncateValue) {
+ results[index++] = f[i];
+ if (index >= kResultElementCount) {
+ return results;
+ }
+ }
+ }
+ }
+ }
+};
+
+// Partial specialization for double.
+template <class SingleSampleGenerator>
+class TruncatedNormalDistribution<SingleSampleGenerator, double> {
+ public:
+ // The number of elements that will be returned.
+ static const int kResultElementCount =
+ (SingleSampleGenerator::kNativeElementCount > 1)
+ ? SingleSampleGenerator::kNativeElementCount / 2
+ : 1;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static const bool kVariableSamplesPerOutput = true;
+ typedef Array<double, kResultElementCount> ResultType;
+ typedef double ResultElementType;
+ const double kTruncateValue = 2.0;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(SingleSampleGenerator* gen) {
+ ResultType results;
+ int index = 0;
+ while (1) {
+ const uint32 x0 = (*gen)();
+ const uint32 x1 = (*gen)();
+ const uint32 x2 = (*gen)();
+ const uint32 x3 = (*gen)();
+ double d[2];
+ BoxMullerDouble(x0, x1, x2, x3, &d[0], &d[1]);
+
+ for (int i = 0; i < 2; ++i) {
+ if (fabs(d[i]) < kTruncateValue) {
+ results[index++] = d[i];
+ if (index >= kResultElementCount) {
+ return results;
+ }
+ }
+ }
+ }
+ }
+};
+
+// Helper function to convert two 32-bit uniform integers to two floats
+// under the unit normal distribution.
+PHILOX_DEVICE_INLINE
+void BoxMullerFloat(uint32 x0, uint32 x1, float* f0, float* f1) {
+ // This function implements the Box-Muller transform:
+ // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
+ // Do not send a really small number to log().
+ // We cannot mark "epsilon" as "static const" because NVCC would complain
+ const float epsilon = 1.0e-7f;
+ float u1 = Uint32ToFloat(x0);
+ if (u1 < epsilon) {
+ u1 = epsilon;
+ }
+ const float v1 = 2.0f * M_PI * Uint32ToFloat(x1);
+ const float u2 = sqrt(-2.0f * log(u1));
+#if defined(__linux)
+ sincosf(v1, f0, f1);
+#else
+ *f0 = sinf(v1);
+ *f1 = cosf(v1);
+#endif
+ *f0 *= u2;
+ *f1 *= u2;
+}
+
+// Helper function to convert four 32-bit uniform integers to two doubles
+// under the unit normal distribution.
+PHILOX_DEVICE_INLINE
+void BoxMullerDouble(uint32 x0, uint32 x1, uint32 x2, uint32 x3, double* d0,
+ double* d1) {
+ // This function implements the Box-Muller transform:
+ // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
+ // Do not send a really small number to log().
+ // We cannot mark "epsilon" as "static const" because NVCC would complain
+ const double epsilon = 1.0e-7;
+ double u1 = Uint64ToDouble(x0, x1);
+ if (u1 < epsilon) {
+ u1 = epsilon;
+ }
+ const double v1 = 2 * M_PI * Uint64ToDouble(x2, x3);
+ const double u2 = sqrt(-2.0 * log(u1));
+#if defined(__linux)
+ sincos(v1, d0, d1);
+#else
+ *d0 = sin(v1);
+ *d1 = cos(v1);
+#endif
+ *d0 *= u2;
+ *d1 *= u2;
+}
+
+// Helper function to convert an 32-bit integer to a float between [0..1).
+PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32 x) {
+ // IEEE754 floats are formatted as follows (MSB first):
+ // sign(1) exponent(8) mantissa(23)
+ // Conceptually construct the following:
+ // sign == 0
+ // exponent == 127 -- an excess 127 representation of a zero exponent
+ // mantissa == 23 random bits
+ const uint32 man = x & 0x7fffffu; // 23 bit mantissa
+ const uint32 exp = static_cast<uint32>(127);
+ const uint32 val = (exp << 23) | man;
+
+ // Assumes that endian-ness is same for float and uint32.
+ float result;
+ memcpy(&result, &val, sizeof(val));
+ return result - 1.0f;
+}
+
+// Helper function to convert two 32-bit integers to a double between [0..1).
+PHILOX_DEVICE_INLINE double Uint64ToDouble(uint32 x0, uint32 x1) {
+ // IEEE754 doubles are formatted as follows (MSB first):
+ // sign(1) exponent(11) mantissa(52)
+ // Conceptually construct the following:
+ // sign == 0
+ // exponent == 1023 -- an excess 1023 representation of a zero exponent
+ // mantissa == 52 random bits
+ const uint32 mhi = x0 & 0xfffffu; // upper 20 bits of mantissa
+ const uint32 mlo = x1; // lower 32 bits of mantissa
+ const uint64 man = (static_cast<uint64>(mhi) << 32) | mlo; // mantissa
+ const uint64 exp = static_cast<uint64>(1023);
+ const uint64 val = (exp << 52) | man;
+ // Assumes that endian-ness is same for double and uint64.
+ double result;
+ memcpy(&result, &val, sizeof(val));
+ return result - 1.0;
+}
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
diff --git a/tensorflow/core/lib/random/random_distributions_test.cc b/tensorflow/core/lib/random/random_distributions_test.cc
new file mode 100644
index 0000000000..3ce86a907a
--- /dev/null
+++ b/tensorflow/core/lib/random/random_distributions_test.cc
@@ -0,0 +1,270 @@
+#include "tensorflow/core/lib/random/random_distributions.h"
+
+#include <math.h>
+#include <algorithm>
+#include <functional>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/philox_random_test_utils.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+namespace {
+
+// The largest z-value we want to tolerate. Since the z-test approximates a
+// unit normal distribution, it should almost definitely never exceed 6.
+static constexpr float kZLimit = 6.0;
+
+// A utility function to fill the given array with samples from the given
+// distribution, using the single adatper of the underlying generator
+template <class Distribution>
+void FillRandomsWithSingles(PhiloxRandom gen,
+ typename Distribution::ResultElementType* p,
+ int64 size) {
+ int granularity = Distribution::kResultElementCount;
+
+ CHECK(size % granularity == 0) << " size: " << size
+ << " granularity: " << granularity;
+
+ SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
+
+ Distribution dist;
+ for (int i = 0; i < size; i += granularity) {
+ auto sample = dist(&single_samples);
+ std::copy(&sample[0], &sample[0] + granularity, &p[i]);
+ }
+}
+
+// Check the given array of samples matches the given theoretical moment
+// function at different orders. The test is considered passing if the z-tests
+// of all statistical moments are all below z_limit.
+// typename T in the template argument could be either float or double.
+// Arguments:
+// samples: an array of samples to be tested for their statistical properties;
+// theoretical_moments: a functor that can calculate arbitrary order of
+// of the given distribution;
+// max_moments: the largest moments of the uniform distribution to be tested;
+// stride: the distance between samples to check for statistical properties
+// 0 means the n-th moment of each sample
+// any other strides tests for spatial correlation between samples;
+// z_limit: the maximum z-test we would consider the test to pass;
+template <typename T>
+bool CheckSamplesMoments(const std::vector<T>& samples,
+ std::function<double(int)> theoretical_moments,
+ int max_moments, int stride, T z_limit) {
+ const T* const samples_data = &samples[0];
+ const int samples_size = samples.size();
+ std::vector<double> moments(max_moments + 1);
+ double* const moments_data = &moments[0];
+ std::vector<int> moments_sample_count(max_moments + 1);
+ int* const moments_sample_count_data = &moments_sample_count[0];
+
+ for (int k = 0; k < samples_size; ++k) {
+ double moment = 1.;
+ for (int i = 0; i <= max_moments; ++i) {
+ int index = k + i * stride;
+ if (index >= samples_size) {
+ break;
+ }
+ // moments[i] store the i-th order measured moments.
+ // bypass std::vector::opeartor[] because they are too slow in the debug
+ // mode, given the large number of samples.
+ moments_data[i] += moment;
+ ++moments_sample_count_data[i];
+ moment *= samples_data[index];
+ }
+ }
+
+ // normalize the moments
+ for (int i = 0; i <= max_moments; ++i) {
+ moments[i] /= moments_sample_count[i];
+ }
+
+ bool status = true;
+
+ for (int i = 1; i <= max_moments; ++i) {
+ // Calculate the theoretical mean and variance
+ const double moments_i_mean = (stride == 0)
+ ? theoretical_moments(i)
+ : std::pow(theoretical_moments(1), i);
+ const double moments_i_squared = (stride == 0)
+ ? theoretical_moments(2 * i)
+ : std::pow(theoretical_moments(2), i);
+ const double moments_i_var =
+ moments_i_squared - moments_i_mean * moments_i_mean;
+
+ // assume every operation has a small numerical error.
+ static const double kNumericalError = 1e-6;
+ // it takes i multiplications to calculate one i-th moment.
+ const double error_per_moment = i * kNumericalError;
+ const double total_variance =
+ moments_i_var / moments_sample_count[i] + error_per_moment;
+ // z_test is approximately a unit normal distribution.
+ const double z_test =
+ fabs((moments[i] - moments_i_mean) / sqrt(total_variance));
+
+ if (z_test > z_limit) {
+ LOG(ERROR) << "failing z_test:"
+ << " moment: " << i << " stride: " << stride
+ << " z_test: " << z_test << " z_limit: " << z_limit
+ << " measured moments: " << moments[i]
+ << " theoretical mean of the moments: " << moments_i_mean
+ << " theoretical var of the moments: " << moments_i_var
+ << " sample count: " << moments_sample_count[i];
+ status = false;
+ }
+ }
+
+ return status;
+}
+
+// This tests checks that the generated samples match the theoretical moments
+// of the uniform distribution.
+template <typename T>
+void UniformMomentsTest(int count, int max_moments,
+ const std::vector<int>& strides, T z_limit) {
+ auto uniform_moments = [](int n) -> double { return 1. / (n + 1); };
+
+ std::vector<T> v1(count);
+ uint64 seed = GetTestSeed();
+ PhiloxRandom gen(seed);
+ FillRandoms<UniformDistribution<PhiloxRandom, T> >(gen, &v1[0], v1.size());
+ for (int stride : strides) {
+ bool status = CheckSamplesMoments<T>(v1, uniform_moments, max_moments,
+ stride, z_limit);
+ ASSERT_TRUE(status) << " UniformMomentsTest failing. seed: " << seed;
+ }
+}
+
+// This test checks that the generated samples match the theoretical moments
+// of the unit normal distribution.
+template <typename T>
+void NormalMomentsTest(int count, int max_moments,
+ const std::vector<int>& strides, T z_limit) {
+ auto normal_moments = [](int n) -> double {
+ if (n % 2 == 1) {
+ // For an odd order, the moment of a unit normal distribution is zero.
+ return 0.;
+ } else {
+ // For an even order, the moment of a unit normal distribution is.
+ // (n-1)!!
+ double v = 1.;
+ for (int i = n - 1; i >= 1; i -= 2) {
+ v *= i;
+ }
+ return v;
+ }
+ };
+
+ std::vector<T> v1(count);
+ uint64 seed = GetTestSeed();
+ PhiloxRandom gen(seed);
+ FillRandoms<NormalDistribution<PhiloxRandom, T> >(gen, &v1[0], v1.size());
+
+ for (int stride : strides) {
+ bool status = CheckSamplesMoments<T>(v1, normal_moments, max_moments,
+ stride, z_limit);
+ ASSERT_TRUE(status) << " NormalMomentsTest failing. seed: " << seed;
+ }
+}
+
+// A functor to calculate the moments for the truncated normal distribution.
+// For any odd order, the moment is zero. But for any other n, it can be proven
+// that the following recursive relationship for the moments of the truncated
+// standard normal:
+// m(n) = (n - 1) * m(n - 2) - 2 * v ^ (n - 1) * f(v) / (2 * Phi(v) - 1)
+// where v is the cut-off value, f(v) is the p.d.f of the standard
+// normal, and Phi(v) is the c.d.f of the standard normal.
+class TruncatedNormalMoments {
+ public:
+ double operator()(int n) {
+ if (n == 0) {
+ return 1;
+ }
+ if (n % 2 == 1) {
+ // For an odd order, the moment is always zero
+ return 0.;
+ }
+
+ // Memoization and check the cached results.
+ auto iter = cached_results_.find(n);
+ if (iter != cached_results_.end()) {
+ return iter->second;
+ }
+
+ // The real computation of the moment.
+ double bias = 2.0 * std::pow(kV, n - 1) * kFV / (2.0 * kPhiV - 1.0);
+ double moment_n_minus_2 = (*this)(n - 2);
+ double moment_n = (n - 1) * moment_n_minus_2 - bias;
+
+ cached_results_[n] = moment_n;
+ return moment_n;
+ }
+
+ private:
+ const double kV = 2.0;
+ // f(v), where f is the p.d.f of the normal distribution and v=2.
+ const double kFV = 1.0 / sqrt(2.0 * M_PI) * exp(-kV * kV / 2.0);
+ // The numerical evaluation of Phi(v), where v is the truncate value.
+ // v = 2 in the current implementation.
+ const double kPhiV = 0.977249868051821;
+ std::unordered_map<int, double> cached_results_;
+};
+
+// This test checks that the generated samples matche the theoretical moments
+// of the truncated normal distribution.
+template <typename T>
+void RandomParametersMomentsTest(int count, int max_moments,
+ const std::vector<int>& strides, T z_limit) {
+ std::vector<T> v1(count);
+ uint64 seed = GetTestSeed();
+ PhiloxRandom gen(seed);
+ FillRandomsWithSingles<
+ TruncatedNormalDistribution<SingleSampleAdapter<PhiloxRandom>, T> >(
+ gen, &v1[0], v1.size());
+
+ for (int stride : strides) {
+ bool status = CheckSamplesMoments<T>(v1, TruncatedNormalMoments(),
+ max_moments, stride, z_limit);
+ ASSERT_TRUE(status) << " NormalMomentsTest failing. seed: " << seed;
+ }
+}
+
+TEST(PhiloxRandomTest, UniformFloatMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ UniformMomentsTest<float>(1 << 20, 40, strides, kZLimit);
+}
+
+TEST(PhiloxRandomTest, NormalFloatMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ NormalMomentsTest<float>(8 << 20, 25, strides, kZLimit);
+}
+
+TEST(PhiloxRandomTest, RandomParametersFloatMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ RandomParametersMomentsTest<float>(1 << 20, 40, strides, kZLimit);
+}
+
+TEST(PhiloxRandomTest, UniformDoubleMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ UniformMomentsTest<double>(1 << 20, 40, strides, kZLimit);
+}
+
+TEST(PhiloxRandomTest, NormalDoubleMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ NormalMomentsTest<double>(8 << 20, 25, strides, kZLimit);
+}
+
+TEST(PhiloxRandomTest, RandomParametersDoubleMomentsTest) {
+ const std::vector<int> strides = {0, 1, 4, 17};
+ RandomParametersMomentsTest<double>(1 << 20, 40, strides, kZLimit);
+}
+
+} // namespace
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/random_test.cc b/tensorflow/core/lib/random/random_test.cc
new file mode 100644
index 0000000000..7ed37c8b5e
--- /dev/null
+++ b/tensorflow/core/lib/random/random_test.cc
@@ -0,0 +1,21 @@
+#include "tensorflow/core/lib/random/random.h"
+
+#include <set>
+#include "tensorflow/core/platform/port.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+namespace {
+
+TEST(New64Test, SanityCheck) {
+ std::set<uint64> values;
+ for (int i = 0; i < 1000000; i++) {
+ uint64 x = New64();
+ EXPECT_TRUE(values.insert(x).second) << "duplicate " << x;
+ }
+}
+
+} // namespace
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/simple_philox.cc b/tensorflow/core/lib/random/simple_philox.cc
new file mode 100644
index 0000000000..1035e1f017
--- /dev/null
+++ b/tensorflow/core/lib/random/simple_philox.cc
@@ -0,0 +1,24 @@
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/random/exact_uniform_int.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace random {
+
+uint32 SimplePhilox::Uniform(uint32 n) {
+ return ExactUniformInt<uint32>(n, [this]() { return Rand32(); });
+}
+
+uint64 SimplePhilox::Uniform64(uint64 n) {
+ return ExactUniformInt<uint64>(n, [this]() { return Rand64(); });
+}
+
+uint32 SimplePhilox::Skewed(int max_log) {
+ CHECK(0 <= max_log && max_log <= 32);
+ const int shift = Rand32() % (max_log + 1);
+ const uint32 mask = shift == 32 ? ~static_cast<uint32>(0) : (1 << shift) - 1;
+ return Rand32() & mask;
+}
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/simple_philox.h b/tensorflow/core/lib/random/simple_philox.h
new file mode 100644
index 0000000000..12b15d7616
--- /dev/null
+++ b/tensorflow/core/lib/random/simple_philox.h
@@ -0,0 +1,61 @@
+#ifndef TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_
+#define TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_
+
+#include <math.h>
+#include <string.h>
+#include <algorithm>
+
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+
+namespace tensorflow {
+namespace random {
+
+// A simple imperative interface to Philox
+class SimplePhilox {
+ public:
+ PHILOX_DEVICE_INLINE
+ explicit SimplePhilox(PhiloxRandom* gen) : single_(gen) {}
+
+ // 32 random bits
+ PHILOX_DEVICE_INLINE uint32 Rand32() { return single_(); }
+
+ // 64 random bits
+ PHILOX_DEVICE_INLINE uint64 Rand64() {
+ const uint32 lo = single_(), hi = single_();
+ return lo | static_cast<uint64>(hi) << 32;
+ }
+
+ // Uniform float in [0, 1)
+ PHILOX_DEVICE_INLINE float RandFloat() { return Uint32ToFloat(single_()); }
+
+ // Uniform double in [0, 1)
+ PHILOX_DEVICE_INLINE double RandDouble() {
+ const uint32 x0 = single_(), x1 = single_();
+ return Uint64ToDouble(x0, x1);
+ }
+
+ // Uniform integer in [0, n).
+ // Uses rejection sampling, so may need more than one 32-bit sample.
+ uint32 Uniform(uint32 n);
+
+ // Approximately uniform integer in [0, n).
+ // Uses rejection sampling, so may need more than one 64-bit sample.
+ uint64 Uniform64(uint64 n);
+
+ // True with probability 1/n.
+ bool OneIn(uint32 n) { return Uniform(n) == 0; }
+
+ // Skewed: pick "base" uniformly from range [0,max_log] and then
+ // return "base" random bits. The effect is to pick a number in the
+ // range [0,2^max_log-1] with bias towards smaller numbers.
+ uint32 Skewed(int max_log);
+
+ private:
+ SingleSampleAdapter<PhiloxRandom> single_;
+};
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_
diff --git a/tensorflow/core/lib/random/simple_philox_test.cc b/tensorflow/core/lib/random/simple_philox_test.cc
new file mode 100644
index 0000000000..4246b8b4dd
--- /dev/null
+++ b/tensorflow/core/lib/random/simple_philox_test.cc
@@ -0,0 +1,120 @@
+#include "tensorflow/core/lib/random/simple_philox.h"
+
+#include <set>
+#include <string>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+namespace {
+
+TEST(SimplePhiloxTest, FloatTest) {
+ PhiloxRandom philox(7, 7);
+ SimplePhilox gen(&philox);
+ static const int kIters = 1000000;
+ for (int i = 0; i < kIters; ++i) {
+ float f = gen.RandFloat();
+ EXPECT_LE(0.0f, f);
+ EXPECT_GT(1.0f, f);
+ }
+ for (int i = 0; i < kIters; ++i) {
+ double d = gen.RandDouble();
+ EXPECT_LE(0.0, d);
+ EXPECT_GT(1.0, d);
+ }
+}
+
+static void DifferenceTest(const char *names, SimplePhilox *gen1,
+ SimplePhilox *gen2) {
+ static const int kIters = 100;
+ bool different = false;
+ for (int i = 0; i < kIters; ++i) {
+ if (gen1->Rand32() != gen2->Rand32()) {
+ different = true;
+ break;
+ }
+ }
+ CHECK(different) << "different seeds but same output!";
+}
+
+TEST(SimplePhiloxTest, DifferenceTest) {
+ PhiloxRandom philox1(1, 1), philox2(17, 17);
+ SimplePhilox gen1(&philox1), gen2(&philox2);
+
+ DifferenceTest("SimplePhilox: different seeds", &gen1, &gen2);
+}
+
+TEST(SimplePhiloxTest, DifferenceTestCloseSeeds) {
+ PhiloxRandom philox1(1, 1), philox2(2, 1);
+ SimplePhilox gen1(&philox1), gen2(&philox2);
+
+ DifferenceTest("SimplePhilox: close seeds", &gen1, &gen2);
+}
+
+TEST(SimplePhiloxTest, Regression_CloseSeedsAreDifferent) {
+ const int kCount = 1000;
+
+ // Two seeds differ only by the last bit.
+ PhiloxRandom philox1(0, 1), philox2(1, 1);
+ SimplePhilox gen1(&philox1), gen2(&philox2);
+
+ std::set<uint32> first;
+ std::set<uint32> all;
+ for (int i = 0; i < kCount; ++i) {
+ uint32 v = gen1.Rand32();
+ first.insert(v);
+ all.insert(v);
+ all.insert(gen2.Rand32());
+ }
+
+ // Broken array initialization implementation (before 2009-08-18) using the
+ // above seeds return <1000, 1007>, generating output that is >99% similar.
+ // The fix returns <1000, 2000> for completely disjoint sets.
+ EXPECT_EQ(kCount, first.size());
+ EXPECT_EQ(2 * kCount, all.size());
+}
+
+TEST(SimplePhiloxTest, TestUniform) {
+ PhiloxRandom philox(17, 17);
+ SimplePhilox gen(&philox);
+
+ uint32 range = 3 * (1L << 29);
+ uint32 threshold = 1L << 30;
+
+ size_t count = 0;
+ static const int kTrials = 100000;
+ for (int i = 0; i < kTrials; ++i) {
+ uint32 rnd = gen.Uniform(range);
+ if (rnd < threshold) {
+ ++count;
+ }
+ }
+
+ EXPECT_LT(fabs((threshold + 0.0) / range - (count + 0.0) / kTrials), 0.005);
+}
+
+TEST(SimplePhiloxTest, TestUniform64) {
+ PhiloxRandom philox(17, 17);
+ SimplePhilox gen(&philox);
+
+ uint64 range = 3 * (1LL << 59);
+ uint64 threshold = 1LL << 60;
+
+ size_t count = 0;
+ static const int kTrials = 100000;
+ for (int i = 0; i < kTrials; ++i) {
+ uint64 rnd = gen.Uniform64(range);
+ if (rnd < threshold) {
+ ++count;
+ }
+ }
+
+ EXPECT_LT(fabs((threshold + 0.0) / range - (count + 0.0) / kTrials), 0.005);
+}
+
+} // namespace
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/weighted_picker.cc b/tensorflow/core/lib/random/weighted_picker.cc
new file mode 100644
index 0000000000..f96da578ec
--- /dev/null
+++ b/tensorflow/core/lib/random/weighted_picker.cc
@@ -0,0 +1,203 @@
+#include "tensorflow/core/lib/random/weighted_picker.h"
+
+#include <string.h>
+#include <algorithm>
+
+#include "tensorflow/core/lib/random/simple_philox.h"
+
+namespace tensorflow {
+namespace random {
+
+WeightedPicker::WeightedPicker(int N) {
+ CHECK_GE(N, 0);
+ N_ = N;
+
+ // Find the number of levels
+ num_levels_ = 1;
+ while (LevelSize(num_levels_ - 1) < N) {
+ num_levels_++;
+ }
+
+ // Initialize the levels
+ level_ = new int32*[num_levels_];
+ for (int l = 0; l < num_levels_; l++) {
+ level_[l] = new int32[LevelSize(l)];
+ }
+
+ SetAllWeights(1);
+}
+
+WeightedPicker::~WeightedPicker() {
+ for (int l = 0; l < num_levels_; l++) {
+ delete[] level_[l];
+ }
+ delete[] level_;
+}
+
+static int32 UnbiasedUniform(SimplePhilox* r, int32 n) {
+ CHECK_LE(0, n);
+ const uint32 range = ~static_cast<uint32>(0);
+ if (n == 0) {
+ return r->Rand32() * n;
+ } else if (0 == (n & (n - 1))) {
+ // N is a power of two, so just mask off the lower bits.
+ return r->Rand32() & (n - 1);
+ } else {
+ // Reject all numbers that skew the distribution towards 0.
+
+ // Rand32's output is uniform in the half-open interval [0, 2^{32}).
+ // For any interval [m,n), the number of elements in it is n-m.
+
+ uint32 rem = (range % n) + 1;
+ uint32 rnd;
+
+ // rem = ((2^{32}-1) \bmod n) + 1
+ // 1 <= rem <= n
+
+ // NB: rem == n is impossible, since n is not a power of 2 (from
+ // earlier check).
+
+ do {
+ rnd = r->Rand32(); // rnd uniform over [0, 2^{32})
+ } while (rnd < rem); // reject [0, rem)
+ // rnd is uniform over [rem, 2^{32})
+ //
+ // The number of elements in the half-open interval is
+ //
+ // 2^{32} - rem = 2^{32} - ((2^{32}-1) \bmod n) - 1
+ // = 2^{32}-1 - ((2^{32}-1) \bmod n)
+ // = n \cdot \lfloor (2^{32}-1)/n \rfloor
+ //
+ // therefore n evenly divides the number of integers in the
+ // interval.
+ //
+ // The function v \rightarrow v % n takes values from [bias,
+ // 2^{32}) to [0, n). Each integer in the range interval [0, n)
+ // will have exactly \lfloor (2^{32}-1)/n \rfloor preimages from
+ // the domain interval.
+ //
+ // Therefore, v % n is uniform over [0, n). QED.
+
+ return rnd % n;
+ }
+}
+
+int WeightedPicker::Pick(SimplePhilox* rnd) const {
+ if (total_weight() == 0) return -1;
+
+ // using unbiased uniform distribution to avoid bias
+ // toward low elements resulting from a possible use
+ // of big weights.
+ return PickAt(UnbiasedUniform(rnd, total_weight()));
+}
+
+int WeightedPicker::PickAt(int32 weight_index) const {
+ if (weight_index < 0 || weight_index >= total_weight()) return -1;
+
+ int32 position = weight_index;
+ int index = 0;
+
+ for (int l = 1; l < num_levels_; l++) {
+ // Pick left or right child of "level_[l-1][index]"
+ const int32 left_weight = level_[l][2 * index];
+ if (position < left_weight) {
+ // Descend to left child
+ index = 2 * index;
+ } else {
+ // Descend to right child
+ index = 2 * index + 1;
+ position -= left_weight;
+ }
+ }
+ CHECK_GE(index, 0);
+ CHECK_LT(index, N_);
+ CHECK_LE(position, level_[num_levels_ - 1][index]);
+ return index;
+}
+
+void WeightedPicker::set_weight(int index, int32 weight) {
+ assert(index >= 0);
+ assert(index < N_);
+
+ // Adjust the sums all the way up to the root
+ const int32 delta = weight - get_weight(index);
+ for (int l = num_levels_ - 1; l >= 0; l--) {
+ level_[l][index] += delta;
+ index >>= 1;
+ }
+}
+
+void WeightedPicker::SetAllWeights(int32 weight) {
+ // Initialize leaves
+ int32* leaves = level_[num_levels_ - 1];
+ for (int i = 0; i < N_; i++) leaves[i] = weight;
+ for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0;
+
+ // Now sum up towards the root
+ RebuildTreeWeights();
+}
+
+void WeightedPicker::SetWeightsFromArray(int N, const int32* weights) {
+ Resize(N);
+
+ // Initialize leaves
+ int32* leaves = level_[num_levels_ - 1];
+ for (int i = 0; i < N_; i++) leaves[i] = weights[i];
+ for (int i = N_; i < LevelSize(num_levels_ - 1); i++) leaves[i] = 0;
+
+ // Now sum up towards the root
+ RebuildTreeWeights();
+}
+
+void WeightedPicker::RebuildTreeWeights() {
+ for (int l = num_levels_ - 2; l >= 0; l--) {
+ int32* level = level_[l];
+ int32* children = level_[l + 1];
+ for (int i = 0; i < LevelSize(l); i++) {
+ level[i] = children[2 * i] + children[2 * i + 1];
+ }
+ }
+}
+
+void WeightedPicker::Append(int32 weight) {
+ Resize(num_elements() + 1);
+ set_weight(num_elements() - 1, weight);
+}
+
+void WeightedPicker::Resize(int new_size) {
+ CHECK_GE(new_size, 0);
+ if (new_size <= LevelSize(num_levels_ - 1)) {
+ // The new picker fits in the existing levels.
+
+ // First zero out any of the weights that are being dropped so
+ // that the levels are correct (only needed when shrinking)
+ for (int i = new_size; i < N_; i++) {
+ set_weight(i, 0);
+ }
+
+ // We do not need to set any new weights when enlarging because
+ // the unneeded entries always have weight zero.
+ N_ = new_size;
+ return;
+ }
+
+ // We follow the simple strategy of just copying the old
+ // WeightedPicker into a new WeightedPicker. The cost is
+ // O(N) regardless.
+ assert(new_size > N_);
+ WeightedPicker new_picker(new_size);
+ int32* dst = new_picker.level_[new_picker.num_levels_ - 1];
+ int32* src = this->level_[this->num_levels_ - 1];
+ memcpy(dst, src, sizeof(dst[0]) * N_);
+ memset(dst + N_, 0, sizeof(dst[0]) * (new_size - N_));
+ new_picker.RebuildTreeWeights();
+
+ // Now swap the two pickers
+ std::swap(new_picker.N_, this->N_);
+ std::swap(new_picker.num_levels_, this->num_levels_);
+ std::swap(new_picker.level_, this->level_);
+ assert(this->N_ == new_size);
+}
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/random/weighted_picker.h b/tensorflow/core/lib/random/weighted_picker.h
new file mode 100644
index 0000000000..3d2c2dbb39
--- /dev/null
+++ b/tensorflow/core/lib/random/weighted_picker.h
@@ -0,0 +1,118 @@
+
+// An abstraction to pick from one of N elements with a specified
+// weight per element.
+//
+// The weight for a given element can be changed in O(lg N) time
+// An element can be picked in O(lg N) time.
+//
+// Uses O(N) bytes of memory.
+//
+// Alternative: distribution-sampler.h allows O(1) time picking, but no weight
+// adjustment after construction.
+
+#ifndef TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
+#define TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
+
+#include <assert.h>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace random {
+
+class SimplePhilox;
+
+class WeightedPicker {
+ public:
+ // REQUIRES N >= 0
+ // Initializes the elements with a weight of one per element
+ explicit WeightedPicker(int N);
+
+ // Releases all resources
+ ~WeightedPicker();
+
+ // Pick a random element with probability proportional to its weight.
+ // If total weight is zero, returns -1.
+ int Pick(SimplePhilox* rnd) const;
+
+ // Deterministically pick element x whose weight covers the
+ // specified weight_index.
+ // Returns -1 if weight_index is not in the range [ 0 .. total_weight()-1 ]
+ int PickAt(int32 weight_index) const;
+
+ // Get the weight associated with an element
+ // REQUIRES 0 <= index < N
+ int32 get_weight(int index) const;
+
+ // Set the weight associated with an element
+ // REQUIRES weight >= 0.0f
+ // REQUIRES 0 <= index < N
+ void set_weight(int index, int32 weight);
+
+ // Get the total combined weight of all elements
+ int32 total_weight() const;
+
+ // Get the number of elements in the picker
+ int num_elements() const;
+
+ // Set weight of each element to "weight"
+ void SetAllWeights(int32 weight);
+
+ // Resizes the picker to N and
+ // sets the weight of each element i to weight[i].
+ // The sum of the weights should not exceed 2^31 - 2
+ // Complexity O(N).
+ void SetWeightsFromArray(int N, const int32* weights);
+
+ // REQUIRES N >= 0
+ //
+ // Resize the weighted picker so that it has "N" elements.
+ // Any newly added entries have zero weight.
+ //
+ // Note: Resizing to a smaller size than num_elements() will
+ // not reclaim any memory. If you wish to reduce memory usage,
+ // allocate a new WeightedPicker of the appropriate size.
+ //
+ // It is efficient to use repeated calls to Resize(num_elements() + 1)
+ // to grow the picker to size X (takes total time O(X)).
+ void Resize(int N);
+
+ // Grow the picker by one and set the weight of the new entry to "weight".
+ //
+ // Repeated calls to Append() in order to grow the
+ // picker to size X takes a total time of O(X lg(X)).
+ // Consider using SetWeightsFromArray instead.
+ void Append(int32 weight);
+
+ private:
+ // We keep a binary tree with N leaves. The "i"th leaf contains
+ // the weight of the "i"th element. An internal node contains
+ // the sum of the weights of its children.
+ int N_; // Number of elements
+ int num_levels_; // Number of levels in tree (level-0 is root)
+ int32** level_; // Array that holds nodes per level
+
+ // Size of each level
+ static int LevelSize(int level) { return 1 << level; }
+
+ // Rebuild the tree weights using the leaf weights
+ void RebuildTreeWeights();
+
+ TF_DISALLOW_COPY_AND_ASSIGN(WeightedPicker);
+};
+
+inline int32 WeightedPicker::get_weight(int index) const {
+ DCHECK_GE(index, 0);
+ DCHECK_LT(index, N_);
+ return level_[num_levels_ - 1][index];
+}
+
+inline int32 WeightedPicker::total_weight() const { return level_[0][0]; }
+
+inline int WeightedPicker::num_elements() const { return N_; }
+
+} // namespace random
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
diff --git a/tensorflow/core/lib/random/weighted_picker_test.cc b/tensorflow/core/lib/random/weighted_picker_test.cc
new file mode 100644
index 0000000000..0b27d437d5
--- /dev/null
+++ b/tensorflow/core/lib/random/weighted_picker_test.cc
@@ -0,0 +1,254 @@
+#include "tensorflow/core/lib/random/weighted_picker.h"
+
+#include <string.h>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace random {
+
+static void TestPicker(SimplePhilox* rnd, int size);
+static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker, int trials);
+static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials);
+static void TestPickAt(int items, const int32* weights);
+
+TEST(WeightedPicker, Simple) {
+ PhiloxRandom philox(testing::RandomSeed(), 17);
+ SimplePhilox rnd(&philox);
+
+ {
+ VLOG(0) << "======= Zero-length picker";
+ WeightedPicker picker(0);
+ EXPECT_EQ(picker.Pick(&rnd), -1);
+ }
+
+ {
+ VLOG(0) << "======= Singleton picker";
+ WeightedPicker picker(1);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ }
+
+ {
+ VLOG(0) << "======= Grown picker";
+ WeightedPicker picker(0);
+ for (int i = 0; i < 10; i++) {
+ picker.Append(1);
+ }
+ CheckUniform(&rnd, &picker, 100000);
+ }
+
+ {
+ VLOG(0) << "======= Grown picker with zero weights";
+ WeightedPicker picker(1);
+ picker.Resize(10);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ }
+
+ {
+ VLOG(0) << "======= Shrink picker and check weights";
+ WeightedPicker picker(1);
+ picker.Resize(10);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ EXPECT_EQ(picker.Pick(&rnd), 0);
+ for (int i = 0; i < 10; i++) {
+ picker.set_weight(i, i);
+ }
+ EXPECT_EQ(picker.total_weight(), 45);
+ picker.Resize(5);
+ EXPECT_EQ(picker.total_weight(), 10);
+ picker.Resize(2);
+ EXPECT_EQ(picker.total_weight(), 1);
+ picker.Resize(1);
+ EXPECT_EQ(picker.total_weight(), 0);
+ }
+}
+
+TEST(WeightedPicker, BigWeights) {
+ PhiloxRandom philox(testing::RandomSeed() + 1, 17);
+ SimplePhilox rnd(&philox);
+ VLOG(0) << "======= Check uniform with big weights";
+ WeightedPicker picker(2);
+ picker.SetAllWeights(2147483646L / 3); // (2^31 - 2) / 3
+ CheckUniform(&rnd, &picker, 100000);
+}
+
+TEST(WeightedPicker, Deterministic) {
+ VLOG(0) << "======= Testing deterministic pick";
+ static const int32 weights[] = {1, 0, 200, 5, 42};
+ TestPickAt(TF_ARRAYSIZE(weights), weights);
+}
+
+TEST(WeightedPicker, Randomized) {
+ PhiloxRandom philox(testing::RandomSeed() + 10, 17);
+ SimplePhilox rnd(&philox);
+ TestPicker(&rnd, 1);
+ TestPicker(&rnd, 2);
+ TestPicker(&rnd, 3);
+ TestPicker(&rnd, 4);
+ TestPicker(&rnd, 7);
+ TestPicker(&rnd, 8);
+ TestPicker(&rnd, 9);
+ TestPicker(&rnd, 10);
+ TestPicker(&rnd, 100);
+}
+
+static void TestPicker(SimplePhilox* rnd, int size) {
+ VLOG(0) << "======= Testing size " << size;
+
+ // Check that empty picker returns -1
+ {
+ WeightedPicker picker(size);
+ picker.SetAllWeights(0);
+ for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), -1);
+ }
+
+ // Create zero weights array
+ std::vector<int32> weights(size);
+ for (int elem = 0; elem < size; elem++) {
+ weights[elem] = 0;
+ }
+
+ // Check that singleton picker always returns the same element
+ for (int elem = 0; elem < size; elem++) {
+ WeightedPicker picker(size);
+ picker.SetAllWeights(0);
+ picker.set_weight(elem, elem + 1);
+ for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem);
+ weights[elem] = 10;
+ picker.SetWeightsFromArray(size, &weights[0]);
+ for (int i = 0; i < 100; i++) EXPECT_EQ(picker.Pick(rnd), elem);
+ weights[elem] = 0;
+ }
+
+ // Check that uniform picker generates elements roughly uniformly
+ {
+ WeightedPicker picker(size);
+ CheckUniform(rnd, &picker, 100000);
+ }
+
+ // Check uniform picker that was grown piecemeal
+ if (size / 3 > 0) {
+ WeightedPicker picker(size / 3);
+ while (picker.num_elements() != size) {
+ picker.Append(1);
+ }
+ CheckUniform(rnd, &picker, 100000);
+ }
+
+ // Check that skewed distribution works
+ if (size <= 10) {
+ // When picker grows one element at a time
+ WeightedPicker picker(size);
+ int32 weight = 1;
+ for (int elem = 0; elem < size; elem++) {
+ picker.set_weight(elem, weight);
+ weights[elem] = weight;
+ weight *= 2;
+ }
+ CheckSkewed(rnd, &picker, 1000000);
+
+ // When picker is created from an array
+ WeightedPicker array_picker(0);
+ array_picker.SetWeightsFromArray(size, &weights[0]);
+ CheckSkewed(rnd, &array_picker, 1000000);
+ }
+}
+
+static void CheckUniform(SimplePhilox* rnd, WeightedPicker* picker,
+ int trials) {
+ const int size = picker->num_elements();
+ int* count = new int[size];
+ memset(count, 0, sizeof(count[0]) * size);
+ for (int i = 0; i < size * trials; i++) {
+ const int elem = picker->Pick(rnd);
+ EXPECT_GE(elem, 0);
+ EXPECT_LT(elem, size);
+ count[elem]++;
+ }
+ const int expected_min = int(0.9 * trials);
+ const int expected_max = int(1.1 * trials);
+ for (int i = 0; i < size; i++) {
+ EXPECT_GE(count[i], expected_min);
+ EXPECT_LE(count[i], expected_max);
+ }
+ delete[] count;
+}
+
+static void CheckSkewed(SimplePhilox* rnd, WeightedPicker* picker, int trials) {
+ const int size = picker->num_elements();
+ int* count = new int[size];
+ memset(count, 0, sizeof(count[0]) * size);
+ for (int i = 0; i < size * trials; i++) {
+ const int elem = picker->Pick(rnd);
+ EXPECT_GE(elem, 0);
+ EXPECT_LT(elem, size);
+ count[elem]++;
+ }
+
+ for (int i = 0; i < size - 1; i++) {
+ LOG(INFO) << i << ": " << count[i];
+ const float ratio = float(count[i + 1]) / float(count[i]);
+ EXPECT_GE(ratio, 1.6f);
+ EXPECT_LE(ratio, 2.4f);
+ }
+ delete[] count;
+}
+
+static void TestPickAt(int items, const int32* weights) {
+ WeightedPicker picker(items);
+ picker.SetWeightsFromArray(items, weights);
+ int weight_index = 0;
+ for (int i = 0; i < items; ++i) {
+ for (int j = 0; j < weights[i]; ++j) {
+ int pick = picker.PickAt(weight_index);
+ EXPECT_EQ(pick, i);
+ ++weight_index;
+ }
+ }
+ EXPECT_EQ(weight_index, picker.total_weight());
+}
+
+static void BM_Create(int iters, int arg) {
+ while (--iters > 0) {
+ WeightedPicker p(arg);
+ }
+}
+BENCHMARK(BM_Create)->Range(1, 1024);
+
+static void BM_CreateAndSetWeights(int iters, int arg) {
+ std::vector<int32> weights(arg);
+ for (int i = 0; i < arg; i++) {
+ weights[i] = i * 10;
+ }
+ while (--iters > 0) {
+ WeightedPicker p(arg);
+ p.SetWeightsFromArray(arg, &weights[0]);
+ }
+}
+BENCHMARK(BM_CreateAndSetWeights)->Range(1, 1024);
+
+static void BM_Pick(int iters, int arg) {
+ PhiloxRandom philox(301, 17);
+ SimplePhilox rnd(&philox);
+ WeightedPicker p(arg);
+ int result = 0;
+ while (--iters > 0) {
+ result += p.Pick(&rnd);
+ }
+ VLOG(4) << result; // Dummy use
+}
+BENCHMARK(BM_Pick)->Range(1, 1024);
+
+} // namespace random
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc
new file mode 100644
index 0000000000..d61129fb3f
--- /dev/null
+++ b/tensorflow/core/lib/strings/numbers.cc
@@ -0,0 +1,260 @@
+#include "tensorflow/core/lib/strings/numbers.h"
+
+#include <float.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <algorithm>
+#include <cmath>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace strings {
+
+char* FastInt32ToBufferLeft(int32 i, char* buffer) {
+ uint32 u = i;
+ if (i < 0) {
+ *buffer++ = '-';
+ // We need to do the negation in modular (i.e., "unsigned")
+ // arithmetic; MSVC++ apprently warns for plain "-u", so
+ // we write the equivalent expression "0 - u" instead.
+ u = 0 - u;
+ }
+ return FastUInt32ToBufferLeft(u, buffer);
+}
+
+char* FastUInt32ToBufferLeft(uint32 i, char* buffer) {
+ char* start = buffer;
+ do {
+ *buffer++ = ((i % 10) + '0');
+ i /= 10;
+ } while (i > 0);
+ *buffer = 0;
+ std::reverse(start, buffer);
+ return buffer;
+}
+
+char* FastInt64ToBufferLeft(int64 i, char* buffer) {
+ uint64 u = i;
+ if (i < 0) {
+ *buffer++ = '-';
+ u = 0 - u;
+ }
+ return FastUInt64ToBufferLeft(u, buffer);
+}
+
+char* FastUInt64ToBufferLeft(uint64 i, char* buffer) {
+ char* start = buffer;
+ do {
+ *buffer++ = ((i % 10) + '0');
+ i /= 10;
+ } while (i > 0);
+ *buffer = 0;
+ std::reverse(start, buffer);
+ return buffer;
+}
+
+static const double kDoublePrecisionCheckMax = DBL_MAX / 1.000000000000001;
+
+char* DoubleToBuffer(double value, char* buffer) {
+ // DBL_DIG is 15 for IEEE-754 doubles, which are used on almost all
+ // platforms these days. Just in case some system exists where DBL_DIG
+ // is significantly larger -- and risks overflowing our buffer -- we have
+ // this assert.
+ static_assert(DBL_DIG < 20, "DBL_DIG is too big");
+
+ bool full_precision_needed = true;
+ if (std::abs(value) <= kDoublePrecisionCheckMax) {
+ int snprintf_result =
+ snprintf(buffer, kFastToBufferSize, "%.*g", DBL_DIG, value);
+
+ // The snprintf should never overflow because the buffer is significantly
+ // larger than the precision we asked for.
+ DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
+
+ full_precision_needed = strtod(buffer, NULL) != value;
+ }
+
+ if (full_precision_needed) {
+ int snprintf_result =
+ snprintf(buffer, kFastToBufferSize, "%.*g", DBL_DIG + 2, value);
+
+ // Should never overflow; see above.
+ DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
+ }
+ return buffer;
+}
+
+bool safe_strto64(const char* str, int64* value) {
+ if (!str) return false;
+
+ // Skip leading space.
+ while (isspace(*str)) ++str;
+
+ int64 vlimit = kint64max;
+ int sign = 1;
+ if (*str == '-') {
+ sign = -1;
+ ++str;
+ // Different limit for positive and negative integers.
+ vlimit = kint64min;
+ }
+
+ if (!isdigit(*str)) return false;
+
+ int64 result = 0;
+ if (sign == 1) {
+ do {
+ int digit = *str - '0';
+ if ((vlimit - digit) / 10 < result) {
+ return false;
+ }
+ result = result * 10 + digit;
+ ++str;
+ } while (isdigit(*str));
+ } else {
+ do {
+ int digit = *str - '0';
+ if ((vlimit + digit) / 10 > result) {
+ return false;
+ }
+ result = result * 10 - digit;
+ ++str;
+ } while (isdigit(*str));
+ }
+
+ // Skip trailing space.
+ while (isspace(*str)) ++str;
+
+ if (*str) return false;
+
+ *value = result;
+ return true;
+}
+
+bool safe_strto32(const char* str, int32* value) {
+ if (!str) return false;
+
+ // Skip leading space.
+ while (isspace(*str)) ++str;
+
+ int64 vmax = kint32max;
+ int sign = 1;
+ if (*str == '-') {
+ sign = -1;
+ ++str;
+ // Different max for positive and negative integers.
+ ++vmax;
+ }
+
+ if (!isdigit(*str)) return false;
+
+ int64 result = 0;
+ do {
+ result = result * 10 + *str - '0';
+ if (result > vmax) {
+ return false;
+ }
+ ++str;
+ } while (isdigit(*str));
+
+ // Skip trailing space.
+ while (isspace(*str)) ++str;
+
+ if (*str) return false;
+
+ *value = result * sign;
+ return true;
+}
+
+bool safe_strtof(const char* str, float* value) {
+ char* endptr;
+ *value = strtof(str, &endptr);
+ while (isspace(*endptr)) ++endptr;
+ // Ignore range errors from strtod/strtof.
+ // The values it returns on underflow and
+ // overflow are the right fallback in a
+ // robust setting.
+ return *str != '\0' && *endptr == '\0';
+}
+
+char* FloatToBuffer(float value, char* buffer) {
+ // FLT_DIG is 6 for IEEE-754 floats, which are used on almost all
+ // platforms these days. Just in case some system exists where FLT_DIG
+ // is significantly larger -- and risks overflowing our buffer -- we have
+ // this assert.
+ static_assert(FLT_DIG < 10, "FLT_DIG is too big");
+
+ int snprintf_result =
+ snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG, value);
+
+ // The snprintf should never overflow because the buffer is significantly
+ // larger than the precision we asked for.
+ DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
+
+ float parsed_value;
+ if (!safe_strtof(buffer, &parsed_value) || parsed_value != value) {
+ snprintf_result =
+ snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG + 2, value);
+
+ // Should never overflow; see above.
+ DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
+ }
+ return buffer;
+}
+
+string FpToString(Fprint fp) {
+ char buf[17];
+ snprintf(buf, sizeof(buf), "%016llx", static_cast<uint64>(fp));
+ return string(buf);
+}
+
+bool StringToFp(const string& s, Fprint* fp) {
+ char junk;
+ uint64 result;
+ if (sscanf(s.c_str(), "%llx%c", &result, &junk) == 1) {
+ *fp = result;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+string HumanReadableNumBytes(int64 num_bytes) {
+ if (num_bytes == kint64min) {
+ // Special case for number with not representable negation.
+ return "-8E";
+ }
+
+ const char* neg_str = (num_bytes < 0) ? "-" : "";
+ if (num_bytes < 0) {
+ num_bytes = -num_bytes;
+ }
+
+ // Special case for bytes.
+ if (num_bytes < 1024) {
+ // No fractions for bytes.
+ char buf[8]; // Longest possible string is '-XXXXB'
+ snprintf(buf, sizeof(buf), "%s%lldB", neg_str,
+ static_cast<int64>(num_bytes));
+ return string(buf);
+ }
+
+ static const char units[] = "KMGTPE"; // int64 only goes up to E.
+ const char* unit = units;
+ while (num_bytes >= static_cast<int64>(1024) * 1024) {
+ num_bytes /= 1024;
+ ++unit;
+ CHECK(unit < units + TF_ARRAYSIZE(units));
+ }
+
+ // We use SI prefixes.
+ char buf[16];
+ snprintf(buf, sizeof(buf), ((*unit == 'K') ? "%s%.1f%ciB" : "%s%.2f%ciB"),
+ neg_str, num_bytes / 1024.0, *unit);
+ return string(buf);
+}
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h
new file mode 100644
index 0000000000..a30a862279
--- /dev/null
+++ b/tensorflow/core/lib/strings/numbers.h
@@ -0,0 +1,92 @@
+#ifndef TENSORFLOW_LIB_STRINGS_NUMBERS_H_
+#define TENSORFLOW_LIB_STRINGS_NUMBERS_H_
+
+#include <string>
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace strings {
+
+// ----------------------------------------------------------------------
+// FastIntToBufferLeft()
+// These are intended for speed.
+//
+// All functions take the output buffer as an arg. FastInt() uses
+// at most 22 bytes, FastTime() uses exactly 30 bytes. They all
+// return a pointer to the beginning of the output, which is the same as
+// the beginning of the input buffer.
+//
+// NOTE: In 64-bit land, sizeof(time_t) is 8, so it is possible
+// to pass to FastTimeToBuffer() a time whose year cannot be
+// represented in 4 digits. In this case, the output buffer
+// will contain the string "Invalid:<value>"
+// ----------------------------------------------------------------------
+
+// Previously documented minimums -- the buffers provided must be at least this
+// long, though these numbers are subject to change:
+// Int32, UInt32: 12 bytes
+// Int64, UInt64, Int, Uint: 22 bytes
+// Time: 30 bytes
+// Use kFastToBufferSize rather than hardcoding constants.
+static const int kFastToBufferSize = 32;
+
+// ----------------------------------------------------------------------
+// FastInt32ToBufferLeft()
+// FastUInt32ToBufferLeft()
+// FastInt64ToBufferLeft()
+// FastUInt64ToBufferLeft()
+//
+// These functions convert their numeric argument to an ASCII
+// representation of the numeric value in base 10, with the
+// representation being left-aligned in the buffer. The caller is
+// responsible for ensuring that the buffer has enough space to hold
+// the output. The buffer should typically be at least kFastToBufferSize
+// bytes.
+//
+// Returns a pointer to the end of the string (i.e. the null character
+// terminating the string).
+// ----------------------------------------------------------------------
+
+char* FastInt32ToBufferLeft(int32 i, char* buffer); // at least 12 bytes
+char* FastUInt32ToBufferLeft(uint32 i, char* buffer); // at least 12 bytes
+char* FastInt64ToBufferLeft(int64 i, char* buffer); // at least 22 bytes
+char* FastUInt64ToBufferLeft(uint64 i, char* buffer); // at least 22 bytes
+
+// Required buffer size for DoubleToBuffer is kFastToBufferSize.
+// Required buffer size for FloatToBuffer is kFastToBufferSize.
+char* DoubleToBuffer(double i, char* buffer);
+char* FloatToBuffer(float i, char* buffer);
+
+// Convert a 64-bit fingerprint value to an ASCII representation.
+string FpToString(Fprint fp);
+
+// Attempt to parse a fingerprint in the form encoded by FpToString. If
+// successsful, stores the fingerprint in *fp and returns true. Otherwise,
+// returns false.
+bool StringToFp(const string& s, Fprint* fp);
+
+// Convert strings to 32bit integer values.
+// Leading and trailing spaces are allowed.
+// Return false with overflow or invalid input.
+bool safe_strto32(const char* str, int32* value);
+
+// Convert strings to 64bit integer values.
+// Leading and trailing spaces are allowed.
+// Return false with overflow or invalid input.
+bool safe_strto64(const char* str, int64* value);
+
+// Convert strings to floating point values.
+// Leading and trailing spaces are allowed.
+// Values may be rounded on over- and underflow.
+bool safe_strtof(const char* str, float* value);
+
+// Converts from an int64 representing a number of bytes to a
+// human readable string representing the same number.
+// e.g. 12345678 -> "11.77MiB".
+string HumanReadableNumBytes(int64 num_bytes);
+
+} // namespace strings
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_STRINGS_NUMBERS_H_
diff --git a/tensorflow/core/lib/strings/numbers_test.cc b/tensorflow/core/lib/strings/numbers_test.cc
new file mode 100644
index 0000000000..b178e6af53
--- /dev/null
+++ b/tensorflow/core/lib/strings/numbers_test.cc
@@ -0,0 +1,113 @@
+#include "tensorflow/core/lib/strings/numbers.h"
+
+#include <string>
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace strings {
+
+// NOTE: most of the routines in numbers.h are tested indirectly through
+// strcat_test.cc in this directory.
+
+// Test StrCat of ints and longs of various sizes and signdedness.
+TEST(FpToString, Ints) {
+ for (int s = 0; s < 64; s++) {
+ for (int delta = -1; delta <= 1; delta++) {
+ uint64 fp = (1ull << s) + delta;
+ string s = FpToString(fp);
+ uint64 fp2;
+ EXPECT_TRUE(StringToFp(s, &fp2));
+ EXPECT_EQ(fp, fp2);
+ }
+ }
+ Fprint dummy;
+ EXPECT_FALSE(StringToFp("", &dummy));
+ EXPECT_FALSE(StringToFp("xyz", &dummy));
+ EXPECT_FALSE(StringToFp("0000000000000000xyz", &dummy));
+}
+
+TEST(HumanReadableNumBytes, Bytes) {
+ EXPECT_EQ("0B", HumanReadableNumBytes(0));
+ EXPECT_EQ("4B", HumanReadableNumBytes(4));
+ EXPECT_EQ("1023B", HumanReadableNumBytes(1023));
+
+ EXPECT_EQ("1.0KiB", HumanReadableNumBytes(1024));
+ EXPECT_EQ("1.0KiB", HumanReadableNumBytes(1025));
+ EXPECT_EQ("1.5KiB", HumanReadableNumBytes(1500));
+ EXPECT_EQ("1.9KiB", HumanReadableNumBytes(1927));
+
+ EXPECT_EQ("2.0KiB", HumanReadableNumBytes(2048));
+ EXPECT_EQ("1.00MiB", HumanReadableNumBytes(1 << 20));
+ EXPECT_EQ("11.77MiB", HumanReadableNumBytes(12345678));
+ EXPECT_EQ("1.00GiB", HumanReadableNumBytes(1 << 30));
+
+ EXPECT_EQ("1.00TiB", HumanReadableNumBytes(1LL << 40));
+ EXPECT_EQ("1.00PiB", HumanReadableNumBytes(1LL << 50));
+ EXPECT_EQ("1.00EiB", HumanReadableNumBytes(1LL << 60));
+
+ // Try a few negative numbers
+ EXPECT_EQ("-1B", HumanReadableNumBytes(-1));
+ EXPECT_EQ("-4B", HumanReadableNumBytes(-4));
+ EXPECT_EQ("-1000B", HumanReadableNumBytes(-1000));
+ EXPECT_EQ("-11.77MiB", HumanReadableNumBytes(-12345678));
+ EXPECT_EQ("-8E", HumanReadableNumBytes(kint64min));
+}
+
+TEST(safe_strto32, Int32s) {
+ int32 result;
+
+ EXPECT_EQ(true, safe_strto32("1", &result));
+ EXPECT_EQ(1, result);
+ EXPECT_EQ(true, safe_strto32("123", &result));
+ EXPECT_EQ(123, result);
+ EXPECT_EQ(true, safe_strto32(" -123 ", &result));
+ EXPECT_EQ(-123, result);
+ EXPECT_EQ(true, safe_strto32("2147483647", &result));
+ EXPECT_EQ(2147483647, result);
+ EXPECT_EQ(true, safe_strto32("-2147483648", &result));
+ EXPECT_EQ(-2147483648, result);
+
+ // Invalid argument
+ EXPECT_EQ(false, safe_strto32(" 132as ", &result));
+ EXPECT_EQ(false, safe_strto32(" 132.2 ", &result));
+ EXPECT_EQ(false, safe_strto32(" -", &result));
+ EXPECT_EQ(false, safe_strto32("", &result));
+ EXPECT_EQ(false, safe_strto32(" ", &result));
+ EXPECT_EQ(false, safe_strto32("123 a", &result));
+
+ // Overflow
+ EXPECT_EQ(false, safe_strto32("2147483648", &result));
+ EXPECT_EQ(false, safe_strto32("-2147483649", &result));
+}
+
+TEST(safe_strto64, Int64s) {
+ int64 result;
+
+ EXPECT_EQ(true, safe_strto64("1", &result));
+ EXPECT_EQ(1, result);
+ EXPECT_EQ(true, safe_strto64("123", &result));
+ EXPECT_EQ(123, result);
+ EXPECT_EQ(true, safe_strto64(" -123 ", &result));
+ EXPECT_EQ(-123, result);
+ EXPECT_EQ(true, safe_strto64("9223372036854775807", &result));
+ EXPECT_EQ(9223372036854775807, result);
+ EXPECT_EQ(true, safe_strto64("-9223372036854775808", &result));
+ // kint64min == -9223372036854775808
+ // Use -9223372036854775808 directly results in out of range error
+ EXPECT_EQ(kint64min, result);
+
+ // Invalid argument
+ EXPECT_EQ(false, safe_strto64(" 132as ", &result));
+ EXPECT_EQ(false, safe_strto64(" 132.2 ", &result));
+ EXPECT_EQ(false, safe_strto64(" -", &result));
+ EXPECT_EQ(false, safe_strto64("", &result));
+ EXPECT_EQ(false, safe_strto64(" ", &result));
+ EXPECT_EQ(false, safe_strto64("123 a", &result));
+
+ // Overflow
+ EXPECT_EQ(false, safe_strto64("9223372036854775808", &result));
+ EXPECT_EQ(false, safe_strto64("-9223372036854775809", &result));
+}
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/ordered_code.cc b/tensorflow/core/lib/strings/ordered_code.cc
new file mode 100644
index 0000000000..ec67595ebb
--- /dev/null
+++ b/tensorflow/core/lib/strings/ordered_code.cc
@@ -0,0 +1,515 @@
+#include "tensorflow/core/lib/strings/ordered_code.h"
+
+#include <assert.h>
+#include <stddef.h>
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace strings {
+
+// We encode a string in different ways depending on whether the item
+// should be in lexicographically increasing or decreasing order.
+//
+//
+// Lexicographically increasing order
+//
+// We want a string-to-string mapping F(x) such that for any two strings
+//
+// x < y => F(x) < F(y)
+//
+// In addition to the normal characters '\x00' through '\xff', we want to
+// encode a few extra symbols in strings:
+//
+// <sep> Separator between items
+// <infinity> Infinite string
+//
+// Therefore we need an alphabet with at least 258 symbols. Each
+// character '\1' through '\xfe' is mapped to itself. The other four are
+// encoded into two-letter sequences starting with '\0' and '\xff':
+//
+// <sep> encoded as => \0\1
+// \0 encoded as => \0\xff
+// \xff encoded as => \xff\x00
+// <infinity> encoded as => \xff\xff
+//
+// The remaining two-letter sequences starting with '\0' and '\xff' are
+// currently unused.
+//
+// F(<infinity>) is defined above. For any finite string x, F(x) is the
+// the encodings of x's characters followed by the encoding for <sep>. The
+// ordering of two finite strings is the same as the ordering of the
+// respective characters at the first position where they differ, which in
+// turn is the same as the ordering of the encodings of those two
+// characters. Moreover, for every finite string x, F(x) < F(<infinity>).
+//
+//
+// Lexicographically decreasing order
+//
+// We want a string-to-string mapping G(x) such that for any two strings,
+// whether finite or not,
+//
+// x < y => G(x) > G(y)
+//
+// To achieve this, define G(x) to be the inversion of F(x): I(F(x)). In
+// other words, invert every bit in F(x) to get G(x). For example,
+//
+// x = \x00\x13\xff
+// F(x) = \x00\xff\x13\xff\x00\x00\x01 escape \0, \xff, append F(<sep>)
+// G(x) = \xff\x00\xec\x00\xff\xff\xfe invert every bit in F(x)
+//
+// x = <infinity>
+// F(x) = \xff\xff
+// G(x) = \x00\x00
+//
+// Another example is
+//
+// x F(x) G(x) = I(F(x))
+// - ---- --------------
+// <infinity> \xff\xff \x00\x00
+// "foo" foo\0\1 \x99\x90\x90\xff\xfe
+// "aaa" aaa\0\1 \x9e\x9e\x9e\xff\xfe
+// "aa" aa\0\1 \x9e\x9e\xff\xfe
+// "" \0\1 \xff\xfe
+//
+// More generally and rigorously, if for any two strings x and y
+//
+// F(x) < F(y) => I(F(x)) > I(F(y)) (1)
+//
+// it would follow that x < y => G(x) > G(y) because
+//
+// x < y => F(x) < F(y) => G(x) = I(F(x)) > I(F(y)) = G(y)
+//
+// We now show why (1) is true, in two parts. Notice that for any two
+// strings x < y, F(x) is *not* a proper prefix of F(y). Suppose x is a
+// proper prefix of y (say, x="abc" < y="abcd"). F(x) and F(y) diverge at
+// the F(<sep>) in F(x) (v. F('d') in the example). Suppose x is not a
+// proper prefix of y (say, x="abce" < y="abd"), F(x) and F(y) diverge at
+// their respective encodings of the characters where x and y diverge
+// (F('c') v. F('d')). Finally, if y=<infinity>, we can see that
+// F(y)=\xff\xff is not the prefix of F(x) for any finite string x, simply
+// by considering all the possible first characters of F(x).
+//
+// Given that F(x) is not a proper prefix F(y), the order of F(x) and F(y)
+// is determined by the byte where F(x) and F(y) diverge. For example, the
+// order of F(x)="eefh" and F(y)="eeg" is determined by their third
+// characters. I(p) inverts each byte in p, which effectively subtracts
+// each byte from 0xff. So, in this example, I('f') > I('g'), and thus
+// I(F(x)) > I(F(y)).
+//
+//
+// Implementation
+//
+// To implement G(x) efficiently, we use C++ template to instantiate two
+// versions of the code to produce F(x), one for normal encoding (giving us
+// F(x)) and one for inverted encoding (giving us G(x) = I(F(x))).
+
+static const char kEscape1 = '\000';
+static const char kNullCharacter = '\xff'; // Combined with kEscape1
+static const char kSeparator = '\001'; // Combined with kEscape1
+
+static const char kEscape2 = '\xff';
+static const char kInfinity = '\xff'; // Combined with kEscape2
+static const char kFFCharacter = '\000'; // Combined with kEscape2
+
+static const char kEscape1_Separator[2] = {kEscape1, kSeparator};
+
+// Append to "*dest" the "len" bytes starting from "*src".
+inline static void AppendBytes(string* dest, const char* src, int len) {
+ dest->append(src, len);
+}
+
+inline bool IsSpecialByte(char c) { return ((unsigned char)(c + 1)) < 2; }
+
+// Return a pointer to the first byte in the range "[start..limit)"
+// whose value is 0 or 255 (kEscape1 or kEscape2). If no such byte
+// exists in the range, returns "limit".
+inline const char* SkipToNextSpecialByte(const char* start, const char* limit) {
+ // If these constants were ever changed, this routine needs to change
+ DCHECK_EQ(kEscape1, 0);
+ DCHECK_EQ(kEscape2 & 0xffu, 255u);
+ const char* p = start;
+ while (p < limit && !IsSpecialByte(*p)) {
+ p++;
+ }
+ return p;
+}
+
+// Expose SkipToNextSpecialByte for testing purposes
+const char* OrderedCode::TEST_SkipToNextSpecialByte(const char* start,
+ const char* limit) {
+ return SkipToNextSpecialByte(start, limit);
+}
+
+// Helper routine to encode "s" and append to "*dest", escaping special
+// characters.
+inline static void EncodeStringFragment(string* dest, StringPiece s) {
+ const char* p = s.data();
+ const char* limit = p + s.size();
+ const char* copy_start = p;
+ while (true) {
+ p = SkipToNextSpecialByte(p, limit);
+ if (p >= limit) break; // No more special characters that need escaping
+ char c = *(p++);
+ DCHECK(IsSpecialByte(c));
+ if (c == kEscape1) {
+ AppendBytes(dest, copy_start, p - copy_start - 1);
+ dest->push_back(kEscape1);
+ dest->push_back(kNullCharacter);
+ copy_start = p;
+ } else {
+ assert(c == kEscape2);
+ AppendBytes(dest, copy_start, p - copy_start - 1);
+ dest->push_back(kEscape2);
+ dest->push_back(kFFCharacter);
+ copy_start = p;
+ }
+ }
+ if (p > copy_start) {
+ AppendBytes(dest, copy_start, p - copy_start);
+ }
+}
+
+void OrderedCode::WriteString(string* dest, StringPiece s) {
+ EncodeStringFragment(dest, s);
+ AppendBytes(dest, kEscape1_Separator, 2);
+}
+
+void OrderedCode::WriteNumIncreasing(string* dest, uint64 val) {
+ // Values are encoded with a single byte length prefix, followed
+ // by the actual value in big-endian format with leading 0 bytes
+ // dropped.
+ unsigned char buf[9]; // 8 bytes for value plus one byte for length
+ int len = 0;
+ while (val > 0) {
+ len++;
+ buf[9 - len] = (val & 0xff);
+ val >>= 8;
+ }
+ buf[9 - len - 1] = (unsigned char)len;
+ len++;
+ AppendBytes(dest, reinterpret_cast<const char*>(buf + 9 - len), len);
+}
+
+// Parse the encoding of a previously encoded string.
+// If parse succeeds, return true, consume encoding from
+// "*src", and if result != NULL append the decoded string to "*result".
+// Otherwise, return false and leave both undefined.
+inline static bool ReadStringInternal(StringPiece* src, string* result) {
+ const char* start = src->data();
+ const char* string_limit = src->data() + src->size();
+
+ // We only scan up to "limit-2" since a valid string must end with
+ // a two character terminator: 'kEscape1 kSeparator'
+ const char* limit = string_limit - 1;
+ const char* copy_start = start;
+ while (true) {
+ start = SkipToNextSpecialByte(start, limit);
+ if (start >= limit) break; // No terminator sequence found
+ const char c = *(start++);
+ // If inversion is required, instead of inverting 'c', we invert the
+ // character constants to which 'c' is compared. We get the same
+ // behavior but save the runtime cost of inverting 'c'.
+ DCHECK(IsSpecialByte(c));
+ if (c == kEscape1) {
+ if (result) {
+ AppendBytes(result, copy_start, start - copy_start - 1);
+ }
+ // kEscape1 kSeparator ends component
+ // kEscape1 kNullCharacter represents '\0'
+ const char next = *(start++);
+ if (next == kSeparator) {
+ src->remove_prefix(start - src->data());
+ return true;
+ } else if (next == kNullCharacter) {
+ if (result) {
+ *result += '\0';
+ }
+ } else {
+ return false;
+ }
+ copy_start = start;
+ } else {
+ assert(c == kEscape2);
+ if (result) {
+ AppendBytes(result, copy_start, start - copy_start - 1);
+ }
+ // kEscape2 kFFCharacter represents '\xff'
+ // kEscape2 kInfinity is an error
+ const char next = *(start++);
+ if (next == kFFCharacter) {
+ if (result) {
+ *result += '\xff';
+ }
+ } else {
+ return false;
+ }
+ copy_start = start;
+ }
+ }
+ return false;
+}
+
+bool OrderedCode::ReadString(StringPiece* src, string* result) {
+ return ReadStringInternal(src, result);
+}
+
+bool OrderedCode::ReadNumIncreasing(StringPiece* src, uint64* result) {
+ if (src->empty()) {
+ return false; // Not enough bytes
+ }
+
+ // Decode length byte
+ const size_t len = static_cast<unsigned char>((*src)[0]);
+
+ // If len > 0 and src is longer than 1, the first byte of "payload"
+ // must be non-zero (otherwise the encoding is not minimal).
+ // In opt mode, we don't enforce that encodings must be minimal.
+ DCHECK(0 == len || src->size() == 1 || (*src)[1] != '\0')
+ << "invalid encoding";
+
+ if (len + 1 > src->size() || len > 8) {
+ return false; // Not enough bytes or too many bytes
+ }
+
+ if (result) {
+ uint64 tmp = 0;
+ for (size_t i = 0; i < len; i++) {
+ tmp <<= 8;
+ tmp |= static_cast<unsigned char>((*src)[1 + i]);
+ }
+ *result = tmp;
+ }
+ src->remove_prefix(len + 1);
+ return true;
+}
+
+void OrderedCode::TEST_Corrupt(string* str, int k) {
+ int seen_seps = 0;
+ for (size_t i = 0; i + 1 < str->size(); i++) {
+ if ((*str)[i] == kEscape1 && (*str)[i + 1] == kSeparator) {
+ seen_seps++;
+ if (seen_seps == k) {
+ (*str)[i + 1] = kSeparator + 1;
+ return;
+ }
+ }
+ }
+}
+
+// Signed number encoding/decoding /////////////////////////////////////
+//
+// The format is as follows:
+//
+// The first bit (the most significant bit of the first byte)
+// represents the sign, 0 if the number is negative and
+// 1 if the number is >= 0.
+//
+// Any unbroken sequence of successive bits with the same value as the sign
+// bit, up to 9 (the 8th and 9th are the most significant bits of the next
+// byte), are size bits that count the number of bytes after the first byte.
+// That is, the total length is between 1 and 10 bytes.
+//
+// The value occupies the bits after the sign bit and the "size bits"
+// till the end of the string, in network byte order. If the number
+// is negative, the bits are in 2-complement.
+//
+//
+// Example 1: number 0x424242 -> 4 byte big-endian hex string 0xf0424242:
+//
+// +---------------+---------------+---------------+---------------+
+// 1 1 1 1 0 0 0 0 0 1 0 0 0 0 1 0 0 1 0 0 0 1 0 0 0 1 0 0 0 0 1 0
+// +---------------+---------------+---------------+---------------+
+// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
+// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
+// | | | | payload: the remaining bits after the sign and size bits
+// | | | | and the delimiter bit, the value is 0x424242
+// | | | |
+// | size bits: 3 successive bits with the same value as the sign bit
+// | (followed by a delimiter bit with the opposite value)
+// | mean that there are 3 bytes after the first byte, 4 total
+// |
+// sign bit: 1 means that the number is non-negative
+//
+// Example 2: negative number -0x800 -> 2 byte big-endian hex string 0x3800:
+//
+// +---------------+---------------+
+// 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0
+// +---------------+---------------+
+// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
+// | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
+// | | payload: the remaining bits after the sign and size bits and the
+// | | delimiter bit, 2-complement because of the negative sign,
+// | | value is ~0x7ff, represents the value -0x800
+// | |
+// | size bits: 1 bit with the same value as the sign bit
+// | (followed by a delimiter bit with the opposite value)
+// | means that there is 1 byte after the first byte, 2 total
+// |
+// sign bit: 0 means that the number is negative
+//
+//
+// Compared with the simpler unsigned format used for uint64 numbers,
+// this format is more compact for small numbers, namely one byte encodes
+// numbers in the range [-64,64), two bytes cover the range [-2^13,2^13), etc.
+// In general, n bytes encode numbers in the range [-2^(n*7-1),2^(n*7-1)).
+// (The cross-over point for compactness of representation is 8 bytes,
+// where this format only covers the range [-2^55,2^55),
+// whereas an encoding with sign bit and length in the first byte and
+// payload in all following bytes would cover [-2^56,2^56).)
+
+static const int kMaxSigned64Length = 10;
+
+// This array maps encoding length to header bits in the first two bytes.
+static const char kLengthToHeaderBits[1 + kMaxSigned64Length][2] = {
+ {0, 0}, {'\x80', 0}, {'\xc0', 0}, {'\xe0', 0},
+ {'\xf0', 0}, {'\xf8', 0}, {'\xfc', 0}, {'\xfe', 0},
+ {'\xff', 0}, {'\xff', '\x80'}, {'\xff', '\xc0'}};
+
+// This array maps encoding lengths to the header bits that overlap with
+// the payload and need fixing when reading.
+static const uint64 kLengthToMask[1 + kMaxSigned64Length] = {
+ 0ULL,
+ 0x80ULL,
+ 0xc000ULL,
+ 0xe00000ULL,
+ 0xf0000000ULL,
+ 0xf800000000ULL,
+ 0xfc0000000000ULL,
+ 0xfe000000000000ULL,
+ 0xff00000000000000ULL,
+ 0x8000000000000000ULL,
+ 0ULL};
+
+// This array maps the number of bits in a number to the encoding
+// length produced by WriteSignedNumIncreasing.
+// For positive numbers, the number of bits is 1 plus the most significant
+// bit position (the highest bit position in a positive int64 is 63).
+// For a negative number n, we count the bits in ~n.
+// That is, length = kBitsToLength[Bits::Log2Floor64(n < 0 ? ~n : n) + 1].
+static const int8 kBitsToLength[1 + 63] = {
+ 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 4,
+ 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 7, 7,
+ 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 10};
+
+#if defined(__GNUC__)
+// Returns floor(lg(n)). Returns -1 if n == 0.
+static int Log2Floor64(uint64 n) {
+ return n == 0 ? -1 : 63 ^ __builtin_clzll(n);
+}
+#else
+// Portable slow version
+static int Log2Floor32_Portable(uint32 n) {
+ if (n == 0) return -1;
+ int log = 0;
+ uint32 value = n;
+ for (int i = 4; i >= 0; --i) {
+ int shift = (1 << i);
+ uint32 x = value >> shift;
+ if (x != 0) {
+ value = x;
+ log += shift;
+ }
+ }
+ assert(value == 1);
+ return log;
+}
+// Returns floor(lg(n)). Returns -1 if n == 0.
+static int Log2Floor64(uint64 n) {
+ const uint32 topbits = static_cast<uint32>(n >> 32);
+ if (topbits == 0) {
+ // Top bits are zero, so scan in bottom bits
+ return Log2Floor32_Portable(static_cast<uint32>(n));
+ } else {
+ return 32 + Log2Floor32_Portable(topbits);
+ }
+}
+#endif
+
+// Calculates the encoding length in bytes of the signed number n.
+static inline int SignedEncodingLength(int64 n) {
+ return kBitsToLength[Log2Floor64(n < 0 ? ~n : n) + 1];
+}
+
+static void StoreBigEndian64(char* dst, uint64 v) {
+ for (int i = 0; i < 8; i++) {
+ dst[i] = (v >> (56 - 8 * i)) & 0xff;
+ }
+}
+
+static uint64 LoadBigEndian64(const char* src) {
+ uint64 result = 0;
+ for (int i = 0; i < 8; i++) {
+ unsigned char c = static_cast<unsigned char>(src[i]);
+ result |= static_cast<uint64>(c) << (56 - 8 * i);
+ }
+ return result;
+}
+
+void OrderedCode::WriteSignedNumIncreasing(string* dest, int64 val) {
+ const uint64 x = val < 0 ? ~val : val;
+ if (x < 64) { // fast path for encoding length == 1
+ *dest += kLengthToHeaderBits[1][0] ^ val;
+ return;
+ }
+ // buf = val in network byte order, sign extended to 10 bytes
+ const char sign_byte = val < 0 ? '\xff' : '\0';
+ char buf[10] = {
+ sign_byte, sign_byte,
+ };
+ StoreBigEndian64(buf + 2, val);
+ static_assert(sizeof(buf) == kMaxSigned64Length, "max length size mismatch");
+ const int len = SignedEncodingLength(x);
+ DCHECK_GE(len, 2);
+ char* const begin = buf + sizeof(buf) - len;
+ begin[0] ^= kLengthToHeaderBits[len][0];
+ begin[1] ^= kLengthToHeaderBits[len][1]; // ok because len >= 2
+ dest->append(begin, len);
+}
+
+bool OrderedCode::ReadSignedNumIncreasing(StringPiece* src, int64* result) {
+ if (src->empty()) return false;
+ const uint64 xor_mask = (!((*src)[0] & 0x80)) ? ~0ULL : 0ULL;
+ const unsigned char first_byte = (*src)[0] ^ (xor_mask & 0xff);
+
+ // now calculate and test length, and set x to raw (unmasked) result
+ int len;
+ uint64 x;
+ if (first_byte != 0xff) {
+ len = 7 - Log2Floor64(first_byte ^ 0xff);
+ if (src->size() < static_cast<size_t>(len)) return false;
+ x = xor_mask; // sign extend using xor_mask
+ for (int i = 0; i < len; ++i)
+ x = (x << 8) | static_cast<unsigned char>((*src)[i]);
+ } else {
+ len = 8;
+ if (src->size() < static_cast<size_t>(len)) return false;
+ const unsigned char second_byte = (*src)[1] ^ (xor_mask & 0xff);
+ if (second_byte >= 0x80) {
+ if (second_byte < 0xc0) {
+ len = 9;
+ } else {
+ const unsigned char third_byte = (*src)[2] ^ (xor_mask & 0xff);
+ if (second_byte == 0xc0 && third_byte < 0x80) {
+ len = 10;
+ } else {
+ return false; // either len > 10 or len == 10 and #bits > 63
+ }
+ }
+ if (src->size() < static_cast<size_t>(len)) return false;
+ }
+ x = LoadBigEndian64(src->data() + len - 8);
+ }
+
+ x ^= kLengthToMask[len]; // remove spurious header bits
+
+ DCHECK_EQ(len, SignedEncodingLength(x)) << "invalid encoding";
+
+ if (result) *result = x;
+ src->remove_prefix(len);
+ return true;
+}
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/ordered_code.h b/tensorflow/core/lib/strings/ordered_code.h
new file mode 100644
index 0000000000..39f1df9a94
--- /dev/null
+++ b/tensorflow/core/lib/strings/ordered_code.h
@@ -0,0 +1,77 @@
+// This module provides routines for encoding a sequence of typed
+// entities into a string. The resulting strings can be
+// lexicographically compared to yield the same comparison value that
+// would have been generated if the encoded items had been compared
+// one by one according to their type.
+//
+// More precisely, suppose:
+// 1. string A is generated by encoding the sequence of items [A_1..A_n]
+// 2. string B is generated by encoding the sequence of items [B_1..B_n]
+// 3. The types match; i.e., for all i: A_i was encoded using
+// the same routine as B_i
+// Then:
+// Comparing A vs. B lexicographically is the same as comparing
+// the vectors [A_1..A_n] and [B_1..B_n] lexicographically.
+//
+// Furthermore, if n < m, the encoding of [A_1..A_n] is a strict prefix of
+// [A_1..A_m] (unless m = n+1 and A_m is the empty string encoded with
+// WriteTrailingString, in which case the encodings are equal).
+//
+// This module is often useful when generating multi-part sstable
+// keys that have to be ordered in a particular fashion.
+
+#ifndef TENSORFLOW_LIB_STRINGS_ORDERED_CODE_H__
+#define TENSORFLOW_LIB_STRINGS_ORDERED_CODE_H__
+
+#include <string>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+class StringPiece;
+
+namespace strings {
+
+class OrderedCode {
+ public:
+ // -------------------------------------------------------------------
+ // Encoding routines: each one of the following routines append
+ // one item to "*dest" in an encoding where larger values are
+ // ordered lexicographically after smaller values.
+ static void WriteString(string* dest, StringPiece str);
+ static void WriteNumIncreasing(string* dest, uint64 num);
+ static void WriteSignedNumIncreasing(string* dest, int64 num);
+
+ // -------------------------------------------------------------------
+ // Decoding routines: these extract an item earlier encoded using
+ // the corresponding WriteXXX() routines above. The item is read
+ // from "*src"; "*src" is modified to point past the decoded item;
+ // and if "result" is non-NULL, "*result" is modified to contain the
+ // result. In case of string result, the decoded string is appended to
+ // "*result". Returns true if the next item was read successfully, false
+ // otherwise.
+ static bool ReadString(StringPiece* src, string* result);
+ static bool ReadNumIncreasing(StringPiece* src, uint64* result);
+ static bool ReadSignedNumIncreasing(StringPiece* src, int64* result);
+
+ // Helper for testing: corrupt "*str" by changing the kth item separator
+ // in the string.
+ static void TEST_Corrupt(string* str, int k);
+
+ // Helper for testing.
+ // SkipToNextSpecialByte is an internal routine defined in the .cc file
+ // with the following semantics. Return a pointer to the first byte
+ // in the range "[start..limit)" whose value is 0 or 255. If no such
+ // byte exists in the range, returns "limit".
+ static const char* TEST_SkipToNextSpecialByte(const char* start,
+ const char* limit);
+
+ private:
+ // This has only static methods, so disallow construction entirely
+ OrderedCode();
+ TF_DISALLOW_COPY_AND_ASSIGN(OrderedCode);
+};
+
+} // namespace strings
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_STRINGS_ORDERED_CODE_H__
diff --git a/tensorflow/core/lib/strings/ordered_code_test.cc b/tensorflow/core/lib/strings/ordered_code_test.cc
new file mode 100644
index 0000000000..d517d14f4a
--- /dev/null
+++ b/tensorflow/core/lib/strings/ordered_code_test.cc
@@ -0,0 +1,1183 @@
+#include "tensorflow/core/lib/strings/ordered_code.h"
+
+#include <float.h>
+#include <stddef.h>
+#include <limits>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace strings {
+
+static string RandomString(random::SimplePhilox* rnd, int len) {
+ string x;
+ for (int i = 0; i < len; i++) {
+ x += rnd->Uniform(256);
+ }
+ return x;
+}
+
+// ---------------------------------------------------------------------
+// Utility template functions (they help templatize the tests below)
+
+// Read/WriteIncreasing are defined for string, uint64, int64 below.
+template <typename T>
+static void OCWriteIncreasing(string* dest, const T& val);
+template <typename T>
+static bool OCReadIncreasing(StringPiece* src, T* result);
+
+// Read/WriteIncreasing<string>
+template <>
+void OCWriteIncreasing<string>(string* dest, const string& val) {
+ OrderedCode::WriteString(dest, val);
+}
+template <>
+bool OCReadIncreasing<string>(StringPiece* src, string* result) {
+ return OrderedCode::ReadString(src, result);
+}
+
+// Read/WriteIncreasing<uint64>
+template <>
+void OCWriteIncreasing<uint64>(string* dest, const uint64& val) {
+ OrderedCode::WriteNumIncreasing(dest, val);
+}
+template <>
+bool OCReadIncreasing<uint64>(StringPiece* src, uint64* result) {
+ return OrderedCode::ReadNumIncreasing(src, result);
+}
+
+// Read/WriteIncreasing<int64>
+template <>
+void OCWriteIncreasing<int64>(string* dest, const int64& val) {
+ OrderedCode::WriteSignedNumIncreasing(dest, val);
+}
+template <>
+bool OCReadIncreasing<int64>(StringPiece* src, int64* result) {
+ return OrderedCode::ReadSignedNumIncreasing(src, result);
+}
+
+template <typename T>
+string OCWrite(T val) {
+ string result;
+ OCWriteIncreasing<T>(&result, val);
+ return result;
+}
+
+template <typename T>
+void OCWriteToString(string* result, T val) {
+ OCWriteIncreasing<T>(result, val);
+}
+
+template <typename T>
+bool OCRead(StringPiece* s, T* val) {
+ return OCReadIncreasing<T>(s, val);
+}
+
+// ---------------------------------------------------------------------
+// Numbers
+
+template <typename T>
+static T TestRead(const string& a) {
+ // gracefully reject any proper prefix of an encoding
+ for (int i = 0; i < a.size() - 1; ++i) {
+ StringPiece s(a.data(), i);
+ CHECK(!OCRead<T>(&s, NULL));
+ CHECK_EQ(s, a.substr(0, i));
+ }
+
+ StringPiece s(a);
+ T v;
+ CHECK(OCRead<T>(&s, &v));
+ CHECK(s.empty());
+ return v;
+}
+
+template <typename T>
+static void TestWriteRead(T expected) {
+ EXPECT_EQ(expected, TestRead<T>(OCWrite<T>(expected)));
+}
+
+// Verifies that the second Write* call appends a non-empty string to its
+// output.
+template <typename T, typename U>
+static void TestWriteAppends(T first, U second) {
+ string encoded;
+ OCWriteToString<T>(&encoded, first);
+ string encoded_first_only = encoded;
+ OCWriteToString<U>(&encoded, second);
+ EXPECT_NE(encoded, encoded_first_only);
+ EXPECT_TRUE(StringPiece(encoded).starts_with(encoded_first_only));
+}
+
+template <typename T>
+static void TestNumbers(T multiplier) {
+ // first test powers of 2 (and nearby numbers)
+ for (T x = std::numeric_limits<T>().max(); x != 0; x /= 2) {
+ TestWriteRead(multiplier * (x - 1));
+ TestWriteRead(multiplier * x);
+ if (x != std::numeric_limits<T>::max()) {
+ TestWriteRead(multiplier * (x + 1));
+ } else if (multiplier < 0 && multiplier == -1) {
+ TestWriteRead(-x - 1);
+ }
+ }
+
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ for (int bits = 1; bits <= std::numeric_limits<T>().digits; ++bits) {
+ // test random non-negative numbers with given number of significant bits
+ const uint64 mask = (~0ULL) >> (64 - bits);
+ for (int i = 0; i < 1000; i++) {
+ T x = rnd.Rand64() & mask;
+ TestWriteRead(multiplier * x);
+ T y = rnd.Rand64() & mask;
+ TestWriteAppends(multiplier * x, multiplier * y);
+ }
+ }
+}
+
+// Return true iff 'a' is "before" 'b'
+static bool CompareStrings(const string& a, const string& b) { return (a < b); }
+
+template <typename T>
+static void TestNumberOrdering() {
+ // first the negative numbers (if T is signed, otherwise no-op)
+ string laststr = OCWrite<T>(std::numeric_limits<T>().min());
+ for (T num = std::numeric_limits<T>().min() / 2; num != 0; num /= 2) {
+ string strminus1 = OCWrite<T>(num - 1);
+ string str = OCWrite<T>(num);
+ string strplus1 = OCWrite<T>(num + 1);
+
+ CHECK(CompareStrings(strminus1, str));
+ CHECK(CompareStrings(str, strplus1));
+
+ // Compare 'str' with 'laststr'. When we approach 0, 'laststr' is
+ // not necessarily before 'strminus1'.
+ CHECK(CompareStrings(laststr, str));
+ laststr = str;
+ }
+
+ // then the positive numbers
+ laststr = OCWrite<T>(0);
+ T num = 1;
+ while (num < std::numeric_limits<T>().max() / 2) {
+ num *= 2;
+ string strminus1 = OCWrite<T>(num - 1);
+ string str = OCWrite<T>(num);
+ string strplus1 = OCWrite<T>(num + 1);
+
+ CHECK(CompareStrings(strminus1, str));
+ CHECK(CompareStrings(str, strplus1));
+
+ // Compare 'str' with 'laststr'.
+ CHECK(CompareStrings(laststr, str));
+ laststr = str;
+ }
+}
+
+// Helper routine for testing TEST_SkipToNextSpecialByte
+static int FindSpecial(const string& x) {
+ const char* p = x.data();
+ const char* limit = p + x.size();
+ const char* result = OrderedCode::TEST_SkipToNextSpecialByte(p, limit);
+ return result - p;
+}
+
+TEST(OrderedCode, SkipToNextSpecialByte) {
+ for (size_t len = 0; len < 256; len++) {
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ string x;
+ while (x.size() < len) {
+ char c = 1 + rnd.Uniform(254);
+ ASSERT_NE(c, 0);
+ ASSERT_NE(c, 255);
+ x += c; // No 0 bytes, no 255 bytes
+ }
+ EXPECT_EQ(FindSpecial(x), x.size());
+ for (size_t special_pos = 0; special_pos < len; special_pos++) {
+ for (size_t special_test = 0; special_test < 2; special_test++) {
+ const char special_byte = (special_test == 0) ? 0 : 255;
+ string y = x;
+ y[special_pos] = special_byte;
+ EXPECT_EQ(FindSpecial(y), special_pos);
+ if (special_pos < 16) {
+ // Add some special bytes after the one at special_pos to make sure
+ // we still return the earliest special byte in the string
+ for (size_t rest = special_pos + 1; rest < len; rest++) {
+ if (rnd.OneIn(3)) {
+ y[rest] = rnd.OneIn(2) ? 0 : 255;
+ EXPECT_EQ(FindSpecial(y), special_pos);
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+TEST(OrderedCode, ExhaustiveFindSpecial) {
+ char buf[16];
+ char* limit = buf + sizeof(buf);
+ int count = 0;
+ for (int start_offset = 0; start_offset <= 5; start_offset += 5) {
+ // We test exhaustively with all combinations of 3 bytes starting
+ // at offset 0 and offset 5 (so as to test with the bytes at both
+ // ends of a 64-bit word).
+ for (size_t i = 0; i < sizeof(buf); i++) {
+ buf[i] = 'a'; // Not a special byte
+ }
+ for (int b0 = 0; b0 < 256; b0++) {
+ for (int b1 = 0; b1 < 256; b1++) {
+ for (int b2 = 0; b2 < 256; b2++) {
+ buf[start_offset + 0] = b0;
+ buf[start_offset + 1] = b1;
+ buf[start_offset + 2] = b2;
+ char* expected;
+ if (b0 == 0 || b0 == 255) {
+ expected = &buf[start_offset];
+ } else if (b1 == 0 || b1 == 255) {
+ expected = &buf[start_offset + 1];
+ } else if (b2 == 0 || b2 == 255) {
+ expected = &buf[start_offset + 2];
+ } else {
+ expected = limit;
+ }
+ count++;
+ EXPECT_EQ(expected,
+ OrderedCode::TEST_SkipToNextSpecialByte(buf, limit));
+ }
+ }
+ }
+ }
+ EXPECT_EQ(count, 256 * 256 * 256 * 2);
+}
+
+TEST(Uint64, EncodeDecode) { TestNumbers<uint64>(1); }
+
+TEST(Uint64, Ordering) { TestNumberOrdering<uint64>(); }
+
+TEST(Int64, EncodeDecode) {
+ TestNumbers<int64>(1);
+ TestNumbers<int64>(-1);
+}
+
+TEST(Int64, Ordering) { TestNumberOrdering<int64>(); }
+
+// Returns the bitwise complement of s.
+static inline string StrNot(const string& s) {
+ string result;
+ for (string::const_iterator it = s.begin(); it != s.end(); ++it)
+ result.push_back(~*it);
+ return result;
+}
+
+template <typename T>
+static void TestInvalidEncoding(const string& s) {
+ StringPiece p(s);
+ EXPECT_FALSE(OCRead<T>(&p, static_cast<T*>(NULL)));
+ EXPECT_EQ(s, p);
+}
+
+TEST(OrderedCodeInvalidEncodingsTest, Overflow) {
+ // 1U << 64, increasing and decreasing
+ const string k2xx64U = "\x09\x01" + string(8, 0);
+ TestInvalidEncoding<uint64>(k2xx64U);
+
+ // 1 << 63 and ~(1 << 63), increasing and decreasing
+ const string k2xx63 = "\xff\xc0\x80" + string(7, 0);
+ TestInvalidEncoding<int64>(k2xx63);
+ TestInvalidEncoding<int64>(StrNot(k2xx63));
+}
+
+TEST(OrderedCodeInvalidEncodingsDeathTest, NonCanonical) {
+ // Test "ambiguous"/"non-canonical" encodings.
+ // These are non-minimal (but otherwise "valid") encodings that
+ // differ from the minimal encoding chosen by OrderedCode::WriteXXX
+ // and thus should be avoided to not mess up the string ordering of
+ // encodings.
+
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+
+ for (int n = 2; n <= 9; ++n) {
+ // The zero in non_minimal[1] is "redundant".
+ string non_minimal =
+ string(1, n - 1) + string(1, 0) + RandomString(&rnd, n - 2);
+ EXPECT_EQ(n, non_minimal.length());
+
+ EXPECT_NE(OCWrite<uint64>(0), non_minimal);
+#ifndef NDEBUG
+ StringPiece s(non_minimal);
+ EXPECT_DEATH(OrderedCode::ReadNumIncreasing(&s, NULL), "invalid encoding");
+#else
+ TestRead<uint64>(non_minimal);
+#endif
+ }
+
+ for (int n = 2; n <= 10; ++n) {
+ // Header with 1 sign bit and n-1 size bits.
+ string header = string(n / 8, 0xff) + string(1, 0xff << (8 - (n % 8)));
+ // There are more than 7 zero bits between header bits and "payload".
+ string non_minimal = header +
+ string(1, rnd.Uniform(256) & ~*header.rbegin()) +
+ RandomString(&rnd, n - header.length() - 1);
+ EXPECT_EQ(n, non_minimal.length());
+
+ EXPECT_NE(OCWrite<int64>(0), non_minimal);
+#ifndef NDEBUG
+ StringPiece s(non_minimal);
+ EXPECT_DEATH(OrderedCode::ReadSignedNumIncreasing(&s, NULL),
+ "invalid encoding")
+ << n;
+#else
+ TestRead<int64>(non_minimal);
+#endif
+ }
+}
+
+// Returns random number with specified number of bits,
+// i.e., in the range [2^(bits-1),2^bits).
+static uint64 NextBits(random::SimplePhilox* rnd, int bits) {
+ return (bits != 0)
+ ? (rnd->Rand64() % (1LL << (bits - 1))) + (1LL << (bits - 1))
+ : 0;
+}
+
+template <typename T>
+static void BM_WriteNum(int n, T multiplier) {
+ static const int kValues = 64;
+ T values[kValues];
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ // Use enough distinct values to confuse the branch predictor
+ for (int i = 0; i < kValues; i++) {
+ values[i] = NextBits(&rnd, n % 64) * multiplier;
+ }
+ string result;
+ int index = 0;
+ while (n-- > 0) {
+ result.clear();
+ OCWriteToString<T>(&result, values[index % kValues]);
+ index++;
+ }
+}
+
+template <typename T>
+static void BM_ReadNum(int n, T multiplier) {
+ string x;
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ // Use enough distinct values to confuse the branch predictor
+ static const int kValues = 64;
+ string values[kValues];
+ for (int i = 0; i < kValues; i++) {
+ T val = NextBits(&rnd, i % 64) * multiplier;
+ values[i] = OCWrite<T>(val);
+ }
+ uint32 index = 0;
+ while (n-- > 0) {
+ T val;
+ StringPiece s = values[index++ % kValues];
+ OCRead<T>(&s, &val);
+ }
+}
+
+#define BENCHMARK_NUM(name, T, multiplier) \
+ static void BM_Write##name(int n) { BM_WriteNum<T>(n, multiplier); } \
+ BENCHMARK(BM_Write##name); \
+ static void BM_Read##name(int n) { BM_ReadNum<T>(n, multiplier); } \
+ BENCHMARK(BM_Read##name)
+
+BENCHMARK_NUM(NumIncreasing, uint64, 1);
+BENCHMARK_NUM(SignedNum, int64, 1);
+BENCHMARK_NUM(SignedNumNegative, int64, -1);
+
+#undef BENCHMARK_NUM
+
+// ---------------------------------------------------------------------
+// Strings
+
+TEST(String, EncodeDecode) {
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+
+ for (int len = 0; len < 256; len++) {
+ const string a = RandomString(&rnd, len);
+ TestWriteRead(a);
+ for (int len2 = 0; len2 < 64; len2++) {
+ const string b = RandomString(&rnd, len2);
+
+ TestWriteAppends(a, b);
+
+ string out;
+ OCWriteToString<string>(&out, a);
+ OCWriteToString<string>(&out, b);
+
+ string a2, b2, dummy;
+ StringPiece s = out;
+ StringPiece s2 = out;
+ CHECK(OCRead<string>(&s, &a2));
+ CHECK(OCRead<string>(&s2, NULL));
+ CHECK_EQ(s, s2);
+
+ CHECK(OCRead<string>(&s, &b2));
+ CHECK(OCRead<string>(&s2, NULL));
+ CHECK_EQ(s, s2);
+
+ CHECK(!OCRead<string>(&s, &dummy));
+ CHECK(!OCRead<string>(&s2, NULL));
+ CHECK_EQ(a, a2);
+ CHECK_EQ(b, b2);
+ CHECK(s.empty());
+ CHECK(s2.empty());
+ }
+ }
+}
+
+// 'str' is a static C-style string that may contain '\0'
+#define STATIC_STR(str) StringPiece((str), sizeof(str) - 1)
+
+static string EncodeStringIncreasing(StringPiece value) {
+ string encoded;
+ OrderedCode::WriteString(&encoded, value);
+ return encoded;
+}
+
+TEST(String, Increasing) {
+ // Here are a series of strings in non-decreasing order, including
+ // consecutive strings such that the second one is equal to, a proper
+ // prefix of, or has the same length as the first one. Most also contain
+ // the special escaping characters '\x00' and '\xff'.
+ ASSERT_EQ(EncodeStringIncreasing(STATIC_STR("")),
+ EncodeStringIncreasing(STATIC_STR("")));
+
+ ASSERT_LT(EncodeStringIncreasing(STATIC_STR("")),
+ EncodeStringIncreasing(STATIC_STR("\x00")));
+
+ ASSERT_EQ(EncodeStringIncreasing(STATIC_STR("\x00")),
+ EncodeStringIncreasing(STATIC_STR("\x00")));
+
+ ASSERT_LT(EncodeStringIncreasing(STATIC_STR("\x00")),
+ EncodeStringIncreasing(STATIC_STR("\x01")));
+
+ ASSERT_LT(EncodeStringIncreasing(STATIC_STR("\x01")),
+ EncodeStringIncreasing(STATIC_STR("a")));
+
+ ASSERT_EQ(EncodeStringIncreasing(STATIC_STR("a")),
+ EncodeStringIncreasing(STATIC_STR("a")));
+
+ ASSERT_LT(EncodeStringIncreasing(STATIC_STR("a")),
+ EncodeStringIncreasing(STATIC_STR("aa")));
+
+ ASSERT_LT(EncodeStringIncreasing(STATIC_STR("aa")),
+ EncodeStringIncreasing(STATIC_STR("\xff")));
+
+ ASSERT_LT(EncodeStringIncreasing(STATIC_STR("\xff")),
+ EncodeStringIncreasing(STATIC_STR("\xff\x00")));
+
+ ASSERT_LT(EncodeStringIncreasing(STATIC_STR("\xff\x00")),
+ EncodeStringIncreasing(STATIC_STR("\xff\x01")));
+}
+
+TEST(EncodingIsExpected, String) {
+ std::vector<std::pair<string, string>> data = {
+ {"", string("\x00\x01", 2)},
+ {"foo", string("foo\x00\x01", 5)},
+ {"hello", string("hello\x00\x01", 7)},
+ {string("\x00\x01\xff", 3), string("\x00\xff\x01\xff\x00\x00\x01", 7)},
+ };
+ for (const auto& t : data) {
+ string result;
+ OrderedCode::WriteString(&result, t.first);
+ EXPECT_EQ(t.second, result);
+
+ StringPiece in = result;
+ string decoded;
+ EXPECT_TRUE(OrderedCode::ReadString(&in, &decoded));
+ EXPECT_EQ(t.first, decoded);
+ EXPECT_EQ("", in);
+ }
+}
+
+TEST(EncodingIsExpected, Unsigned) {
+ std::vector<std::pair<uint64, string>> data = {
+ {0x0ull, string("\000", 1)},
+ {0x1ull, string("\001\001", 2)},
+ {0x2ull, string("\001\002", 2)},
+ {0x1ull, string("\001\001", 2)},
+ {0x2ull, string("\001\002", 2)},
+ {0x3ull, string("\001\003", 2)},
+ {0x3ull, string("\001\003", 2)},
+ {0x4ull, string("\001\004", 2)},
+ {0x5ull, string("\001\005", 2)},
+ {0x7ull, string("\001\007", 2)},
+ {0x8ull, string("\001\010", 2)},
+ {0x9ull, string("\001\t", 2)},
+ {0xfull, string("\001\017", 2)},
+ {0x10ull, string("\001\020", 2)},
+ {0x11ull, string("\001\021", 2)},
+ {0x1full, string("\001\037", 2)},
+ {0x20ull, string("\001 ", 2)},
+ {0x21ull, string("\001!", 2)},
+ {0x3full, string("\001?", 2)},
+ {0x40ull, string("\001@", 2)},
+ {0x41ull, string("\001A", 2)},
+ {0x7full, string("\001\177", 2)},
+ {0x80ull, string("\001\200", 2)},
+ {0x81ull, string("\001\201", 2)},
+ {0xffull, string("\001\377", 2)},
+ {0x100ull, string("\002\001\000", 3)},
+ {0x101ull, string("\002\001\001", 3)},
+ {0x1ffull, string("\002\001\377", 3)},
+ {0x200ull, string("\002\002\000", 3)},
+ {0x201ull, string("\002\002\001", 3)},
+ {0x3ffull, string("\002\003\377", 3)},
+ {0x400ull, string("\002\004\000", 3)},
+ {0x401ull, string("\002\004\001", 3)},
+ {0x7ffull, string("\002\007\377", 3)},
+ {0x800ull, string("\002\010\000", 3)},
+ {0x801ull, string("\002\010\001", 3)},
+ {0xfffull, string("\002\017\377", 3)},
+ {0x1000ull, string("\002\020\000", 3)},
+ {0x1001ull, string("\002\020\001", 3)},
+ {0x1fffull, string("\002\037\377", 3)},
+ {0x2000ull, string("\002 \000", 3)},
+ {0x2001ull, string("\002 \001", 3)},
+ {0x3fffull, string("\002?\377", 3)},
+ {0x4000ull, string("\002@\000", 3)},
+ {0x4001ull, string("\002@\001", 3)},
+ {0x7fffull, string("\002\177\377", 3)},
+ {0x8000ull, string("\002\200\000", 3)},
+ {0x8001ull, string("\002\200\001", 3)},
+ {0xffffull, string("\002\377\377", 3)},
+ {0x10000ull, string("\003\001\000\000", 4)},
+ {0x10001ull, string("\003\001\000\001", 4)},
+ {0x1ffffull, string("\003\001\377\377", 4)},
+ {0x20000ull, string("\003\002\000\000", 4)},
+ {0x20001ull, string("\003\002\000\001", 4)},
+ {0x3ffffull, string("\003\003\377\377", 4)},
+ {0x40000ull, string("\003\004\000\000", 4)},
+ {0x40001ull, string("\003\004\000\001", 4)},
+ {0x7ffffull, string("\003\007\377\377", 4)},
+ {0x80000ull, string("\003\010\000\000", 4)},
+ {0x80001ull, string("\003\010\000\001", 4)},
+ {0xfffffull, string("\003\017\377\377", 4)},
+ {0x100000ull, string("\003\020\000\000", 4)},
+ {0x100001ull, string("\003\020\000\001", 4)},
+ {0x1fffffull, string("\003\037\377\377", 4)},
+ {0x200000ull, string("\003 \000\000", 4)},
+ {0x200001ull, string("\003 \000\001", 4)},
+ {0x3fffffull, string("\003?\377\377", 4)},
+ {0x400000ull, string("\003@\000\000", 4)},
+ {0x400001ull, string("\003@\000\001", 4)},
+ {0x7fffffull, string("\003\177\377\377", 4)},
+ {0x800000ull, string("\003\200\000\000", 4)},
+ {0x800001ull, string("\003\200\000\001", 4)},
+ {0xffffffull, string("\003\377\377\377", 4)},
+ {0x1000000ull, string("\004\001\000\000\000", 5)},
+ {0x1000001ull, string("\004\001\000\000\001", 5)},
+ {0x1ffffffull, string("\004\001\377\377\377", 5)},
+ {0x2000000ull, string("\004\002\000\000\000", 5)},
+ {0x2000001ull, string("\004\002\000\000\001", 5)},
+ {0x3ffffffull, string("\004\003\377\377\377", 5)},
+ {0x4000000ull, string("\004\004\000\000\000", 5)},
+ {0x4000001ull, string("\004\004\000\000\001", 5)},
+ {0x7ffffffull, string("\004\007\377\377\377", 5)},
+ {0x8000000ull, string("\004\010\000\000\000", 5)},
+ {0x8000001ull, string("\004\010\000\000\001", 5)},
+ {0xfffffffull, string("\004\017\377\377\377", 5)},
+ {0x10000000ull, string("\004\020\000\000\000", 5)},
+ {0x10000001ull, string("\004\020\000\000\001", 5)},
+ {0x1fffffffull, string("\004\037\377\377\377", 5)},
+ {0x20000000ull, string("\004 \000\000\000", 5)},
+ {0x20000001ull, string("\004 \000\000\001", 5)},
+ {0x3fffffffull, string("\004?\377\377\377", 5)},
+ {0x40000000ull, string("\004@\000\000\000", 5)},
+ {0x40000001ull, string("\004@\000\000\001", 5)},
+ {0x7fffffffull, string("\004\177\377\377\377", 5)},
+ {0x80000000ull, string("\004\200\000\000\000", 5)},
+ {0x80000001ull, string("\004\200\000\000\001", 5)},
+ {0xffffffffull, string("\004\377\377\377\377", 5)},
+ {0x100000000ull, string("\005\001\000\000\000\000", 6)},
+ {0x100000001ull, string("\005\001\000\000\000\001", 6)},
+ {0x1ffffffffull, string("\005\001\377\377\377\377", 6)},
+ {0x200000000ull, string("\005\002\000\000\000\000", 6)},
+ {0x200000001ull, string("\005\002\000\000\000\001", 6)},
+ {0x3ffffffffull, string("\005\003\377\377\377\377", 6)},
+ {0x400000000ull, string("\005\004\000\000\000\000", 6)},
+ {0x400000001ull, string("\005\004\000\000\000\001", 6)},
+ {0x7ffffffffull, string("\005\007\377\377\377\377", 6)},
+ {0x800000000ull, string("\005\010\000\000\000\000", 6)},
+ {0x800000001ull, string("\005\010\000\000\000\001", 6)},
+ {0xfffffffffull, string("\005\017\377\377\377\377", 6)},
+ {0x1000000000ull, string("\005\020\000\000\000\000", 6)},
+ {0x1000000001ull, string("\005\020\000\000\000\001", 6)},
+ {0x1fffffffffull, string("\005\037\377\377\377\377", 6)},
+ {0x2000000000ull, string("\005 \000\000\000\000", 6)},
+ {0x2000000001ull, string("\005 \000\000\000\001", 6)},
+ {0x3fffffffffull, string("\005?\377\377\377\377", 6)},
+ {0x4000000000ull, string("\005@\000\000\000\000", 6)},
+ {0x4000000001ull, string("\005@\000\000\000\001", 6)},
+ {0x7fffffffffull, string("\005\177\377\377\377\377", 6)},
+ {0x8000000000ull, string("\005\200\000\000\000\000", 6)},
+ {0x8000000001ull, string("\005\200\000\000\000\001", 6)},
+ {0xffffffffffull, string("\005\377\377\377\377\377", 6)},
+ {0x10000000000ull, string("\006\001\000\000\000\000\000", 7)},
+ {0x10000000001ull, string("\006\001\000\000\000\000\001", 7)},
+ {0x1ffffffffffull, string("\006\001\377\377\377\377\377", 7)},
+ {0x20000000000ull, string("\006\002\000\000\000\000\000", 7)},
+ {0x20000000001ull, string("\006\002\000\000\000\000\001", 7)},
+ {0x3ffffffffffull, string("\006\003\377\377\377\377\377", 7)},
+ {0x40000000000ull, string("\006\004\000\000\000\000\000", 7)},
+ {0x40000000001ull, string("\006\004\000\000\000\000\001", 7)},
+ {0x7ffffffffffull, string("\006\007\377\377\377\377\377", 7)},
+ {0x80000000000ull, string("\006\010\000\000\000\000\000", 7)},
+ {0x80000000001ull, string("\006\010\000\000\000\000\001", 7)},
+ {0xfffffffffffull, string("\006\017\377\377\377\377\377", 7)},
+ {0x100000000000ull, string("\006\020\000\000\000\000\000", 7)},
+ {0x100000000001ull, string("\006\020\000\000\000\000\001", 7)},
+ {0x1fffffffffffull, string("\006\037\377\377\377\377\377", 7)},
+ {0x200000000000ull, string("\006 \000\000\000\000\000", 7)},
+ {0x200000000001ull, string("\006 \000\000\000\000\001", 7)},
+ {0x3fffffffffffull, string("\006?\377\377\377\377\377", 7)},
+ {0x400000000000ull, string("\006@\000\000\000\000\000", 7)},
+ {0x400000000001ull, string("\006@\000\000\000\000\001", 7)},
+ {0x7fffffffffffull, string("\006\177\377\377\377\377\377", 7)},
+ {0x800000000000ull, string("\006\200\000\000\000\000\000", 7)},
+ {0x800000000001ull, string("\006\200\000\000\000\000\001", 7)},
+ {0xffffffffffffull, string("\006\377\377\377\377\377\377", 7)},
+ {0x1000000000000ull, string("\007\001\000\000\000\000\000\000", 8)},
+ {0x1000000000001ull, string("\007\001\000\000\000\000\000\001", 8)},
+ {0x1ffffffffffffull, string("\007\001\377\377\377\377\377\377", 8)},
+ {0x2000000000000ull, string("\007\002\000\000\000\000\000\000", 8)},
+ {0x2000000000001ull, string("\007\002\000\000\000\000\000\001", 8)},
+ {0x3ffffffffffffull, string("\007\003\377\377\377\377\377\377", 8)},
+ {0x4000000000000ull, string("\007\004\000\000\000\000\000\000", 8)},
+ {0x4000000000001ull, string("\007\004\000\000\000\000\000\001", 8)},
+ {0x7ffffffffffffull, string("\007\007\377\377\377\377\377\377", 8)},
+ {0x8000000000000ull, string("\007\010\000\000\000\000\000\000", 8)},
+ {0x8000000000001ull, string("\007\010\000\000\000\000\000\001", 8)},
+ {0xfffffffffffffull, string("\007\017\377\377\377\377\377\377", 8)},
+ {0x10000000000000ull, string("\007\020\000\000\000\000\000\000", 8)},
+ {0x10000000000001ull, string("\007\020\000\000\000\000\000\001", 8)},
+ {0x1fffffffffffffull, string("\007\037\377\377\377\377\377\377", 8)},
+ {0x20000000000000ull, string("\007 \000\000\000\000\000\000", 8)},
+ {0x20000000000001ull, string("\007 \000\000\000\000\000\001", 8)},
+ {0x3fffffffffffffull, string("\007?\377\377\377\377\377\377", 8)},
+ {0x40000000000000ull, string("\007@\000\000\000\000\000\000", 8)},
+ {0x40000000000001ull, string("\007@\000\000\000\000\000\001", 8)},
+ {0x7fffffffffffffull, string("\007\177\377\377\377\377\377\377", 8)},
+ {0x80000000000000ull, string("\007\200\000\000\000\000\000\000", 8)},
+ {0x80000000000001ull, string("\007\200\000\000\000\000\000\001", 8)},
+ {0xffffffffffffffull, string("\007\377\377\377\377\377\377\377", 8)},
+ {0x100000000000000ull, string("\010\001\000\000\000\000\000\000\000", 9)},
+ {0x100000000000001ull, string("\010\001\000\000\000\000\000\000\001", 9)},
+ {0x1ffffffffffffffull, string("\010\001\377\377\377\377\377\377\377", 9)},
+ {0x200000000000000ull, string("\010\002\000\000\000\000\000\000\000", 9)},
+ {0x200000000000001ull, string("\010\002\000\000\000\000\000\000\001", 9)},
+ {0x3ffffffffffffffull, string("\010\003\377\377\377\377\377\377\377", 9)},
+ {0x400000000000000ull, string("\010\004\000\000\000\000\000\000\000", 9)},
+ {0x400000000000001ull, string("\010\004\000\000\000\000\000\000\001", 9)},
+ {0x7ffffffffffffffull, string("\010\007\377\377\377\377\377\377\377", 9)},
+ {0x800000000000000ull, string("\010\010\000\000\000\000\000\000\000", 9)},
+ {0x800000000000001ull, string("\010\010\000\000\000\000\000\000\001", 9)},
+ {0xfffffffffffffffull, string("\010\017\377\377\377\377\377\377\377", 9)},
+ {0x1000000000000000ull,
+ string("\010\020\000\000\000\000\000\000\000", 9)},
+ {0x1000000000000001ull,
+ string("\010\020\000\000\000\000\000\000\001", 9)},
+ {0x1fffffffffffffffull,
+ string("\010\037\377\377\377\377\377\377\377", 9)},
+ {0x2000000000000000ull, string("\010 \000\000\000\000\000\000\000", 9)},
+ {0x2000000000000001ull, string("\010 \000\000\000\000\000\000\001", 9)},
+ {0x3fffffffffffffffull, string("\010?\377\377\377\377\377\377\377", 9)},
+ {0x4000000000000000ull, string("\010@\000\000\000\000\000\000\000", 9)},
+ {0x4000000000000001ull, string("\010@\000\000\000\000\000\000\001", 9)},
+ {0x7fffffffffffffffull,
+ string("\010\177\377\377\377\377\377\377\377", 9)},
+ {0x8000000000000000ull,
+ string("\010\200\000\000\000\000\000\000\000", 9)},
+ {0x8000000000000001ull,
+ string("\010\200\000\000\000\000\000\000\001", 9)},
+ };
+ for (const auto& t : data) {
+ uint64 num = t.first;
+ string result;
+ OrderedCode::WriteNumIncreasing(&result, num);
+ EXPECT_EQ(t.second, result) << std::hex << num;
+
+ StringPiece in = result;
+ uint64 decoded;
+ EXPECT_TRUE(OrderedCode::ReadNumIncreasing(&in, &decoded));
+ EXPECT_EQ(num, decoded);
+ EXPECT_EQ("", in);
+ }
+}
+
+TEST(EncodingIsExpected, Signed) {
+ std::vector<std::pair<int64, string>> data = {
+ {0ll, string("\200", 1)},
+ {1ll, string("\201", 1)},
+ {2ll, string("\202", 1)},
+ {1ll, string("\201", 1)},
+ {2ll, string("\202", 1)},
+ {3ll, string("\203", 1)},
+ {3ll, string("\203", 1)},
+ {4ll, string("\204", 1)},
+ {5ll, string("\205", 1)},
+ {7ll, string("\207", 1)},
+ {8ll, string("\210", 1)},
+ {9ll, string("\211", 1)},
+ {15ll, string("\217", 1)},
+ {16ll, string("\220", 1)},
+ {17ll, string("\221", 1)},
+ {31ll, string("\237", 1)},
+ {32ll, string("\240", 1)},
+ {33ll, string("\241", 1)},
+ {63ll, string("\277", 1)},
+ {64ll, string("\300@", 2)},
+ {65ll, string("\300A", 2)},
+ {127ll, string("\300\177", 2)},
+ {128ll, string("\300\200", 2)},
+ {129ll, string("\300\201", 2)},
+ {255ll, string("\300\377", 2)},
+ {256ll, string("\301\000", 2)},
+ {257ll, string("\301\001", 2)},
+ {511ll, string("\301\377", 2)},
+ {512ll, string("\302\000", 2)},
+ {513ll, string("\302\001", 2)},
+ {1023ll, string("\303\377", 2)},
+ {1024ll, string("\304\000", 2)},
+ {1025ll, string("\304\001", 2)},
+ {2047ll, string("\307\377", 2)},
+ {2048ll, string("\310\000", 2)},
+ {2049ll, string("\310\001", 2)},
+ {4095ll, string("\317\377", 2)},
+ {4096ll, string("\320\000", 2)},
+ {4097ll, string("\320\001", 2)},
+ {8191ll, string("\337\377", 2)},
+ {8192ll, string("\340 \000", 3)},
+ {8193ll, string("\340 \001", 3)},
+ {16383ll, string("\340?\377", 3)},
+ {16384ll, string("\340@\000", 3)},
+ {16385ll, string("\340@\001", 3)},
+ {32767ll, string("\340\177\377", 3)},
+ {32768ll, string("\340\200\000", 3)},
+ {32769ll, string("\340\200\001", 3)},
+ {65535ll, string("\340\377\377", 3)},
+ {65536ll, string("\341\000\000", 3)},
+ {65537ll, string("\341\000\001", 3)},
+ {131071ll, string("\341\377\377", 3)},
+ {131072ll, string("\342\000\000", 3)},
+ {131073ll, string("\342\000\001", 3)},
+ {262143ll, string("\343\377\377", 3)},
+ {262144ll, string("\344\000\000", 3)},
+ {262145ll, string("\344\000\001", 3)},
+ {524287ll, string("\347\377\377", 3)},
+ {524288ll, string("\350\000\000", 3)},
+ {524289ll, string("\350\000\001", 3)},
+ {1048575ll, string("\357\377\377", 3)},
+ {1048576ll, string("\360\020\000\000", 4)},
+ {1048577ll, string("\360\020\000\001", 4)},
+ {2097151ll, string("\360\037\377\377", 4)},
+ {2097152ll, string("\360 \000\000", 4)},
+ {2097153ll, string("\360 \000\001", 4)},
+ {4194303ll, string("\360?\377\377", 4)},
+ {4194304ll, string("\360@\000\000", 4)},
+ {4194305ll, string("\360@\000\001", 4)},
+ {8388607ll, string("\360\177\377\377", 4)},
+ {8388608ll, string("\360\200\000\000", 4)},
+ {8388609ll, string("\360\200\000\001", 4)},
+ {16777215ll, string("\360\377\377\377", 4)},
+ {16777216ll, string("\361\000\000\000", 4)},
+ {16777217ll, string("\361\000\000\001", 4)},
+ {33554431ll, string("\361\377\377\377", 4)},
+ {33554432ll, string("\362\000\000\000", 4)},
+ {33554433ll, string("\362\000\000\001", 4)},
+ {67108863ll, string("\363\377\377\377", 4)},
+ {67108864ll, string("\364\000\000\000", 4)},
+ {67108865ll, string("\364\000\000\001", 4)},
+ {134217727ll, string("\367\377\377\377", 4)},
+ {134217728ll, string("\370\010\000\000\000", 5)},
+ {134217729ll, string("\370\010\000\000\001", 5)},
+ {268435455ll, string("\370\017\377\377\377", 5)},
+ {268435456ll, string("\370\020\000\000\000", 5)},
+ {268435457ll, string("\370\020\000\000\001", 5)},
+ {536870911ll, string("\370\037\377\377\377", 5)},
+ {536870912ll, string("\370 \000\000\000", 5)},
+ {536870913ll, string("\370 \000\000\001", 5)},
+ {1073741823ll, string("\370?\377\377\377", 5)},
+ {1073741824ll, string("\370@\000\000\000", 5)},
+ {1073741825ll, string("\370@\000\000\001", 5)},
+ {2147483647ll, string("\370\177\377\377\377", 5)},
+ {2147483648ll, string("\370\200\000\000\000", 5)},
+ {2147483649ll, string("\370\200\000\000\001", 5)},
+ {4294967295ll, string("\370\377\377\377\377", 5)},
+ {4294967296ll, string("\371\000\000\000\000", 5)},
+ {4294967297ll, string("\371\000\000\000\001", 5)},
+ {8589934591ll, string("\371\377\377\377\377", 5)},
+ {8589934592ll, string("\372\000\000\000\000", 5)},
+ {8589934593ll, string("\372\000\000\000\001", 5)},
+ {17179869183ll, string("\373\377\377\377\377", 5)},
+ {17179869184ll, string("\374\004\000\000\000\000", 6)},
+ {17179869185ll, string("\374\004\000\000\000\001", 6)},
+ {34359738367ll, string("\374\007\377\377\377\377", 6)},
+ {34359738368ll, string("\374\010\000\000\000\000", 6)},
+ {34359738369ll, string("\374\010\000\000\000\001", 6)},
+ {68719476735ll, string("\374\017\377\377\377\377", 6)},
+ {68719476736ll, string("\374\020\000\000\000\000", 6)},
+ {68719476737ll, string("\374\020\000\000\000\001", 6)},
+ {137438953471ll, string("\374\037\377\377\377\377", 6)},
+ {137438953472ll, string("\374 \000\000\000\000", 6)},
+ {137438953473ll, string("\374 \000\000\000\001", 6)},
+ {274877906943ll, string("\374?\377\377\377\377", 6)},
+ {274877906944ll, string("\374@\000\000\000\000", 6)},
+ {274877906945ll, string("\374@\000\000\000\001", 6)},
+ {549755813887ll, string("\374\177\377\377\377\377", 6)},
+ {549755813888ll, string("\374\200\000\000\000\000", 6)},
+ {549755813889ll, string("\374\200\000\000\000\001", 6)},
+ {1099511627775ll, string("\374\377\377\377\377\377", 6)},
+ {1099511627776ll, string("\375\000\000\000\000\000", 6)},
+ {1099511627777ll, string("\375\000\000\000\000\001", 6)},
+ {2199023255551ll, string("\375\377\377\377\377\377", 6)},
+ {2199023255552ll, string("\376\002\000\000\000\000\000", 7)},
+ {2199023255553ll, string("\376\002\000\000\000\000\001", 7)},
+ {4398046511103ll, string("\376\003\377\377\377\377\377", 7)},
+ {4398046511104ll, string("\376\004\000\000\000\000\000", 7)},
+ {4398046511105ll, string("\376\004\000\000\000\000\001", 7)},
+ {8796093022207ll, string("\376\007\377\377\377\377\377", 7)},
+ {8796093022208ll, string("\376\010\000\000\000\000\000", 7)},
+ {8796093022209ll, string("\376\010\000\000\000\000\001", 7)},
+ {17592186044415ll, string("\376\017\377\377\377\377\377", 7)},
+ {17592186044416ll, string("\376\020\000\000\000\000\000", 7)},
+ {17592186044417ll, string("\376\020\000\000\000\000\001", 7)},
+ {35184372088831ll, string("\376\037\377\377\377\377\377", 7)},
+ {35184372088832ll, string("\376 \000\000\000\000\000", 7)},
+ {35184372088833ll, string("\376 \000\000\000\000\001", 7)},
+ {70368744177663ll, string("\376?\377\377\377\377\377", 7)},
+ {70368744177664ll, string("\376@\000\000\000\000\000", 7)},
+ {70368744177665ll, string("\376@\000\000\000\000\001", 7)},
+ {140737488355327ll, string("\376\177\377\377\377\377\377", 7)},
+ {140737488355328ll, string("\376\200\000\000\000\000\000", 7)},
+ {140737488355329ll, string("\376\200\000\000\000\000\001", 7)},
+ {281474976710655ll, string("\376\377\377\377\377\377\377", 7)},
+ {281474976710656ll, string("\377\001\000\000\000\000\000\000", 8)},
+ {281474976710657ll, string("\377\001\000\000\000\000\000\001", 8)},
+ {562949953421311ll, string("\377\001\377\377\377\377\377\377", 8)},
+ {562949953421312ll, string("\377\002\000\000\000\000\000\000", 8)},
+ {562949953421313ll, string("\377\002\000\000\000\000\000\001", 8)},
+ {1125899906842623ll, string("\377\003\377\377\377\377\377\377", 8)},
+ {1125899906842624ll, string("\377\004\000\000\000\000\000\000", 8)},
+ {1125899906842625ll, string("\377\004\000\000\000\000\000\001", 8)},
+ {2251799813685247ll, string("\377\007\377\377\377\377\377\377", 8)},
+ {2251799813685248ll, string("\377\010\000\000\000\000\000\000", 8)},
+ {2251799813685249ll, string("\377\010\000\000\000\000\000\001", 8)},
+ {4503599627370495ll, string("\377\017\377\377\377\377\377\377", 8)},
+ {4503599627370496ll, string("\377\020\000\000\000\000\000\000", 8)},
+ {4503599627370497ll, string("\377\020\000\000\000\000\000\001", 8)},
+ {9007199254740991ll, string("\377\037\377\377\377\377\377\377", 8)},
+ {9007199254740992ll, string("\377 \000\000\000\000\000\000", 8)},
+ {9007199254740993ll, string("\377 \000\000\000\000\000\001", 8)},
+ {18014398509481983ll, string("\377?\377\377\377\377\377\377", 8)},
+ {18014398509481984ll, string("\377@\000\000\000\000\000\000", 8)},
+ {18014398509481985ll, string("\377@\000\000\000\000\000\001", 8)},
+ {36028797018963967ll, string("\377\177\377\377\377\377\377\377", 8)},
+ {36028797018963968ll, string("\377\200\200\000\000\000\000\000\000", 9)},
+ {36028797018963969ll, string("\377\200\200\000\000\000\000\000\001", 9)},
+ {72057594037927935ll, string("\377\200\377\377\377\377\377\377\377", 9)},
+ {72057594037927936ll, string("\377\201\000\000\000\000\000\000\000", 9)},
+ {72057594037927937ll, string("\377\201\000\000\000\000\000\000\001", 9)},
+ {144115188075855871ll, string("\377\201\377\377\377\377\377\377\377", 9)},
+ {144115188075855872ll, string("\377\202\000\000\000\000\000\000\000", 9)},
+ {144115188075855873ll, string("\377\202\000\000\000\000\000\000\001", 9)},
+ {288230376151711743ll, string("\377\203\377\377\377\377\377\377\377", 9)},
+ {288230376151711744ll, string("\377\204\000\000\000\000\000\000\000", 9)},
+ {288230376151711745ll, string("\377\204\000\000\000\000\000\000\001", 9)},
+ {576460752303423487ll, string("\377\207\377\377\377\377\377\377\377", 9)},
+ {576460752303423488ll, string("\377\210\000\000\000\000\000\000\000", 9)},
+ {576460752303423489ll, string("\377\210\000\000\000\000\000\000\001", 9)},
+ {1152921504606846975ll,
+ string("\377\217\377\377\377\377\377\377\377", 9)},
+ {1152921504606846976ll,
+ string("\377\220\000\000\000\000\000\000\000", 9)},
+ {1152921504606846977ll,
+ string("\377\220\000\000\000\000\000\000\001", 9)},
+ {2305843009213693951ll,
+ string("\377\237\377\377\377\377\377\377\377", 9)},
+ {2305843009213693952ll,
+ string("\377\240\000\000\000\000\000\000\000", 9)},
+ {2305843009213693953ll,
+ string("\377\240\000\000\000\000\000\000\001", 9)},
+ {4611686018427387903ll,
+ string("\377\277\377\377\377\377\377\377\377", 9)},
+ {4611686018427387904ll,
+ string("\377\300@\000\000\000\000\000\000\000", 10)},
+ {4611686018427387905ll,
+ string("\377\300@\000\000\000\000\000\000\001", 10)},
+ {9223372036854775807ll,
+ string("\377\300\177\377\377\377\377\377\377\377", 10)},
+ {-9223372036854775807ll,
+ string("\000?\200\000\000\000\000\000\000\001", 10)},
+ {0ll, string("\200", 1)},
+ {-1ll, string("\177", 1)},
+ {-2ll, string("~", 1)},
+ {-1ll, string("\177", 1)},
+ {-2ll, string("~", 1)},
+ {-3ll, string("}", 1)},
+ {-3ll, string("}", 1)},
+ {-4ll, string("|", 1)},
+ {-5ll, string("{", 1)},
+ {-7ll, string("y", 1)},
+ {-8ll, string("x", 1)},
+ {-9ll, string("w", 1)},
+ {-15ll, string("q", 1)},
+ {-16ll, string("p", 1)},
+ {-17ll, string("o", 1)},
+ {-31ll, string("a", 1)},
+ {-32ll, string("`", 1)},
+ {-33ll, string("_", 1)},
+ {-63ll, string("A", 1)},
+ {-64ll, string("@", 1)},
+ {-65ll, string("?\277", 2)},
+ {-127ll, string("?\201", 2)},
+ {-128ll, string("?\200", 2)},
+ {-129ll, string("?\177", 2)},
+ {-255ll, string("?\001", 2)},
+ {-256ll, string("?\000", 2)},
+ {-257ll, string(">\377", 2)},
+ {-511ll, string(">\001", 2)},
+ {-512ll, string(">\000", 2)},
+ {-513ll, string("=\377", 2)},
+ {-1023ll, string("<\001", 2)},
+ {-1024ll, string("<\000", 2)},
+ {-1025ll, string(";\377", 2)},
+ {-2047ll, string("8\001", 2)},
+ {-2048ll, string("8\000", 2)},
+ {-2049ll, string("7\377", 2)},
+ {-4095ll, string("0\001", 2)},
+ {-4096ll, string("0\000", 2)},
+ {-4097ll, string("/\377", 2)},
+ {-8191ll, string(" \001", 2)},
+ {-8192ll, string(" \000", 2)},
+ {-8193ll, string("\037\337\377", 3)},
+ {-16383ll, string("\037\300\001", 3)},
+ {-16384ll, string("\037\300\000", 3)},
+ {-16385ll, string("\037\277\377", 3)},
+ {-32767ll, string("\037\200\001", 3)},
+ {-32768ll, string("\037\200\000", 3)},
+ {-32769ll, string("\037\177\377", 3)},
+ {-65535ll, string("\037\000\001", 3)},
+ {-65536ll, string("\037\000\000", 3)},
+ {-65537ll, string("\036\377\377", 3)},
+ {-131071ll, string("\036\000\001", 3)},
+ {-131072ll, string("\036\000\000", 3)},
+ {-131073ll, string("\035\377\377", 3)},
+ {-262143ll, string("\034\000\001", 3)},
+ {-262144ll, string("\034\000\000", 3)},
+ {-262145ll, string("\033\377\377", 3)},
+ {-524287ll, string("\030\000\001", 3)},
+ {-524288ll, string("\030\000\000", 3)},
+ {-524289ll, string("\027\377\377", 3)},
+ {-1048575ll, string("\020\000\001", 3)},
+ {-1048576ll, string("\020\000\000", 3)},
+ {-1048577ll, string("\017\357\377\377", 4)},
+ {-2097151ll, string("\017\340\000\001", 4)},
+ {-2097152ll, string("\017\340\000\000", 4)},
+ {-2097153ll, string("\017\337\377\377", 4)},
+ {-4194303ll, string("\017\300\000\001", 4)},
+ {-4194304ll, string("\017\300\000\000", 4)},
+ {-4194305ll, string("\017\277\377\377", 4)},
+ {-8388607ll, string("\017\200\000\001", 4)},
+ {-8388608ll, string("\017\200\000\000", 4)},
+ {-8388609ll, string("\017\177\377\377", 4)},
+ {-16777215ll, string("\017\000\000\001", 4)},
+ {-16777216ll, string("\017\000\000\000", 4)},
+ {-16777217ll, string("\016\377\377\377", 4)},
+ {-33554431ll, string("\016\000\000\001", 4)},
+ {-33554432ll, string("\016\000\000\000", 4)},
+ {-33554433ll, string("\r\377\377\377", 4)},
+ {-67108863ll, string("\014\000\000\001", 4)},
+ {-67108864ll, string("\014\000\000\000", 4)},
+ {-67108865ll, string("\013\377\377\377", 4)},
+ {-134217727ll, string("\010\000\000\001", 4)},
+ {-134217728ll, string("\010\000\000\000", 4)},
+ {-134217729ll, string("\007\367\377\377\377", 5)},
+ {-268435455ll, string("\007\360\000\000\001", 5)},
+ {-268435456ll, string("\007\360\000\000\000", 5)},
+ {-268435457ll, string("\007\357\377\377\377", 5)},
+ {-536870911ll, string("\007\340\000\000\001", 5)},
+ {-536870912ll, string("\007\340\000\000\000", 5)},
+ {-536870913ll, string("\007\337\377\377\377", 5)},
+ {-1073741823ll, string("\007\300\000\000\001", 5)},
+ {-1073741824ll, string("\007\300\000\000\000", 5)},
+ {-1073741825ll, string("\007\277\377\377\377", 5)},
+ {-2147483647ll, string("\007\200\000\000\001", 5)},
+ {-2147483648ll, string("\007\200\000\000\000", 5)},
+ {-2147483649ll, string("\007\177\377\377\377", 5)},
+ {-4294967295ll, string("\007\000\000\000\001", 5)},
+ {-4294967296ll, string("\007\000\000\000\000", 5)},
+ {-4294967297ll, string("\006\377\377\377\377", 5)},
+ {-8589934591ll, string("\006\000\000\000\001", 5)},
+ {-8589934592ll, string("\006\000\000\000\000", 5)},
+ {-8589934593ll, string("\005\377\377\377\377", 5)},
+ {-17179869183ll, string("\004\000\000\000\001", 5)},
+ {-17179869184ll, string("\004\000\000\000\000", 5)},
+ {-17179869185ll, string("\003\373\377\377\377\377", 6)},
+ {-34359738367ll, string("\003\370\000\000\000\001", 6)},
+ {-34359738368ll, string("\003\370\000\000\000\000", 6)},
+ {-34359738369ll, string("\003\367\377\377\377\377", 6)},
+ {-68719476735ll, string("\003\360\000\000\000\001", 6)},
+ {-68719476736ll, string("\003\360\000\000\000\000", 6)},
+ {-68719476737ll, string("\003\357\377\377\377\377", 6)},
+ {-137438953471ll, string("\003\340\000\000\000\001", 6)},
+ {-137438953472ll, string("\003\340\000\000\000\000", 6)},
+ {-137438953473ll, string("\003\337\377\377\377\377", 6)},
+ {-274877906943ll, string("\003\300\000\000\000\001", 6)},
+ {-274877906944ll, string("\003\300\000\000\000\000", 6)},
+ {-274877906945ll, string("\003\277\377\377\377\377", 6)},
+ {-549755813887ll, string("\003\200\000\000\000\001", 6)},
+ {-549755813888ll, string("\003\200\000\000\000\000", 6)},
+ {-549755813889ll, string("\003\177\377\377\377\377", 6)},
+ {-1099511627775ll, string("\003\000\000\000\000\001", 6)},
+ {-1099511627776ll, string("\003\000\000\000\000\000", 6)},
+ {-1099511627777ll, string("\002\377\377\377\377\377", 6)},
+ {-2199023255551ll, string("\002\000\000\000\000\001", 6)},
+ {-2199023255552ll, string("\002\000\000\000\000\000", 6)},
+ {-2199023255553ll, string("\001\375\377\377\377\377\377", 7)},
+ {-4398046511103ll, string("\001\374\000\000\000\000\001", 7)},
+ {-4398046511104ll, string("\001\374\000\000\000\000\000", 7)},
+ {-4398046511105ll, string("\001\373\377\377\377\377\377", 7)},
+ {-8796093022207ll, string("\001\370\000\000\000\000\001", 7)},
+ {-8796093022208ll, string("\001\370\000\000\000\000\000", 7)},
+ {-8796093022209ll, string("\001\367\377\377\377\377\377", 7)},
+ {-17592186044415ll, string("\001\360\000\000\000\000\001", 7)},
+ {-17592186044416ll, string("\001\360\000\000\000\000\000", 7)},
+ {-17592186044417ll, string("\001\357\377\377\377\377\377", 7)},
+ {-35184372088831ll, string("\001\340\000\000\000\000\001", 7)},
+ {-35184372088832ll, string("\001\340\000\000\000\000\000", 7)},
+ {-35184372088833ll, string("\001\337\377\377\377\377\377", 7)},
+ {-70368744177663ll, string("\001\300\000\000\000\000\001", 7)},
+ {-70368744177664ll, string("\001\300\000\000\000\000\000", 7)},
+ {-70368744177665ll, string("\001\277\377\377\377\377\377", 7)},
+ {-140737488355327ll, string("\001\200\000\000\000\000\001", 7)},
+ {-140737488355328ll, string("\001\200\000\000\000\000\000", 7)},
+ {-140737488355329ll, string("\001\177\377\377\377\377\377", 7)},
+ {-281474976710655ll, string("\001\000\000\000\000\000\001", 7)},
+ {-281474976710656ll, string("\001\000\000\000\000\000\000", 7)},
+ {-281474976710657ll, string("\000\376\377\377\377\377\377\377", 8)},
+ {-562949953421311ll, string("\000\376\000\000\000\000\000\001", 8)},
+ {-562949953421312ll, string("\000\376\000\000\000\000\000\000", 8)},
+ {-562949953421313ll, string("\000\375\377\377\377\377\377\377", 8)},
+ {-1125899906842623ll, string("\000\374\000\000\000\000\000\001", 8)},
+ {-1125899906842624ll, string("\000\374\000\000\000\000\000\000", 8)},
+ {-1125899906842625ll, string("\000\373\377\377\377\377\377\377", 8)},
+ {-2251799813685247ll, string("\000\370\000\000\000\000\000\001", 8)},
+ {-2251799813685248ll, string("\000\370\000\000\000\000\000\000", 8)},
+ {-2251799813685249ll, string("\000\367\377\377\377\377\377\377", 8)},
+ {-4503599627370495ll, string("\000\360\000\000\000\000\000\001", 8)},
+ {-4503599627370496ll, string("\000\360\000\000\000\000\000\000", 8)},
+ {-4503599627370497ll, string("\000\357\377\377\377\377\377\377", 8)},
+ {-9007199254740991ll, string("\000\340\000\000\000\000\000\001", 8)},
+ {-9007199254740992ll, string("\000\340\000\000\000\000\000\000", 8)},
+ {-9007199254740993ll, string("\000\337\377\377\377\377\377\377", 8)},
+ {-18014398509481983ll, string("\000\300\000\000\000\000\000\001", 8)},
+ {-18014398509481984ll, string("\000\300\000\000\000\000\000\000", 8)},
+ {-18014398509481985ll, string("\000\277\377\377\377\377\377\377", 8)},
+ {-36028797018963967ll, string("\000\200\000\000\000\000\000\001", 8)},
+ {-36028797018963968ll, string("\000\200\000\000\000\000\000\000", 8)},
+ {-36028797018963969ll, string("\000\177\177\377\377\377\377\377\377", 9)},
+ {-72057594037927935ll, string("\000\177\000\000\000\000\000\000\001", 9)},
+ {-72057594037927936ll, string("\000\177\000\000\000\000\000\000\000", 9)},
+ {-72057594037927937ll, string("\000~\377\377\377\377\377\377\377", 9)},
+ {-144115188075855871ll, string("\000~\000\000\000\000\000\000\001", 9)},
+ {-144115188075855872ll, string("\000~\000\000\000\000\000\000\000", 9)},
+ {-144115188075855873ll, string("\000}\377\377\377\377\377\377\377", 9)},
+ {-288230376151711743ll, string("\000|\000\000\000\000\000\000\001", 9)},
+ {-288230376151711744ll, string("\000|\000\000\000\000\000\000\000", 9)},
+ {-288230376151711745ll, string("\000{\377\377\377\377\377\377\377", 9)},
+ {-576460752303423487ll, string("\000x\000\000\000\000\000\000\001", 9)},
+ {-576460752303423488ll, string("\000x\000\000\000\000\000\000\000", 9)},
+ {-576460752303423489ll, string("\000w\377\377\377\377\377\377\377", 9)},
+ {-1152921504606846975ll, string("\000p\000\000\000\000\000\000\001", 9)},
+ {-1152921504606846976ll, string("\000p\000\000\000\000\000\000\000", 9)},
+ {-1152921504606846977ll, string("\000o\377\377\377\377\377\377\377", 9)},
+ {-2305843009213693951ll, string("\000`\000\000\000\000\000\000\001", 9)},
+ {-2305843009213693952ll, string("\000`\000\000\000\000\000\000\000", 9)},
+ {-2305843009213693953ll, string("\000_\377\377\377\377\377\377\377", 9)},
+ {-4611686018427387903ll, string("\000@\000\000\000\000\000\000\001", 9)},
+ {-4611686018427387904ll, string("\000@\000\000\000\000\000\000\000", 9)},
+ {-4611686018427387905ll,
+ string("\000?\277\377\377\377\377\377\377\377", 10)},
+ {-9223372036854775807ll,
+ string("\000?\200\000\000\000\000\000\000\001", 10)},
+ {9223372036854775807ll,
+ string("\377\300\177\377\377\377\377\377\377\377", 10)},
+ };
+ for (const auto& t : data) {
+ int64 num = t.first;
+ string result;
+ OrderedCode::WriteSignedNumIncreasing(&result, num);
+ EXPECT_EQ(t.second, result) << std::hex << num;
+
+ StringPiece in = result;
+ int64 decoded;
+ EXPECT_TRUE(OrderedCode::ReadSignedNumIncreasing(&in, &decoded));
+ EXPECT_EQ(num, decoded);
+ EXPECT_EQ("", in);
+ }
+}
+
+static void BM_WriteString(int n, int len) {
+ testing::StopTiming();
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ string x;
+ for (int i = 0; i < len; i++) {
+ x += rnd.Uniform(256);
+ }
+ string y;
+
+ testing::BytesProcessed(n * len);
+ testing::StartTiming();
+ while (n-- > 0) {
+ y.clear();
+ OCWriteToString<string>(&y, x);
+ }
+}
+
+static void BM_ReadString(int n, int len) {
+ testing::StopTiming();
+ random::PhiloxRandom philox(301, 17);
+ random::SimplePhilox rnd(&philox);
+ string x;
+ for (int i = 0; i < len; i++) {
+ x += rnd.Uniform(256);
+ }
+ string data;
+ OCWriteToString<string>(&data, x);
+ string result;
+
+ testing::BytesProcessed(n * len);
+ testing::StartTiming();
+ while (n-- > 0) {
+ result.clear();
+ StringPiece s = data;
+ OCRead<string>(&s, &result);
+ }
+}
+
+static void BM_WriteStringIncreasing(int n, int len) { BM_WriteString(n, len); }
+static void BM_ReadStringIncreasing(int n, int len) { BM_ReadString(n, len); }
+
+BENCHMARK(BM_WriteStringIncreasing)->Range(0, 1024);
+BENCHMARK(BM_ReadStringIncreasing)->Range(0, 1024);
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/str_util.cc b/tensorflow/core/lib/strings/str_util.cc
new file mode 100644
index 0000000000..cccd50c7ff
--- /dev/null
+++ b/tensorflow/core/lib/strings/str_util.cc
@@ -0,0 +1,312 @@
+#include "tensorflow/core/lib/strings/str_util.h"
+#include <ctype.h>
+
+namespace tensorflow {
+namespace str_util {
+
+static char hex_char[] = "0123456789abcdef";
+
+string CEscape(const string& src) {
+ string dest;
+
+ for (unsigned char c : src) {
+ switch (c) {
+ case '\n':
+ dest.append("\\n");
+ break;
+ case '\r':
+ dest.append("\\r");
+ break;
+ case '\t':
+ dest.append("\\t");
+ break;
+ case '\"':
+ dest.append("\\\"");
+ break;
+ case '\'':
+ dest.append("\\'");
+ break;
+ case '\\':
+ dest.append("\\\\");
+ break;
+ default:
+ // Note that if we emit \xNN and the src character after that is a hex
+ // digit then that digit must be escaped too to prevent it being
+ // interpreted as part of the character code by C.
+ if ((c >= 0x80) || !isprint(c)) {
+ dest.append("\\");
+ dest.push_back(hex_char[c / 64]);
+ dest.push_back(hex_char[(c % 64) / 8]);
+ dest.push_back(hex_char[c % 8]);
+ } else {
+ dest.push_back(c);
+ break;
+ }
+ }
+ }
+
+ return dest;
+}
+
+namespace { // Private helpers for CUnescape().
+
+inline bool is_octal_digit(unsigned char c) { return c >= '0' && c <= '7'; }
+
+inline bool ascii_isxdigit(unsigned char c) {
+ return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') ||
+ (c >= 'A' && c <= 'F');
+}
+
+inline int hex_digit_to_int(char c) {
+ int x = static_cast<unsigned char>(c);
+ if (x > '9') {
+ x += 9;
+ }
+ return x & 0xf;
+}
+
+bool CUnescapeInternal(StringPiece source, char* dest, int* dest_len,
+ string* error) {
+ char* d = dest;
+ const char* p = source.data();
+ const char* end = source.end();
+ const char* last_byte = end - 1;
+
+ // Small optimization for case where source = dest and there's no escaping
+ while (p == d && p < end && *p != '\\') p++, d++;
+
+ while (p < end) {
+ if (*p != '\\') {
+ *d++ = *p++;
+ } else {
+ if (++p > last_byte) { // skip past the '\\'
+ if (error) *error = "String cannot end with \\";
+ return false;
+ }
+ switch (*p) {
+ case 'a':
+ *d++ = '\a';
+ break;
+ case 'b':
+ *d++ = '\b';
+ break;
+ case 'f':
+ *d++ = '\f';
+ break;
+ case 'n':
+ *d++ = '\n';
+ break;
+ case 'r':
+ *d++ = '\r';
+ break;
+ case 't':
+ *d++ = '\t';
+ break;
+ case 'v':
+ *d++ = '\v';
+ break;
+ case '\\':
+ *d++ = '\\';
+ break;
+ case '?':
+ *d++ = '\?';
+ break; // \? Who knew?
+ case '\'':
+ *d++ = '\'';
+ break;
+ case '"':
+ *d++ = '\"';
+ break;
+ case '0':
+ case '1':
+ case '2':
+ case '3': // octal digit: 1 to 3 digits
+ case '4':
+ case '5':
+ case '6':
+ case '7': {
+ const char* octal_start = p;
+ unsigned int ch = *p - '0';
+ if (p < last_byte && is_octal_digit(p[1])) ch = ch * 8 + *++p - '0';
+ if (p < last_byte && is_octal_digit(p[1]))
+ ch = ch * 8 + *++p - '0'; // now points at last digit
+ if (ch > 0xff) {
+ if (error) {
+ *error = "Value of \\" +
+ string(octal_start, p + 1 - octal_start) +
+ " exceeds 0xff";
+ }
+ return false;
+ }
+ *d++ = ch;
+ break;
+ }
+ case 'x':
+ case 'X': {
+ if (p >= last_byte) {
+ if (error) *error = "String cannot end with \\x";
+ return false;
+ } else if (!ascii_isxdigit(p[1])) {
+ if (error) *error = "\\x cannot be followed by a non-hex digit";
+ return false;
+ }
+ unsigned int ch = 0;
+ const char* hex_start = p;
+ while (p < last_byte && ascii_isxdigit(p[1]))
+ // Arbitrarily many hex digits
+ ch = (ch << 4) + hex_digit_to_int(*++p);
+ if (ch > 0xFF) {
+ if (error) {
+ *error = "Value of \\" + string(hex_start, p + 1 - hex_start) +
+ " exceeds 0xff";
+ }
+ return false;
+ }
+ *d++ = ch;
+ break;
+ }
+ default: {
+ if (error) *error = string("Unknown escape sequence: \\") + *p;
+ return false;
+ }
+ }
+ p++; // read past letter we escaped
+ }
+ }
+ *dest_len = d - dest;
+ return true;
+}
+
+} // namespace
+
+bool CUnescape(StringPiece source, string* dest, string* error) {
+ dest->resize(source.size());
+ int dest_size;
+ if (!CUnescapeInternal(source, const_cast<char*>(dest->data()), &dest_size,
+ error)) {
+ return false;
+ }
+ dest->erase(dest_size);
+ return true;
+}
+
+bool NumericParse32(const string& text, int32* val) {
+ // Slow, but this code is not performance critical, and this
+ // doesn't bring in any new dependencies
+ char junk;
+ if (sscanf(text.c_str(), "%d%c", val, &junk) == 1) {
+ return true;
+ } else {
+ return false;
+ }
+}
+
+void StripTrailingWhitespace(string* s) {
+ string::size_type i;
+ for (i = s->size(); i > 0 && isspace((*s)[i - 1]); --i) {
+ }
+ s->resize(i);
+}
+
+// Return lower-cased version of s.
+string Lowercase(StringPiece s) {
+ string result(s.data(), s.size());
+ for (char& c : result) {
+ c = tolower(c);
+ }
+ return result;
+}
+
+// Return upper-cased version of s.
+string Uppercase(StringPiece s) {
+ string result(s.data(), s.size());
+ for (char& c : result) {
+ c = toupper(c);
+ }
+ return result;
+}
+
+void TitlecaseString(string* s, StringPiece delimiters) {
+ bool upper = true;
+ for (string::iterator ss = s->begin(); ss != s->end(); ++ss) {
+ if (upper) {
+ *ss = toupper(*ss);
+ }
+ upper = (delimiters.find(*ss) != StringPiece::npos);
+ }
+}
+
+size_t RemoveLeadingWhitespace(StringPiece* text) {
+ size_t count = 0;
+ const char* ptr = text->data();
+ while (count < text->size() && isspace(*ptr)) {
+ count++;
+ ptr++;
+ }
+ text->remove_prefix(count);
+ return count;
+}
+
+size_t RemoveTrailingWhitespace(StringPiece* text) {
+ size_t count = 0;
+ const char* ptr = text->data() + text->size() - 1;
+ while (count < text->size() && isspace(*ptr)) {
+ ++count;
+ --ptr;
+ }
+ text->remove_suffix(count);
+ return count;
+}
+
+size_t RemoveWhitespaceContext(StringPiece* text) {
+ // use RemoveLeadingWhitespace() and RemoveTrailingWhitespace() to do the job
+ return (RemoveLeadingWhitespace(text) + RemoveTrailingWhitespace(text));
+}
+
+bool ConsumePrefix(StringPiece* s, StringPiece expected) {
+ if (s->starts_with(expected)) {
+ s->remove_prefix(expected.size());
+ return true;
+ }
+ return false;
+}
+
+bool ConsumeLeadingDigits(StringPiece* s, uint64* val) {
+ const char* p = s->data();
+ const char* limit = p + s->size();
+ uint64 v = 0;
+ while (p < limit) {
+ const char c = *p;
+ if (c < '0' || c > '9') break;
+ uint64 new_v = (v * 10) + (c - '0');
+ if (new_v < v) {
+ // Overflow occurred
+ return false;
+ }
+ v = new_v;
+ p++;
+ }
+ if (p > s->data()) {
+ // Consume some digits
+ s->remove_prefix(p - s->data());
+ *val = v;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+bool SplitAndParseAsInts(StringPiece text, char delim,
+ std::vector<int32>* result) {
+ result->clear();
+ std::vector<string> num_strings = Split(text, delim);
+ for (const auto& s : num_strings) {
+ int32 num;
+ if (!NumericParse32(s, &num)) return false;
+ result->push_back(num);
+ }
+ return true;
+}
+
+} // namespace str_util
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h
new file mode 100644
index 0000000000..34ea462b2d
--- /dev/null
+++ b/tensorflow/core/lib/strings/str_util.h
@@ -0,0 +1,149 @@
+#ifndef TENSORFLOW_LIB_STRINGS_STR_UTIL_H_
+#define TENSORFLOW_LIB_STRINGS_STR_UTIL_H_
+
+#include <string>
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+// Basic string utility routines
+namespace tensorflow {
+namespace str_util {
+
+// Returns a version of 'src' where unprintable characters have been
+// escaped using C-style escape sequences.
+string CEscape(const string& src);
+
+// Copies "source" to "dest", rewriting C-style escape sequences --
+// '\n', '\r', '\\', '\ooo', etc -- to their ASCII equivalents.
+//
+// Errors: Sets the description of the first encountered error in
+// 'error'. To disable error reporting, set 'error' to NULL.
+//
+// NOTE: Does not support \u or \U!
+bool CUnescape(StringPiece source, string* dest, string* error);
+
+// If "text" can be successfully parsed as the ASCII representation of
+// an integer, sets "*val" to the value and returns true. Otherwise,
+// returns false.
+bool NumericParse32(const string& text, int32* val);
+
+// Removes any trailing whitespace from "*s".
+void StripTrailingWhitespace(string* s);
+
+// Removes leading ascii_isspace() characters.
+// Returns number of characters removed.
+size_t RemoveLeadingWhitespace(StringPiece* text);
+
+// Removes trailing ascii_isspace() characters.
+// Returns number of characters removed.
+size_t RemoveTrailingWhitespace(StringPiece* text);
+
+// Removes leading and trailing ascii_isspace() chars.
+// Returns number of chars removed.
+size_t RemoveWhitespaceContext(StringPiece* text);
+
+// Consume a leading positive integer value. If any digits were
+// found, store the value of the leading unsigned number in "*val",
+// advance "*s" past the consumed number, and return true. If
+// overflow occurred, returns false. Otherwise, returns false.
+bool ConsumeLeadingDigits(StringPiece* s, uint64* val);
+
+// If "*s" starts with "expected", consume it and return true.
+// Otherwise, return false.
+bool ConsumePrefix(StringPiece* s, StringPiece expected);
+
+// Return lower-cased version of s.
+string Lowercase(StringPiece s);
+
+// Return upper-cased version of s.
+string Uppercase(StringPiece s);
+
+// Capitalize first character of each word in "*s". "delimiters" is a
+// set of characters that can be used as word boundaries.
+void TitlecaseString(string* s, StringPiece delimiters);
+
+// Join functionality
+template <typename T>
+string Join(const std::vector<T>& s, const char* sep);
+template <typename T>
+string Join(const gtl::ArraySlice<T>& s, const char* sep);
+
+struct AllowEmpty {
+ bool operator()(StringPiece sp) const { return true; }
+};
+struct SkipEmpty {
+ bool operator()(StringPiece sp) const { return !sp.empty(); }
+};
+struct SkipWhitespace {
+ bool operator()(StringPiece sp) const {
+ RemoveTrailingWhitespace(&sp);
+ return !sp.empty();
+ }
+};
+
+std::vector<string> Split(StringPiece text, char delim);
+template <typename Predicate>
+std::vector<string> Split(StringPiece text, char delim, Predicate p);
+
+// Split "text" at "delim" characters, and parse each component as
+// an integer. If successful, adds the individual numbers in order
+// to "*result" and returns true. Otherwise returns false.
+bool SplitAndParseAsInts(StringPiece text, char delim,
+ std::vector<int32>* result);
+
+// ------------------------------------------------------------------
+// Implementation details below
+namespace internal {
+template <typename T>
+string JoinHelper(typename gtl::ArraySlice<T>::const_iterator begin,
+ typename gtl::ArraySlice<T>::const_iterator end,
+ const char* sep) {
+ string result;
+ bool first = true;
+ for (typename gtl::ArraySlice<T>::const_iterator it = begin; it != end;
+ ++it) {
+ tensorflow::strings::StrAppend(&result, (first ? "" : sep), *it);
+ first = false;
+ }
+ return result;
+}
+} // namespace internal
+
+template <typename T>
+string Join(const std::vector<T>& s, const char* sep) {
+ return Join<T>(gtl::ArraySlice<T>(s), sep);
+}
+
+template <typename T>
+string Join(const gtl::ArraySlice<T>& s, const char* sep) {
+ return internal::JoinHelper<T>(s.begin(), s.end(), sep);
+}
+
+inline std::vector<string> Split(StringPiece text, char delim) {
+ return Split(text, delim, AllowEmpty());
+}
+
+template <typename Predicate>
+std::vector<string> Split(StringPiece text, char delim, Predicate p) {
+ std::vector<string> result;
+ int token_start = 0;
+ if (!text.empty()) {
+ for (int i = 0; i < text.size() + 1; i++) {
+ if ((i == text.size()) || (text[i] == delim)) {
+ StringPiece token(text.data() + token_start, i - token_start);
+ if (p(token)) {
+ result.push_back(token.ToString());
+ }
+ token_start = i + 1;
+ }
+ }
+ }
+ return result;
+}
+
+} // namespace str_util
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_STRINGS_STR_UTIL_H_
diff --git a/tensorflow/core/lib/strings/str_util_test.cc b/tensorflow/core/lib/strings/str_util_test.cc
new file mode 100644
index 0000000000..f71cc6c609
--- /dev/null
+++ b/tensorflow/core/lib/strings/str_util_test.cc
@@ -0,0 +1,258 @@
+#include "tensorflow/core/lib/strings/str_util.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TEST(CEscape, Basic) {
+ EXPECT_EQ(str_util::CEscape("hello"), "hello");
+ EXPECT_EQ(str_util::CEscape("hello\n"), "hello\\n");
+ EXPECT_EQ(str_util::CEscape("hello\r"), "hello\\r");
+ EXPECT_EQ(str_util::CEscape("\t\r\"'"), "\\t\\r\\\"\\'");
+ EXPECT_EQ(str_util::CEscape("\320hi\200"), "\\320hi\\200");
+}
+
+string ExpectCUnescapeSuccess(StringPiece source) {
+ string dest;
+ string error;
+ EXPECT_TRUE(str_util::CUnescape(source, &dest, &error)) << error;
+ return dest;
+}
+
+TEST(CUnescape, Basic) {
+ EXPECT_EQ("hello", ExpectCUnescapeSuccess("hello"));
+ EXPECT_EQ("hello\n", ExpectCUnescapeSuccess("hello\\n"));
+ EXPECT_EQ("hello\r", ExpectCUnescapeSuccess("hello\\r"));
+ EXPECT_EQ("\t\r\"'", ExpectCUnescapeSuccess("\\t\\r\\\"\\'"));
+ EXPECT_EQ("\320hi\200", ExpectCUnescapeSuccess("\\320hi\\200"));
+}
+
+TEST(NumericParse32, Basic) {
+ int32 val = -1234;
+ EXPECT_TRUE(str_util::NumericParse32("0", &val) && val == 0);
+ EXPECT_TRUE(str_util::NumericParse32("123", &val) && val == 123);
+ EXPECT_TRUE(str_util::NumericParse32("-375", &val) && val == -375);
+ EXPECT_FALSE(str_util::NumericParse32("123hello", &val));
+ EXPECT_FALSE(str_util::NumericParse32("hello123", &val));
+}
+
+TEST(StripTrailingWhitespace, Basic) {
+ string test;
+ test = "hello";
+ str_util::StripTrailingWhitespace(&test);
+ EXPECT_EQ(test, "hello");
+
+ test = "foo ";
+ str_util::StripTrailingWhitespace(&test);
+ EXPECT_EQ(test, "foo");
+
+ test = " ";
+ str_util::StripTrailingWhitespace(&test);
+ EXPECT_EQ(test, "");
+
+ test = "";
+ str_util::StripTrailingWhitespace(&test);
+ EXPECT_EQ(test, "");
+
+ test = " abc\t";
+ str_util::StripTrailingWhitespace(&test);
+ EXPECT_EQ(test, " abc");
+}
+
+TEST(RemoveLeadingWhitespace, Basic) {
+ string text = " \t \n \r Quick\t";
+ StringPiece data(text);
+ // check that all whitespace is removed
+ EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 11);
+ EXPECT_EQ(data, StringPiece("Quick\t"));
+ // check that non-whitespace is not removed
+ EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 0);
+ EXPECT_EQ(data, StringPiece("Quick\t"));
+}
+
+TEST(RemoveLeadingWhitespace, TerminationHandling) {
+ // check termination handling
+ string text = "\t";
+ StringPiece data(text);
+ EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 1);
+ EXPECT_EQ(data, StringPiece(""));
+
+ // check termination handling again
+ EXPECT_EQ(str_util::RemoveLeadingWhitespace(&data), 0);
+ EXPECT_EQ(data, StringPiece(""));
+}
+
+TEST(RemoveTrailingWhitespace, Basic) {
+ string text = " \t \n \r Quick \t";
+ StringPiece data(text);
+ // check that all whitespace is removed
+ EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 2);
+ EXPECT_EQ(data, StringPiece(" \t \n \r Quick"));
+ // check that non-whitespace is not removed
+ EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 0);
+ EXPECT_EQ(data, StringPiece(" \t \n \r Quick"));
+}
+
+TEST(RemoveTrailingWhitespace, TerminationHandling) {
+ // check termination handling
+ string text = "\t";
+ StringPiece data(text);
+ EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 1);
+ EXPECT_EQ(data, StringPiece(""));
+
+ // check termination handling again
+ EXPECT_EQ(str_util::RemoveTrailingWhitespace(&data), 0);
+ EXPECT_EQ(data, StringPiece(""));
+}
+
+TEST(RemoveWhitespaceContext, Basic) {
+ string text = " \t \n \r Quick \t";
+ StringPiece data(text);
+ // check that all whitespace is removed
+ EXPECT_EQ(str_util::RemoveWhitespaceContext(&data), 13);
+ EXPECT_EQ(data, StringPiece("Quick"));
+ // check that non-whitespace is not removed
+ EXPECT_EQ(str_util::RemoveWhitespaceContext(&data), 0);
+ EXPECT_EQ(data, StringPiece("Quick"));
+
+ // Test empty string
+ text = "";
+ data = text;
+ EXPECT_EQ(str_util::RemoveWhitespaceContext(&data), 0);
+ EXPECT_EQ(data, StringPiece(""));
+}
+
+void TestConsumeLeadingDigits(StringPiece s, int64 expected,
+ StringPiece remaining) {
+ uint64 v;
+ StringPiece input(s);
+ if (str_util::ConsumeLeadingDigits(&input, &v)) {
+ EXPECT_EQ(v, static_cast<uint64>(expected));
+ EXPECT_EQ(input, remaining);
+ } else {
+ EXPECT_LT(expected, 0);
+ EXPECT_EQ(input, remaining);
+ }
+}
+
+TEST(ConsumeLeadingDigits, Basic) {
+ TestConsumeLeadingDigits("123", 123, "");
+ TestConsumeLeadingDigits("a123", -1, "a123");
+ TestConsumeLeadingDigits("9_", 9, "_");
+ TestConsumeLeadingDigits("11111111111xyz", 11111111111ll, "xyz");
+
+ // Overflow case
+ TestConsumeLeadingDigits("1111111111111111111111111111111xyz", -1,
+ "1111111111111111111111111111111xyz");
+
+ // 2^64
+ TestConsumeLeadingDigits("18446744073709551616xyz", -1,
+ "18446744073709551616xyz");
+ // 2^64-1
+ TestConsumeLeadingDigits("18446744073709551615xyz", 18446744073709551615ull,
+ "xyz");
+}
+
+TEST(ConsumePrefix, Basic) {
+ string s("abcdef");
+ StringPiece input(s);
+ EXPECT_FALSE(str_util::ConsumePrefix(&input, "abcdefg"));
+ EXPECT_EQ(input, "abcdef");
+
+ EXPECT_FALSE(str_util::ConsumePrefix(&input, "abce"));
+ EXPECT_EQ(input, "abcdef");
+
+ EXPECT_TRUE(str_util::ConsumePrefix(&input, ""));
+ EXPECT_EQ(input, "abcdef");
+
+ EXPECT_FALSE(str_util::ConsumePrefix(&input, "abcdeg"));
+ EXPECT_EQ(input, "abcdef");
+
+ EXPECT_TRUE(str_util::ConsumePrefix(&input, "abcdef"));
+ EXPECT_EQ(input, "");
+
+ input = s;
+ EXPECT_TRUE(str_util::ConsumePrefix(&input, "abcde"));
+ EXPECT_EQ(input, "f");
+}
+
+TEST(JoinStrings, Basic) {
+ std::vector<string> s;
+ s = {"hi"};
+ EXPECT_EQ(str_util::Join(s, " "), "hi");
+ s = {"hi", "there", "strings"};
+ EXPECT_EQ(str_util::Join(s, " "), "hi there strings");
+
+ std::vector<StringPiece> sp;
+ sp = {"hi"};
+ EXPECT_EQ(str_util::Join(sp, ",,"), "hi");
+ sp = {"hi", "there", "strings"};
+ EXPECT_EQ(str_util::Join(sp, "--"), "hi--there--strings");
+}
+
+TEST(Split, Basic) {
+ EXPECT_TRUE(str_util::Split("", ',').empty());
+ EXPECT_EQ(str_util::Join(str_util::Split("a", ','), "|"), "a");
+ EXPECT_EQ(str_util::Join(str_util::Split(",", ','), "|"), "|");
+ EXPECT_EQ(str_util::Join(str_util::Split("a,b,c", ','), "|"), "a|b|c");
+ EXPECT_EQ(str_util::Join(str_util::Split("a,,,b,,c,", ','), "|"),
+ "a|||b||c|");
+ EXPECT_EQ(str_util::Join(
+ str_util::Split("a,,,b,,c,", ',', str_util::SkipEmpty()), "|"),
+ "a|b|c");
+ EXPECT_EQ(
+ str_util::Join(
+ str_util::Split("a, ,b,,c,", ',', str_util::SkipWhitespace()), "|"),
+ "a|b|c");
+}
+
+TEST(SplitAndParseAsInts, Basic) {
+ std::vector<int32> nums;
+ EXPECT_TRUE(str_util::SplitAndParseAsInts("", ',', &nums));
+ EXPECT_EQ(nums.size(), 0);
+
+ EXPECT_TRUE(str_util::SplitAndParseAsInts("134", ',', &nums));
+ EXPECT_EQ(nums.size(), 1);
+ EXPECT_EQ(nums[0], 134);
+
+ EXPECT_TRUE(str_util::SplitAndParseAsInts("134,2,13,-5", ',', &nums));
+ EXPECT_EQ(nums.size(), 4);
+ EXPECT_EQ(nums[0], 134);
+ EXPECT_EQ(nums[1], 2);
+ EXPECT_EQ(nums[2], 13);
+ EXPECT_EQ(nums[3], -5);
+
+ EXPECT_FALSE(str_util::SplitAndParseAsInts("abc", ',', &nums));
+
+ EXPECT_FALSE(str_util::SplitAndParseAsInts("-13,abc", ',', &nums));
+
+ EXPECT_FALSE(str_util::SplitAndParseAsInts("13,abc,5", ',', &nums));
+}
+
+TEST(Lowercase, Basic) {
+ EXPECT_EQ("", str_util::Lowercase(""));
+ EXPECT_EQ("hello", str_util::Lowercase("hello"));
+ EXPECT_EQ("hello world", str_util::Lowercase("Hello World"));
+}
+
+TEST(Uppercase, Basic) {
+ EXPECT_EQ("", str_util::Uppercase(""));
+ EXPECT_EQ("HELLO", str_util::Uppercase("hello"));
+ EXPECT_EQ("HELLO WORLD", str_util::Uppercase("Hello World"));
+}
+
+TEST(TitlecaseString, Basic) {
+ string s = "sparse_lookup";
+ str_util::TitlecaseString(&s, "_");
+ ASSERT_EQ(s, "Sparse_Lookup");
+
+ s = "sparse_lookup";
+ str_util::TitlecaseString(&s, " ");
+ ASSERT_EQ(s, "Sparse_lookup");
+
+ s = "dense";
+ str_util::TitlecaseString(&s, " ");
+ ASSERT_EQ(s, "Dense");
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/strcat.cc b/tensorflow/core/lib/strings/strcat.cc
new file mode 100644
index 0000000000..e564b9eb73
--- /dev/null
+++ b/tensorflow/core/lib/strings/strcat.cc
@@ -0,0 +1,194 @@
+#include "tensorflow/core/lib/strings/strcat.h"
+
+#include <stdarg.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <string.h>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+
+namespace tensorflow {
+namespace strings {
+
+AlphaNum gEmptyAlphaNum("");
+
+AlphaNum::AlphaNum(Hex hex) {
+ char *const end = &digits_[kFastToBufferSize];
+ char *writer = end;
+ uint64 value = hex.value;
+ uint64 width = hex.spec;
+ // We accomplish minimum width by OR'ing in 0x10000 to the user's value,
+ // where 0x10000 is the smallest hex number that is as wide as the user
+ // asked for.
+ uint64 mask = ((static_cast<uint64>(1) << (width - 1) * 4)) | value;
+ static const char hexdigits[] = "0123456789abcdef";
+ do {
+ *--writer = hexdigits[value & 0xF];
+ value >>= 4;
+ mask >>= 4;
+ } while (mask != 0);
+ piece_.set(writer, end - writer);
+}
+
+// ----------------------------------------------------------------------
+// StrCat()
+// This merges the given strings or integers, with no delimiter. This
+// is designed to be the fastest possible way to construct a string out
+// of a mix of raw C strings, StringPieces, strings, and integer values.
+// ----------------------------------------------------------------------
+
+// Append is merely a version of memcpy that returns the address of the byte
+// after the area just overwritten. It comes in multiple flavors to minimize
+// call overhead.
+static char *Append1(char *out, const AlphaNum &x) {
+ memcpy(out, x.data(), x.size());
+ return out + x.size();
+}
+
+static char *Append2(char *out, const AlphaNum &x1, const AlphaNum &x2) {
+ memcpy(out, x1.data(), x1.size());
+ out += x1.size();
+
+ memcpy(out, x2.data(), x2.size());
+ return out + x2.size();
+}
+
+static char *Append4(char *out, const AlphaNum &x1, const AlphaNum &x2,
+ const AlphaNum &x3, const AlphaNum &x4) {
+ memcpy(out, x1.data(), x1.size());
+ out += x1.size();
+
+ memcpy(out, x2.data(), x2.size());
+ out += x2.size();
+
+ memcpy(out, x3.data(), x3.size());
+ out += x3.size();
+
+ memcpy(out, x4.data(), x4.size());
+ return out + x4.size();
+}
+
+string StrCat(const AlphaNum &a, const AlphaNum &b) {
+ string result;
+ gtl::STLStringResizeUninitialized(&result, a.size() + b.size());
+ char *const begin = &*result.begin();
+ char *out = Append2(begin, a, b);
+ DCHECK_EQ(out, begin + result.size());
+ return result;
+}
+
+string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c) {
+ string result;
+ gtl::STLStringResizeUninitialized(&result, a.size() + b.size() + c.size());
+ char *const begin = &*result.begin();
+ char *out = Append2(begin, a, b);
+ out = Append1(out, c);
+ DCHECK_EQ(out, begin + result.size());
+ return result;
+}
+
+string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c,
+ const AlphaNum &d) {
+ string result;
+ gtl::STLStringResizeUninitialized(&result,
+ a.size() + b.size() + c.size() + d.size());
+ char *const begin = &*result.begin();
+ char *out = Append4(begin, a, b, c, d);
+ DCHECK_EQ(out, begin + result.size());
+ return result;
+}
+
+namespace internal {
+
+// Do not call directly - these are not part of the public API.
+string CatPieces(std::initializer_list<StringPiece> pieces) {
+ string result;
+ size_t total_size = 0;
+ for (const StringPiece piece : pieces) total_size += piece.size();
+ gtl::STLStringResizeUninitialized(&result, total_size);
+
+ char *const begin = &*result.begin();
+ char *out = begin;
+ for (const StringPiece piece : pieces) {
+ const size_t this_size = piece.size();
+ memcpy(out, piece.data(), this_size);
+ out += this_size;
+ }
+ DCHECK_EQ(out, begin + result.size());
+ return result;
+}
+
+// It's possible to call StrAppend with a StringPiece that is itself a fragment
+// of the string we're appending to. However the results of this are random.
+// Therefore, check for this in debug mode. Use unsigned math so we only have
+// to do one comparison.
+#define DCHECK_NO_OVERLAP(dest, src) \
+ DCHECK_GE(uintptr_t((src).data() - (dest).data()), uintptr_t((dest).size()))
+
+void AppendPieces(string *result, std::initializer_list<StringPiece> pieces) {
+ size_t old_size = result->size();
+ size_t total_size = old_size;
+ for (const StringPiece piece : pieces) {
+ DCHECK_NO_OVERLAP(*result, piece);
+ total_size += piece.size();
+ }
+ gtl::STLStringResizeUninitialized(result, total_size);
+
+ char *const begin = &*result->begin();
+ char *out = begin + old_size;
+ for (const StringPiece piece : pieces) {
+ const size_t this_size = piece.size();
+ memcpy(out, piece.data(), this_size);
+ out += this_size;
+ }
+ DCHECK_EQ(out, begin + result->size());
+}
+
+} // namespace internal
+
+void StrAppend(string *result, const AlphaNum &a) {
+ DCHECK_NO_OVERLAP(*result, a);
+ result->append(a.data(), a.size());
+}
+
+void StrAppend(string *result, const AlphaNum &a, const AlphaNum &b) {
+ DCHECK_NO_OVERLAP(*result, a);
+ DCHECK_NO_OVERLAP(*result, b);
+ string::size_type old_size = result->size();
+ gtl::STLStringResizeUninitialized(result, old_size + a.size() + b.size());
+ char *const begin = &*result->begin();
+ char *out = Append2(begin + old_size, a, b);
+ DCHECK_EQ(out, begin + result->size());
+}
+
+void StrAppend(string *result, const AlphaNum &a, const AlphaNum &b,
+ const AlphaNum &c) {
+ DCHECK_NO_OVERLAP(*result, a);
+ DCHECK_NO_OVERLAP(*result, b);
+ DCHECK_NO_OVERLAP(*result, c);
+ string::size_type old_size = result->size();
+ gtl::STLStringResizeUninitialized(result,
+ old_size + a.size() + b.size() + c.size());
+ char *const begin = &*result->begin();
+ char *out = Append2(begin + old_size, a, b);
+ out = Append1(out, c);
+ DCHECK_EQ(out, begin + result->size());
+}
+
+void StrAppend(string *result, const AlphaNum &a, const AlphaNum &b,
+ const AlphaNum &c, const AlphaNum &d) {
+ DCHECK_NO_OVERLAP(*result, a);
+ DCHECK_NO_OVERLAP(*result, b);
+ DCHECK_NO_OVERLAP(*result, c);
+ DCHECK_NO_OVERLAP(*result, d);
+ string::size_type old_size = result->size();
+ gtl::STLStringResizeUninitialized(
+ result, old_size + a.size() + b.size() + c.size() + d.size());
+ char *const begin = &*result->begin();
+ char *out = Append4(begin + old_size, a, b, c, d);
+ DCHECK_EQ(out, begin + result->size());
+}
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h
new file mode 100644
index 0000000000..763ad8368a
--- /dev/null
+++ b/tensorflow/core/lib/strings/strcat.h
@@ -0,0 +1,229 @@
+// #status: RECOMMENDED
+// #category: operations on strings
+// #summary: Merges strings or numbers with no delimiter.
+//
+#ifndef TENSORFLOW_LIB_STRINGS_STRCAT_H_
+#define TENSORFLOW_LIB_STRINGS_STRCAT_H_
+
+#include <string>
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/platform/port.h"
+
+// The AlphaNum type was designed to be used as the parameter type for StrCat().
+// Any routine accepting either a string or a number may accept it.
+// The basic idea is that by accepting a "const AlphaNum &" as an argument
+// to your function, your callers will automagically convert bools, integers,
+// and floating point values to strings for you.
+//
+// NOTE: Use of AlphaNum outside of the //strings package is unsupported except
+// for the specific case of function parameters of type "AlphaNum" or "const
+// AlphaNum &". In particular, instantiating AlphaNum directly as a stack
+// variable is not supported.
+//
+// Conversion from 8-bit values is not accepted because if it were, then an
+// attempt to pass ':' instead of ":" might result in a 58 ending up in your
+// result.
+//
+// Bools convert to "0" or "1".
+//
+// Floating point values are converted to a string which, if passed to strtod(),
+// would produce the exact same original double (except in case of NaN; all NaNs
+// are considered the same value). We try to keep the string short but it's not
+// guaranteed to be as short as possible.
+//
+// You can convert to Hexadecimal output rather than Decimal output using Hex.
+// To do this, pass strings::Hex(my_int) as a parameter to StrCat. You may
+// specify a minimum field width using a separate parameter, so the equivalent
+// of Printf("%04x", my_int) is StrCat(Hex(my_int, strings::ZERO_PAD_4))
+//
+// This class has implicit constructors.
+namespace tensorflow {
+namespace strings {
+
+enum PadSpec {
+ NO_PAD = 1,
+ ZERO_PAD_2,
+ ZERO_PAD_3,
+ ZERO_PAD_4,
+ ZERO_PAD_5,
+ ZERO_PAD_6,
+ ZERO_PAD_7,
+ ZERO_PAD_8,
+ ZERO_PAD_9,
+ ZERO_PAD_10,
+ ZERO_PAD_11,
+ ZERO_PAD_12,
+ ZERO_PAD_13,
+ ZERO_PAD_14,
+ ZERO_PAD_15,
+ ZERO_PAD_16,
+};
+
+struct Hex {
+ uint64 value;
+ enum PadSpec spec;
+ template <class Int>
+ explicit Hex(Int v, PadSpec s = NO_PAD)
+ : spec(s) {
+ // Prevent sign-extension by casting integers to
+ // their unsigned counterparts.
+ static_assert(
+ sizeof(v) == 1 || sizeof(v) == 2 || sizeof(v) == 4 || sizeof(v) == 8,
+ "Unknown integer type");
+ value = sizeof(v) == 1
+ ? static_cast<uint8>(v)
+ : sizeof(v) == 2 ? static_cast<uint16>(v)
+ : sizeof(v) == 4 ? static_cast<uint32>(v)
+ : static_cast<uint64>(v);
+ }
+};
+
+class AlphaNum {
+ public:
+ // No bool ctor -- bools convert to an integral type.
+ // A bool ctor would also convert incoming pointers (bletch).
+
+ AlphaNum(int i32) // NOLINT(runtime/explicit)
+ : piece_(digits_, FastInt32ToBufferLeft(i32, digits_) - &digits_[0]) {}
+ AlphaNum(unsigned int u32) // NOLINT(runtime/explicit)
+ : piece_(digits_, FastUInt32ToBufferLeft(u32, digits_) - &digits_[0]) {}
+ AlphaNum(long x) // NOLINT(runtime/explicit)
+ : piece_(digits_, FastInt64ToBufferLeft(x, digits_) - &digits_[0]) {}
+ AlphaNum(unsigned long x) // NOLINT(runtime/explicit)
+ : piece_(digits_, FastUInt64ToBufferLeft(x, digits_) - &digits_[0]) {}
+ AlphaNum(long long int i64) // NOLINT(runtime/explicit)
+ : piece_(digits_, FastInt64ToBufferLeft(i64, digits_) - &digits_[0]) {}
+ AlphaNum(unsigned long long int u64) // NOLINT(runtime/explicit)
+ : piece_(digits_, FastUInt64ToBufferLeft(u64, digits_) - &digits_[0]) {}
+
+ AlphaNum(float f) // NOLINT(runtime/explicit)
+ : piece_(digits_, strlen(FloatToBuffer(f, digits_))) {}
+ AlphaNum(double f) // NOLINT(runtime/explicit)
+ : piece_(digits_, strlen(DoubleToBuffer(f, digits_))) {}
+
+ AlphaNum(Hex hex); // NOLINT(runtime/explicit)
+
+ AlphaNum(const char *c_str) : piece_(c_str) {} // NOLINT(runtime/explicit)
+ AlphaNum(const StringPiece &pc) : piece_(pc) {} // NOLINT(runtime/explicit)
+ AlphaNum(const tensorflow::string &str) // NOLINT(runtime/explicit)
+ : piece_(str) {}
+
+ StringPiece::size_type size() const { return piece_.size(); }
+ const char *data() const { return piece_.data(); }
+ StringPiece Piece() const { return piece_; }
+
+ private:
+ StringPiece piece_;
+ char digits_[kFastToBufferSize];
+
+ // Use ":" not ':'
+ AlphaNum(char c); // NOLINT(runtime/explicit)
+
+ TF_DISALLOW_COPY_AND_ASSIGN(AlphaNum);
+};
+
+extern AlphaNum gEmptyAlphaNum;
+
+using strings::AlphaNum;
+using strings::gEmptyAlphaNum;
+
+// ----------------------------------------------------------------------
+// StrCat()
+// This merges the given strings or numbers, with no delimiter. This
+// is designed to be the fastest possible way to construct a string out
+// of a mix of raw C strings, StringPieces, strings, bool values,
+// and numeric values.
+//
+// Don't use this for user-visible strings. The localization process
+// works poorly on strings built up out of fragments.
+//
+// For clarity and performance, don't use StrCat when appending to a
+// string. In particular, avoid using any of these (anti-)patterns:
+// str.append(StrCat(...))
+// str += StrCat(...)
+// str = StrCat(str, ...)
+// where the last is the worse, with the potential to change a loop
+// from a linear time operation with O(1) dynamic allocations into a
+// quadratic time operation with O(n) dynamic allocations. StrAppend
+// is a better choice than any of the above, subject to the restriction
+// of StrAppend(&str, a, b, c, ...) that none of the a, b, c, ... may
+// be a reference into str.
+// ----------------------------------------------------------------------
+
+// For performance reasons, we have specializations for <= 4 args.
+string StrCat(const AlphaNum &a) TF_MUST_USE_RESULT;
+string StrCat(const AlphaNum &a, const AlphaNum &b) TF_MUST_USE_RESULT;
+string StrCat(const AlphaNum &a, const AlphaNum &b,
+ const AlphaNum &c) TF_MUST_USE_RESULT;
+string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c,
+ const AlphaNum &d) TF_MUST_USE_RESULT;
+
+// inline definitions must be duplicated due to TF_MUST_USE_RESULT
+inline string StrCat(const AlphaNum &a) { return string(a.data(), a.size()); }
+
+namespace internal {
+
+// Do not call directly - this is not part of the public API.
+string CatPieces(std::initializer_list<StringPiece> pieces);
+void AppendPieces(string *dest, std::initializer_list<StringPiece> pieces);
+
+} // namespace internal
+
+// Support 5 or more arguments
+template <typename... AV>
+string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c,
+ const AlphaNum &d, const AlphaNum &e,
+ const AV &... args) TF_MUST_USE_RESULT;
+
+template <typename... AV>
+inline string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c,
+ const AlphaNum &d, const AlphaNum &e, const AV &... args) {
+ return internal::CatPieces({a.Piece(), b.Piece(), c.Piece(), d.Piece(),
+ e.Piece(),
+ static_cast<const AlphaNum &>(args).Piece()...});
+}
+
+// ----------------------------------------------------------------------
+// StrAppend()
+// Same as above, but adds the output to the given string.
+// WARNING: For speed, StrAppend does not try to check each of its input
+// arguments to be sure that they are not a subset of the string being
+// appended to. That is, while this will work:
+//
+// string s = "foo";
+// s += s;
+//
+// This will not (necessarily) work:
+//
+// string s = "foo";
+// StrAppend(&s, s);
+//
+// Note: while StrCat supports appending up to 26 arguments, StrAppend
+// is currently limited to 9. That's rarely an issue except when
+// automatically transforming StrCat to StrAppend, and can easily be
+// worked around as consecutive calls to StrAppend are quite efficient.
+// ----------------------------------------------------------------------
+
+void StrAppend(string *dest, const AlphaNum &a);
+void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b);
+void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b,
+ const AlphaNum &c);
+void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b,
+ const AlphaNum &c, const AlphaNum &d);
+
+// Support 5 or more arguments
+template <typename... AV>
+inline void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b,
+ const AlphaNum &c, const AlphaNum &d, const AlphaNum &e,
+ const AV &... args) {
+ internal::AppendPieces(dest,
+ {a.Piece(), b.Piece(), c.Piece(), d.Piece(), e.Piece(),
+ static_cast<const AlphaNum &>(args).Piece()...});
+}
+
+} // namespace strings
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_STRINGS_STRCAT_H_
diff --git a/tensorflow/core/lib/strings/strcat_test.cc b/tensorflow/core/lib/strings/strcat_test.cc
new file mode 100644
index 0000000000..9ff7d81af9
--- /dev/null
+++ b/tensorflow/core/lib/strings/strcat_test.cc
@@ -0,0 +1,324 @@
+#include "tensorflow/core/lib/strings/strcat.h"
+
+#include <string>
+
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/port.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace strings {
+
+// Test StrCat of ints and longs of various sizes and signdedness.
+TEST(StrCat, Ints) {
+ const int16 s = -1;
+ const uint16 us = 2;
+ const int i = -3;
+ const unsigned int ui = 4;
+ const int32 l = -5;
+ const uint32 ul = 6;
+ const int64 ll = -7;
+ const uint64 ull = 8;
+ const ptrdiff_t ptrdiff = -9;
+ const size_t size = 10;
+ const ssize_t ssize = -11;
+ const intptr_t intptr = -12;
+ const uintptr_t uintptr = 13;
+ string answer;
+ answer = StrCat(s, us);
+ EXPECT_EQ(answer, "-12");
+ answer = StrCat(i, ui);
+ EXPECT_EQ(answer, "-34");
+ answer = StrCat(l, ul);
+ EXPECT_EQ(answer, "-56");
+ answer = StrCat(ll, ull);
+ EXPECT_EQ(answer, "-78");
+ answer = StrCat(ptrdiff, size);
+ EXPECT_EQ(answer, "-910");
+ answer = StrCat(ssize, intptr);
+ EXPECT_EQ(answer, "-11-12");
+ answer = StrCat(uintptr, 0);
+ EXPECT_EQ(answer, "130");
+}
+
+TEST(StrCat, Basics) {
+ string result;
+
+ string strs[] = {"Hello", "Cruel", "World"};
+
+ StringPiece pieces[] = {"Hello", "Cruel", "World"};
+
+ const char *c_strs[] = {"Hello", "Cruel", "World"};
+
+ int32 i32s[] = {'H', 'C', 'W'};
+ uint64 ui64s[] = {12345678910LL, 10987654321LL};
+
+ result = StrCat(false, true, 2, 3);
+ EXPECT_EQ(result, "0123");
+
+ result = StrCat(-1);
+ EXPECT_EQ(result, "-1");
+
+ result = StrCat(0.5);
+ EXPECT_EQ(result, "0.5");
+
+ result = StrCat(strs[1], pieces[2]);
+ EXPECT_EQ(result, "CruelWorld");
+
+ result = StrCat(strs[0], ", ", pieces[2]);
+ EXPECT_EQ(result, "Hello, World");
+
+ result = StrCat(strs[0], ", ", strs[1], " ", strs[2], "!");
+ EXPECT_EQ(result, "Hello, Cruel World!");
+
+ result = StrCat(pieces[0], ", ", pieces[1], " ", pieces[2]);
+ EXPECT_EQ(result, "Hello, Cruel World");
+
+ result = StrCat(c_strs[0], ", ", c_strs[1], " ", c_strs[2]);
+ EXPECT_EQ(result, "Hello, Cruel World");
+
+ result = StrCat("ASCII ", i32s[0], ", ", i32s[1], " ", i32s[2], "!");
+ EXPECT_EQ(result, "ASCII 72, 67 87!");
+
+ result = StrCat(ui64s[0], ", ", ui64s[1], "!");
+ EXPECT_EQ(result, "12345678910, 10987654321!");
+
+ string one = "1"; // Actually, it's the size of this string that we want; a
+ // 64-bit build distinguishes between size_t and uint64,
+ // even though they're both unsigned 64-bit values.
+ result = StrCat("And a ", one.size(), " and a ", &result[2] - &result[0],
+ " and a ", one, " 2 3 4", "!");
+ EXPECT_EQ(result, "And a 1 and a 2 and a 1 2 3 4!");
+
+ // result = StrCat("Single chars won't compile", '!');
+ // result = StrCat("Neither will NULLs", NULL);
+ result = StrCat("To output a char by ASCII/numeric value, use +: ", '!' + 0);
+ EXPECT_EQ(result, "To output a char by ASCII/numeric value, use +: 33");
+
+ float f = 100000.5;
+ result = StrCat("A hundred K and a half is ", f);
+ EXPECT_EQ(result, "A hundred K and a half is 100000.5");
+
+ double d = f;
+ d *= d;
+ result = StrCat("A hundred K and a half squared is ", d);
+ EXPECT_EQ(result, "A hundred K and a half squared is 10000100000.25");
+
+ result = StrCat(1, 2, 333, 4444, 55555, 666666, 7777777, 88888888, 999999999);
+ EXPECT_EQ(result, "12333444455555666666777777788888888999999999");
+}
+
+TEST(StrCat, MaxArgs) {
+ string result;
+ // Test 10 up to 26 arguments, the current maximum
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a");
+ EXPECT_EQ(result, "123456789a");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b");
+ EXPECT_EQ(result, "123456789ab");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c");
+ EXPECT_EQ(result, "123456789abc");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d");
+ EXPECT_EQ(result, "123456789abcd");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e");
+ EXPECT_EQ(result, "123456789abcde");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f");
+ EXPECT_EQ(result, "123456789abcdef");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g");
+ EXPECT_EQ(result, "123456789abcdefg");
+ result =
+ StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g", "h");
+ EXPECT_EQ(result, "123456789abcdefgh");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i");
+ EXPECT_EQ(result, "123456789abcdefghi");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i", "j");
+ EXPECT_EQ(result, "123456789abcdefghij");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i", "j", "k");
+ EXPECT_EQ(result, "123456789abcdefghijk");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i", "j", "k", "l");
+ EXPECT_EQ(result, "123456789abcdefghijkl");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i", "j", "k", "l", "m");
+ EXPECT_EQ(result, "123456789abcdefghijklm");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i", "j", "k", "l", "m", "n");
+ EXPECT_EQ(result, "123456789abcdefghijklmn");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i", "j", "k", "l", "m", "n", "o");
+ EXPECT_EQ(result, "123456789abcdefghijklmno");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i", "j", "k", "l", "m", "n", "o", "p");
+ EXPECT_EQ(result, "123456789abcdefghijklmnop");
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, "a", "b", "c", "d", "e", "f", "g",
+ "h", "i", "j", "k", "l", "m", "n", "o", "p", "q");
+ EXPECT_EQ(result, "123456789abcdefghijklmnopq");
+ // No limit thanks to C++11's variadic templates
+ result = StrCat(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, "a", "b", "c", "d", "e", "f",
+ "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r",
+ "s", "t", "u", "v", "w", "x", "y", "z", "A", "B", "C", "D",
+ "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P",
+ "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z");
+ EXPECT_EQ(result,
+ "12345678910abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ");
+}
+
+TEST(StrAppend, Basics) {
+ string result = "existing text";
+
+ string strs[] = {"Hello", "Cruel", "World"};
+
+ StringPiece pieces[] = {"Hello", "Cruel", "World"};
+
+ const char *c_strs[] = {"Hello", "Cruel", "World"};
+
+ int32 i32s[] = {'H', 'C', 'W'};
+ uint64 ui64s[] = {12345678910LL, 10987654321LL};
+
+ string::size_type old_size = result.size();
+ StrAppend(&result, strs[0]);
+ EXPECT_EQ(result.substr(old_size), "Hello");
+
+ old_size = result.size();
+ StrAppend(&result, strs[1], pieces[2]);
+ EXPECT_EQ(result.substr(old_size), "CruelWorld");
+
+ old_size = result.size();
+ StrAppend(&result, strs[0], ", ", pieces[2]);
+ EXPECT_EQ(result.substr(old_size), "Hello, World");
+
+ old_size = result.size();
+ StrAppend(&result, strs[0], ", ", strs[1], " ", strs[2], "!");
+ EXPECT_EQ(result.substr(old_size), "Hello, Cruel World!");
+
+ old_size = result.size();
+ StrAppend(&result, pieces[0], ", ", pieces[1], " ", pieces[2]);
+ EXPECT_EQ(result.substr(old_size), "Hello, Cruel World");
+
+ old_size = result.size();
+ StrAppend(&result, c_strs[0], ", ", c_strs[1], " ", c_strs[2]);
+ EXPECT_EQ(result.substr(old_size), "Hello, Cruel World");
+
+ old_size = result.size();
+ StrAppend(&result, "ASCII ", i32s[0], ", ", i32s[1], " ", i32s[2], "!");
+ EXPECT_EQ(result.substr(old_size), "ASCII 72, 67 87!");
+
+ old_size = result.size();
+ StrAppend(&result, ui64s[0], ", ", ui64s[1], "!");
+ EXPECT_EQ(result.substr(old_size), "12345678910, 10987654321!");
+
+ string one = "1"; // Actually, it's the size of this string that we want; a
+ // 64-bit build distinguishes between size_t and uint64,
+ // even though they're both unsigned 64-bit values.
+ old_size = result.size();
+ StrAppend(&result, "And a ", one.size(), " and a ", &result[2] - &result[0],
+ " and a ", one, " 2 3 4", "!");
+ EXPECT_EQ(result.substr(old_size), "And a 1 and a 2 and a 1 2 3 4!");
+
+ // result = StrCat("Single chars won't compile", '!');
+ // result = StrCat("Neither will NULLs", NULL);
+ old_size = result.size();
+ StrAppend(&result, "To output a char by ASCII/numeric value, use +: ",
+ '!' + 0);
+ EXPECT_EQ(result.substr(old_size),
+ "To output a char by ASCII/numeric value, use +: 33");
+
+ float f = 100000.5;
+ old_size = result.size();
+ StrAppend(&result, "A hundred K and a half is ", f);
+ EXPECT_EQ(result.substr(old_size), "A hundred K and a half is 100000.5");
+
+ double d = f;
+ d *= d;
+ old_size = result.size();
+ StrAppend(&result, "A hundred K and a half squared is ", d);
+ EXPECT_EQ(result.substr(old_size),
+ "A hundred K and a half squared is 10000100000.25");
+
+ // Test 9 arguments, the old maximum
+ old_size = result.size();
+ StrAppend(&result, 1, 22, 333, 4444, 55555, 666666, 7777777, 88888888, 9);
+ EXPECT_EQ(result.substr(old_size), "1223334444555556666667777777888888889");
+
+ // No limit thanks to C++11's variadic templates
+ old_size = result.size();
+ StrAppend(&result, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, "a", "b", "c", "d", "e",
+ "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r",
+ "s", "t", "u", "v", "w", "x", "y", "z", "A", "B", "C", "D", "E",
+ "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R",
+ "S", "T", "U", "V", "W", "X", "Y", "Z",
+ "No limit thanks to C++11's variadic templates");
+ EXPECT_EQ(result.substr(old_size),
+ "12345678910abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "No limit thanks to C++11's variadic templates");
+}
+
+TEST(StrAppend, Death) {
+ string s = "self";
+ EXPECT_DEBUG_DEATH(StrAppend(&s, s.c_str() + 1), "Check failed:");
+ EXPECT_DEBUG_DEATH(StrAppend(&s, s), "Check failed:");
+}
+
+static void CheckHex64(uint64 v) {
+ using tensorflow::strings::Hex;
+ string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_16));
+ string expected = Printf("%016llx", static_cast<unsigned long long>(v));
+ EXPECT_EQ(expected, actual) << " decimal value " << v;
+
+ actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8));
+ expected = Printf("%08llx", static_cast<unsigned long long>(v));
+ EXPECT_EQ(expected, actual) << " decimal value " << v;
+
+ actual = StrCat(Hex(v));
+ expected = Printf("%llx", static_cast<unsigned long long>(v));
+ EXPECT_EQ(expected, actual) << " decimal value " << v;
+}
+
+static void CheckHex32(uint32 v) {
+ using tensorflow::strings::Hex;
+ string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8));
+ string expected = Printf("%08x", v);
+ EXPECT_EQ(expected, actual) << " decimal value " << v;
+
+ actual = StrCat(Hex(v));
+ expected = Printf("%x", v);
+ EXPECT_EQ(expected, actual) << " decimal value " << v;
+}
+
+static void CheckHexSigned32(int32 v) {
+ using tensorflow::strings::Hex;
+ string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8));
+ string expected = Printf("%08x", v);
+ EXPECT_EQ(expected, actual) << " decimal value " << v;
+
+ actual = StrCat(Hex(v));
+ expected = Printf("%x", v);
+ EXPECT_EQ(expected, actual) << " decimal value " << v;
+}
+
+static void TestFastPrints() {
+ using tensorflow::strings::Hex;
+
+ // Test min int to make sure that works
+ for (int i = 0; i < 10000; i++) {
+ CheckHex64(i);
+ CheckHex32(i);
+ CheckHexSigned32(i);
+ CheckHexSigned32(-i);
+ }
+ CheckHex64(0x123456789abcdef0ull);
+ CheckHex32(0x12345678);
+
+ int8 minus_one_8bit = -1;
+ EXPECT_EQ("ff", StrCat(Hex(minus_one_8bit)));
+
+ int16 minus_one_16bit = -1;
+ EXPECT_EQ("ffff", StrCat(Hex(minus_one_16bit)));
+}
+
+TEST(Numbers, TestFunctionsMovedOverFromNumbersMain) { TestFastPrints(); }
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/stringprintf.cc b/tensorflow/core/lib/strings/stringprintf.cc
new file mode 100644
index 0000000000..b354706cbd
--- /dev/null
+++ b/tensorflow/core/lib/strings/stringprintf.cc
@@ -0,0 +1,85 @@
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+#include <errno.h>
+#include <stdarg.h> // For va_list and related operations
+#include <stdio.h> // MSVC requires this for _vsnprintf
+#include <vector>
+
+namespace tensorflow {
+namespace strings {
+
+#ifdef COMPILER_MSVC
+enum { IS_COMPILER_MSVC = 1 };
+#else
+enum { IS_COMPILER_MSVC = 0 };
+#endif
+
+void Appendv(string* dst, const char* format, va_list ap) {
+ // First try with a small fixed size buffer
+ static const int kSpaceLength = 1024;
+ char space[kSpaceLength];
+
+ // It's possible for methods that use a va_list to invalidate
+ // the data in it upon use. The fix is to make a copy
+ // of the structure before using it and use that copy instead.
+ va_list backup_ap;
+ va_copy(backup_ap, ap);
+ int result = vsnprintf(space, kSpaceLength, format, backup_ap);
+ va_end(backup_ap);
+
+ if (result < kSpaceLength) {
+ if (result >= 0) {
+ // Normal case -- everything fit.
+ dst->append(space, result);
+ return;
+ }
+
+ if (IS_COMPILER_MSVC) {
+ // Error or MSVC running out of space. MSVC 8.0 and higher
+ // can be asked about space needed with the special idiom below:
+ va_copy(backup_ap, ap);
+ result = vsnprintf(NULL, 0, format, backup_ap);
+ va_end(backup_ap);
+ }
+
+ if (result < 0) {
+ // Just an error.
+ return;
+ }
+ }
+
+ // Increase the buffer size to the size requested by vsnprintf,
+ // plus one for the closing \0.
+ int length = result + 1;
+ char* buf = new char[length];
+
+ // Restore the va_list before we use it again
+ va_copy(backup_ap, ap);
+ result = vsnprintf(buf, length, format, backup_ap);
+ va_end(backup_ap);
+
+ if (result >= 0 && result < length) {
+ // It fit
+ dst->append(buf, result);
+ }
+ delete[] buf;
+}
+
+string Printf(const char* format, ...) {
+ va_list ap;
+ va_start(ap, format);
+ string result;
+ Appendv(&result, format, ap);
+ va_end(ap);
+ return result;
+}
+
+void Appendf(string* dst, const char* format, ...) {
+ va_list ap;
+ va_start(ap, format);
+ Appendv(dst, format, ap);
+ va_end(ap);
+}
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/stringprintf.h b/tensorflow/core/lib/strings/stringprintf.h
new file mode 100644
index 0000000000..23ca2583ca
--- /dev/null
+++ b/tensorflow/core/lib/strings/stringprintf.h
@@ -0,0 +1,37 @@
+// Printf variants that place their output in a C++ string.
+//
+// Usage:
+// string result = strings::Printf("%d %s\n", 10, "hello");
+// strings::SPrintf(&result, "%d %s\n", 10, "hello");
+// strings::Appendf(&result, "%d %s\n", 20, "there");
+
+#ifndef TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_
+#define TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_
+
+#include <stdarg.h>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace strings {
+
+// Return a C++ string
+extern string Printf(const char* format, ...)
+ // Tell the compiler to do printf format string checking.
+ TF_PRINTF_ATTRIBUTE(1, 2);
+
+// Append result to a supplied string
+extern void Appendf(string* dst, const char* format, ...)
+ // Tell the compiler to do printf format string checking.
+ TF_PRINTF_ATTRIBUTE(2, 3);
+
+// Lower-level routine that takes a va_list and appends to a specified
+// string. All other routines are just convenience wrappers around it.
+extern void Appendv(string* dst, const char* format, va_list ap);
+
+} // namespace strings
+} // namespace tensorflow
+
+#endif // TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_
diff --git a/tensorflow/core/lib/strings/stringprintf_test.cc b/tensorflow/core/lib/strings/stringprintf_test.cc
new file mode 100644
index 0000000000..737ed5c0e0
--- /dev/null
+++ b/tensorflow/core/lib/strings/stringprintf_test.cc
@@ -0,0 +1,113 @@
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+#include <string>
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace strings {
+namespace {
+
+TEST(PrintfTest, Empty) {
+ EXPECT_EQ("", Printf("%s", string().c_str()));
+ EXPECT_EQ("", Printf("%s", ""));
+}
+
+TEST(PrintfTest, Misc) {
+// MSVC does not support $ format specifier.
+#if !defined(COMPILER_MSVC)
+ EXPECT_EQ("123hello w", Printf("%3$d%2$s %1$c", 'w', "hello", 123));
+#endif // !COMPILER_MSVC
+}
+
+TEST(AppendfTest, Empty) {
+ string value("Hello");
+ const char* empty = "";
+ Appendf(&value, "%s", empty);
+ EXPECT_EQ("Hello", value);
+}
+
+TEST(AppendfTest, EmptyString) {
+ string value("Hello");
+ Appendf(&value, "%s", "");
+ EXPECT_EQ("Hello", value);
+}
+
+TEST(AppendfTest, String) {
+ string value("Hello");
+ Appendf(&value, " %s", "World");
+ EXPECT_EQ("Hello World", value);
+}
+
+TEST(AppendfTest, Int) {
+ string value("Hello");
+ Appendf(&value, " %d", 123);
+ EXPECT_EQ("Hello 123", value);
+}
+
+TEST(PrintfTest, Multibyte) {
+ // If we are in multibyte mode and feed invalid multibyte sequence,
+ // Printf should return an empty string instead of running
+ // out of memory while trying to determine destination buffer size.
+ // see b/4194543.
+
+ char* old_locale = setlocale(LC_CTYPE, NULL);
+ // Push locale with multibyte mode
+ setlocale(LC_CTYPE, "en_US.utf8");
+
+ const char kInvalidCodePoint[] = "\375\067s";
+ string value = Printf("%.*s", 3, kInvalidCodePoint);
+
+ // In some versions of glibc (e.g. eglibc-2.11.1, aka GRTEv2), snprintf
+ // returns error given an invalid codepoint. Other versions
+ // (e.g. eglibc-2.15, aka pre-GRTEv3) emit the codepoint verbatim.
+ // We test that the output is one of the above.
+ EXPECT_TRUE(value.empty() || value == kInvalidCodePoint);
+
+ // Repeat with longer string, to make sure that the dynamically
+ // allocated path in StringAppendV is handled correctly.
+ int n = 2048;
+ char* buf = new char[n + 1];
+ memset(buf, ' ', n - 3);
+ memcpy(buf + n - 3, kInvalidCodePoint, 4);
+ value = Printf("%.*s", n, buf);
+ // See GRTEv2 vs. GRTEv3 comment above.
+ EXPECT_TRUE(value.empty() || value == buf);
+ delete[] buf;
+
+ setlocale(LC_CTYPE, old_locale);
+}
+
+TEST(PrintfTest, NoMultibyte) {
+ // No multibyte handling, but the string contains funny chars.
+ char* old_locale = setlocale(LC_CTYPE, NULL);
+ setlocale(LC_CTYPE, "POSIX");
+ string value = Printf("%.*s", 3, "\375\067s");
+ setlocale(LC_CTYPE, old_locale);
+ EXPECT_EQ("\375\067s", value);
+}
+
+TEST(PrintfTest, DontOverwriteErrno) {
+ // Check that errno isn't overwritten unless we're printing
+ // something significantly larger than what people are normally
+ // printing in their badly written PLOG() statements.
+ errno = ECHILD;
+ string value = Printf("Hello, %s!", "World");
+ EXPECT_EQ(ECHILD, errno);
+}
+
+TEST(PrintfTest, LargeBuf) {
+ // Check that the large buffer is handled correctly.
+ int n = 2048;
+ char* buf = new char[n + 1];
+ memset(buf, ' ', n);
+ buf[n] = 0;
+ string value = Printf("%s", buf);
+ EXPECT_EQ(buf, value);
+ delete[] buf;
+}
+
+} // namespace
+
+} // namespace strings
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
new file mode 100644
index 0000000000..8c0571b50e
--- /dev/null
+++ b/tensorflow/core/ops/array_ops.cc
@@ -0,0 +1,892 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("Pack")
+ .Input("values: N * T")
+ .Output("output: T")
+ .Attr("N: int >= 1")
+ .Attr("T: type")
+ .Doc(R"doc(
+Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor.
+
+Packs the `N` tensors in `values` into a tensor with rank one higher than each
+tensor in `values` and shape `[N] + values[0].shape`. The output satisfies
+`output[i, ...] = values[i][...]`.
+
+This is the opposite of `unpack`.
+
+values: Must be of same shape and type.
+output: The packed tensor.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Unpack")
+ .Input("value: T")
+ .Output("output: num * T")
+ .Attr("num: int >= 0")
+ .Attr("T: type")
+ .Doc(R"doc(
+Unpacks the outer dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors.
+
+Unpacks `num` tensors from `value` by chipping it along the first dimension.
+The i'th tensor in `output` is the slice `value[i, ...]`. Each tensor in
+`output` has shape `value.shape[1:]`.
+
+This is the opposite of `pack`.
+
+value: 1-D or higher, with first dimension `num`.
+output: The list of tensors unpacked from `value`.
+)doc");
+
+// --------------------------------------------------------------------------
+// TODO(josh11b): Remove the >= 2 constraint, once we can rewrite the graph
+// in the N == 1 case to remove the node.
+REGISTER_OP("Concat")
+ .Input("concat_dim: int32")
+ .Input("values: N * T")
+ .Output("output: T")
+ .Attr("N: int >= 2")
+ .Attr("T: type")
+ .Doc(R"doc(
+Concatenates tensors along one dimension.
+
+concat_dim: 0-D. The dimension along which to concatenate. Must be in the
+ range [0, rank(values)).
+values: The `N` Tensors to concatenate. Their ranks and types must match,
+ and their sizes must match in all dimensions except `concat_dim`.
+output: A `Tensor` with the concatenation of values stacked along the
+ `concat_dim` dimension. This tensor's shape matches that of `values` except
+ in `concat_dim` where it has the sum of the sizes.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Split")
+ .Input("split_dim: int32")
+ .Input("value: T")
+ .Output("output: num_split * T")
+ .Attr("num_split: int >= 1")
+ .Attr("T: type")
+ .Doc(R"doc(
+Splits a tensor into `num_split` tensors along one dimension.
+
+split_dim: 0-D. The dimension along which to split. Must be in the range
+ `[0, rank(value))`.
+num_split: The number of ways to split. Must evenly divide
+ `value.shape[split_dim]`.
+value: The tensor to split.
+output: They are identically shaped tensors, whose shape matches that of `value`
+ except along `split_dim`, where their sizes are
+ `values.shape[split_dim] / num_split`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Const")
+ .Output("output: dtype")
+ .Attr("value: tensor")
+ .Attr("dtype: type")
+ .Doc(R"doc(
+Returns a constant tensor.
+
+value: Attr `value` is the tensor to return.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ZerosLike")
+ .Input("x: T")
+ .Output("y: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Returns a tensor of zeros with the same shape and type as x.
+
+x: a tensor of type T.
+y: a tensor of the same shape and type as x but filled with zeros.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Diag")
+ .Input("diagonal: T")
+ .Output("output: T")
+ .Attr("T: {float, double, int32, int64}")
+ .Doc(R"doc(
+Returns a diagonal tensor with a given diagonal values.
+
+Given a `diagonal`, this operation returns a tensor with the `diagonal` and
+everything else padded with zeros. The diagonal is computed as follows:
+
+Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of
+rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where:
+
+`output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else.
+
+For example:
+
+```prettyprint
+# 'diagonal' is [1, 2, 3, 4]
+tf.diag(diagonal) ==> [[1, 0, 0, 0]
+ [0, 2, 0, 0]
+ [0, 0, 3, 0]
+ [0, 0, 0, 4]]
+```
+
+diagonal: Rank k tensor where k is at most 3.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Reverse")
+ .Input("tensor: T")
+ .Input("dims: bool")
+ .Output("output: T")
+ .Attr("T: {uint8, int8, int32, bool, float, double}")
+ .Doc(R"Doc(
+Reverses specific dimensions of a tensor.
+
+Given a `tensor`, and a `bool` tensor `dims` representing the dimensions
+of `tensor`, this operation reverses each dimension i of `tensor` where
+`dims[i]` is `True`.
+
+`tensor` can have up to 8 dimensions. The number of dimensions
+of `tensor` must equal the number of elements in `dims`. In other words:
+
+`rank(tensor) = size(dims)`
+
+For example:
+
+```prettyprint
+# tensor 't' is [[[[ 0, 1, 2, 3],
+# [ 4, 5, 6, 7],
+# [ 8, 9, 10, 11]],
+# [[12, 13, 14, 15],
+# [16, 17, 18, 19],
+# [20, 21, 22, 23]]]]
+# tensor 't' shape is [1, 2, 3, 4]
+
+# 'dims' is [False, False, False, True]
+reverse(t, dims) ==> [[[[ 3, 2, 1, 0],
+ [ 7, 6, 5, 4],
+ [ 11, 10, 9, 8]],
+ [[15, 14, 13, 12],
+ [19, 18, 17, 16],
+ [23, 22, 21, 20]]]]
+
+# 'dims' is [False, True, False, False]
+reverse(t, dims) ==> [[[[12, 13, 14, 15],
+ [16, 17, 18, 19],
+ [20, 21, 22, 23]
+ [[ 0, 1, 2, 3],
+ [ 4, 5, 6, 7],
+ [ 8, 9, 10, 11]]]]
+
+# 'dims' is [False, False, True, False]
+reverse(t, dims) ==> [[[[8, 9, 10, 11],
+ [4, 5, 6, 7],
+ [0, 1, 2, 3]]
+ [[20, 21, 22, 23],
+ [16, 17, 18, 19],
+ [12, 13, 14, 15]]]]
+```
+
+tensor: Up to 8-D.
+dims: 1-D. The dimensions to reverse.
+output: The same shape as `tensor`.
+)Doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("EditDistance")
+ .Input("hypothesis_indices: int64")
+ .Input("hypothesis_values: T")
+ .Input("hypothesis_shape: int64")
+ .Input("truth_indices: int64")
+ .Input("truth_values: T")
+ .Input("truth_shape: int64")
+ .Attr("normalize: bool = True")
+ .Attr("T: type")
+ .Output("output: float")
+ .Doc(R"doc(
+Computes the (possibly normalized) Levenshtein Edit Distance.
+
+The inputs are variable-length sequences provided by SparseTensors
+ (hypothesis_indices, hypothesis_values, hypothesis_shape)
+and
+ (truth_indices, truth_values, truth_shape).
+
+The inputs are:
+
+hypothesis_indices: The indices of the hypothesis list SparseTensor.
+ This is an N x R int64 matrix.
+hypothesis_values: The values of the hypothesis list SparseTensor.
+ This is an N-length vector.
+hypothesis_shape: The shape of the hypothesis list SparseTensor.
+ This is an R-length vector.
+truth_indices: The indices of the truth list SparseTensor.
+ This is an M x R int64 matrix.
+truth_values: The values of the truth list SparseTensor.
+ This is an M-length vector.
+truth_shape: The shape of the truth list SparseTensor.
+ This is an R-length vector.
+truth_shape: truth indices, vector.
+normalize: boolean (if true, edit distances are normalized by length of truth).
+
+The output is:
+
+output: A dense float tensor with rank R - 1.
+
+For the example input:
+
+ // hypothesis represents a 2x1 matrix with variable-length values:
+ // (0,0) = ["a"]
+ // (1,0) = ["b"]
+ hypothesis_indices = [[0, 0, 0],
+ [1, 0, 0]]
+ hypothesis_values = ["a", "b"]
+ hypothesis_shape = [2, 1, 1]
+
+ // truth represents a 2x2 matrix with variable-length values:
+ // (0,0) = []
+ // (0,1) = ["a"]
+ // (1,0) = ["b", "c"]
+ // (1,1) = ["a"]
+ truth_indices = [[0, 1, 0],
+ [1, 0, 0],
+ [1, 0, 1],
+ [1, 1, 0]]
+ truth_values = ["a", "b", "c", "a"]
+ truth_shape = [2, 2, 2]
+ normalize = true
+
+The output will be:
+
+ // output is a 2x2 matrix with edit distances normalized by truth lengths.
+ output = [[inf, 1.0], // (0,0): no truth, (0,1): no hypothesis
+ [0.5, 1.0]] // (1,0): addition, (1,1): no hypothesis
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Fill")
+ .Input("dims: int32")
+ .Input("value: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Creates a tensor filled with a scalar value.
+
+This operation creates a tensor of shape `dims` and fills it with `value`.
+
+For example:
+
+```prettyprint
+# output tensor shape needs to be [2, 3]
+# so 'dims' is [2, 3]
+fill(dims, 9) ==> [[9, 9, 9]
+ [9, 9, 9]]
+```
+
+dims: 1-D. Represents the shape of the output tensor.
+value: 0-D (scalar). Value to fill the returned tensor.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Gather")
+ .Input("params: Tparams")
+ .Input("indices: Tindices")
+ .Output("output: Tparams")
+ .Attr("Tparams: type")
+ .Attr("Tindices: {int32,int64}")
+ .Doc(R"doc(
+Gather slices from `params` according to `indices`.
+
+`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).
+Produces an output tensor with shape `indices.shape + params.shape[1:]` where:
+
+ # Scalar indices
+ output[:, ..., :] = params[indices, :, ... :]
+
+ # Vector indices
+ output[i, :, ..., :] = params[indices[i], :, ... :]
+
+ # Higher rank indices
+ output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]
+
+If `indices` is a permutation and `len(indices) == params.shape[0]` then
+this operation will permute `params` accordingly.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/Gather.png" alt>
+</div>
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Identity")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"Doc(
+Return a tensor with the same shape and contents as the input tensor or value.
+)Doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("RefIdentity")
+ .Input("input: Ref(T)")
+ .Output("output: Ref(T)")
+ .Attr("T: type")
+ .Doc(R"Doc(
+Return the same ref tensor as the input ref tensor.
+)Doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("StopGradient")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"Doc(
+Stops gradient computation.
+
+When executed in a graph, this op outputs its input tensor as-is.
+
+When building ops to compute gradients, this op prevents the contribution of
+its inputs to be taken into account. Normally, the gradient generator adds ops
+to a graph to compute the derivatives of a specified 'loss' by recursively
+finding out inputs that contributed to its computation. If you insert this op
+in the graph it inputs are masked from the gradient generator. They are not
+taken into account for computing gradients.
+
+This is useful any time you want to compute a value with TensorFlow but need
+to pretend that the value was a constant. Some examples include:
+
+* The *EM* algorithm where the *M-step* should not involve backpropagation
+ through the output of the *E-step*.
+* Contrastive divergence training of Boltzmann machines where, when
+ differentiating the energy function, the training must not backpropagate
+ through the graph that generated the samples from the model.
+* Adversarial training, where no backprop should happen through the adversarial
+ example generation process.
+)Doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("CheckNumerics")
+ .Input("tensor: T")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Attr("message: string")
+ .Doc(R"doc(
+Checks a tensor for NaN and Inf values.
+
+When run, reports an `InvalidArgument` error if `tensor` has any values
+that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
+
+message: Prefix of the error message.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Reshape")
+ .Input("tensor: T")
+ .Input("shape: int32")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"Doc(
+Reshapes a tensor.
+
+Given `tensor`, this operation returns a tensor that has the same values
+as `tensor` with shape `shape`.
+
+If `shape` is the special value `[-1]`, then `tensor` is flattened and the
+operation outputs a 1-D tensor with all elements of `tensor`.
+
+If `shape` is 1-D or higher, then the operation returns a tensor with shape
+`shape` filled with the values of `tensor`. In this case, the number of elements
+implied by `shape` must be the same as the number of elements in `tensor`.
+
+For example:
+
+```prettyprint
+# tensor 't' is [1, 2, 3, 4, 5, 6, 7, 8, 9]
+# tensor 't' has shape [9]
+reshape(t, [3, 3]) ==> [[1, 2, 3]
+ [4, 5, 6]
+ [7, 8, 9]]
+
+# tensor 't' is [[[1, 1], [2, 2]]
+# [[3, 3], [4, 4]]]
+# tensor 't' has shape [2, 2]
+reshape(t, [2, 4]) ==> [[1, 1, 2, 2]
+ [3, 3, 4, 4]]
+
+# tensor 't' is [[[1, 1, 1],
+# [2, 2, 2]],
+# [[3, 3, 3],
+# [4, 4, 4]],
+# [[5, 5, 5],
+# [6, 6, 6]]]
+# tensor 't' has shape [3, 2, 3]
+# pass '[-1]' to flatten 't'
+reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]
+```
+
+shape: Defines the shape of the output tensor.
+)Doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("InvertPermutation")
+ .Input("x: int32")
+ .Output("y: int32")
+ .Doc(R"doc(
+Computes the inverse permutation of a tensor.
+
+This operation computes the inverse of an index permutation. It takes a 1-D
+integer tensor `x`, which represents the indices of a zero-based array, and
+swaps each value with its index position. In other words, for an ouput tensor
+`y` and an input tensor `x`, this operation computes the following:
+
+`y[x[i]] = i for i in [0, 1, ..., len(x) - 1]`
+
+The values must include 0. There can be no duplicate values or negative values.
+
+For example:
+
+```prettyprint
+# tensor `x` is [3, 4, 0, 2, 1]
+invert_permutation(x) ==> [2, 4, 3, 0, 1]
+```
+
+x: 1-D.
+y: 1-D.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Transpose")
+ .Input("x: T")
+ .Input("perm: int32")
+ .Output("y: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Shuffle dimensions of x according to a permutation.
+
+The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
+ `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Unique")
+ .Input("x: T")
+ .Output("y: T")
+ .Output("idx: int32")
+ .Attr("T: type")
+ .Doc(R"doc(
+Finds unique elements in a 1-D tensor.
+
+This operation returns a tensor `y` containing all of the unique elements of `x`
+sorted in the same order that they occur in `x`. This operation also returns a
+tensor `idx` the same size as `x` that contains the index of each value of `x`
+in the unique output `y`. In other words:
+
+`y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]`
+
+For example:
+
+```prettyprint
+# tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8]
+y, idx = unique(x)
+y ==> [1, 2, 4, 7, 8]
+idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4]
+```
+
+x: 1-D.
+y: 1-D.
+idx: 1-D.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Shape")
+ .Input("input: T")
+ .Output("output: int32")
+ .Attr("T: type")
+ .Doc(R"doc(
+Returns the shape of a tensor.
+
+This operation returns a 1-D integer tensor representing the shape of `input`.
+
+For example:
+
+```prettyprint
+# 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
+shape(t) ==> [2, 2, 3]
+```
+
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ReverseSequence")
+ .Input("input: T")
+ .Input("seq_lengths: int64")
+ .Output("output: T")
+ .Attr("seq_dim: int")
+ .Attr("T: type")
+ .Doc(R"doc(
+Reverses variable length slices in dimension `seq_dim`.
+
+This op first slices `input` along the first dimension, and for each slice `i`,
+reverses the first `seq_lengths[i]` elements along the dimension `seq_dim`.
+
+The elements of `seq_lengths` must obey `seq_lengths[i] < input.dims[seq_dim]`,
+and `seq_lengths` must be a vector of length `input.dims(0)`.
+
+The output slice `i` along dimension 0 is then given by input slice `i`, with
+the first `seq_lengths[i]` slices along dimension `seq_dim` reversed.
+
+For example:
+
+```prettyprint
+# Given this:
+seq_dim = 1
+input.dims = (4, ...)
+seq_lengths = [7, 2, 3, 5]
+
+# then slices of input are reversed on seq_dim, but only up to seq_lengths:
+output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...]
+output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...]
+output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...]
+output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...]
+
+# while entries past seq_lens are copied through:
+output[0, 7:, :, ...] = input[0, 7:, :, ...]
+output[1, 2:, :, ...] = input[1, 2:, :, ...]
+output[2, 3:, :, ...] = input[2, 3:, :, ...]
+output[3, 2:, :, ...] = input[3, 2:, :, ...]
+```
+
+input: The input to reverse.
+seq_lengths: 1-D with length `input.dims(0)` and
+ `max(seq_lengths) < input.dims(seq_dim)`
+seq_dim: The dimension which is partially reversed.
+output: The partially reversed input. It has the same shape as `input`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Rank")
+ .Input("input: T")
+ .Output("output: int32")
+ .Attr("T: type")
+ .Doc(R"doc(
+Returns the rank of a tensor.
+
+This operation returns an integer representing the rank of `input`.
+
+For example:
+
+```prettyprint
+# 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
+# shape of tensor 't' is [2, 2, 3]
+rank(t) ==> 3
+```
+
+**Note**: The rank of a tensor is not the same as the rank of a matrix. The rank
+of a tensor is the number of indices required to uniquely select each element
+of the tensor. Rank is also known as "order", "degree", or "ndims."
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Size")
+ .Input("input: T")
+ .Output("output: int32")
+ .Attr("T: type")
+ .Doc(R"doc(
+Returns the size of a tensor.
+
+This operation returns an integer representing the number of elements in
+`input`.
+
+For example:
+
+```prettyprint
+# 't' is [[[1, 1,, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]]
+size(t) ==> 12
+```
+
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Slice")
+ .Input("input: T")
+ .Input("begin: Index")
+ .Input("size: Index")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("Index: {int32,int64}")
+ .Doc(R"doc(
+Return a slice from 'input'.
+
+The output tensor is a tensor with dimensions described by 'size'
+whose values are extracted from 'input' starting at the offsets in
+'begin'.
+
+*Requirements*:
+ 0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n)
+
+begin: begin[i] specifies the offset into the 'i'th dimension of
+ 'input' to slice from.
+size: size[i] specifies the number of elements of the 'i'th dimension
+ of 'input' to slice. If size[i] is -1, all remaining elements in dimension
+ i are included in the slice (i.e. this is equivalent to setting
+ size[i] = input.dim_size(i) - begin[i]).
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Tile")
+ .Input("input: T")
+ .Input("multiples: int32")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Constructs a tensor by tiling a given tensor.
+
+This operation creates a new tensor by replicating `input` `multiples` times.
+The output tensor's i'th dimension has `input.dims(i) * multiples[i]` elements,
+and the values of `input` are replicated `multiples[i]` times along the 'i'th
+dimension. For example, tiling `[a b c d]` by `[2]` produces
+`[a b c d a b c d]`.
+
+input: 1-D or higher.
+multiples: 1-D. Length must be the same as the number of dimensions in `input`
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("TileGrad")
+ .Input("input: T")
+ .Input("multiples: int32")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Returns the gradient of `Tile`.
+
+Since `Tile` takes an input and repeats the input `multiples` times
+along each dimension, `TileGrad` takes in `multiples` and aggregates
+each repeated tile of `input` into `output`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Where")
+ .Input("input: bool")
+ .Output("index: int64")
+ .Doc(R"doc(
+Returns locations of true values in a boolean tensor.
+
+This operation returns the coordinates of true elements in `input`. The
+coordinates are returned in a 2-D tensor where the first dimension (rows)
+represents the number of true elements, and the second dimension (columns)
+represents the coordinates of the true elements. Keep in mind, the shape of
+the output tensor can vary depending on how many true values there are in
+`input`. Indices are output in row-major order.
+
+For example:
+
+```prettyprint
+# 'input' tensor is [[True, False]
+# [True, False]]
+# 'input' has two true values, so output has two coordinates.
+# 'input' has rank of 2, so coordinates have two indices.
+where(input) ==> [[0, 0],
+ [1, 0]]
+
+# `input` tensor is [[[True, False]
+# [True, False]]
+# [[False, True]
+# [False, True]]
+# [[False, False]
+# [False, True]]]
+# 'input' has 5 true values, so output has 5 coordinates.
+# 'input' has rank of 3, so coordinates have three indices.
+where(input) ==> [[0, 0, 0],
+ [0, 1, 0],
+ [1, 0, 1],
+ [1, 1, 1],
+ [2, 1, 1]]
+```
+
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("BroadcastGradientArgs")
+ .Input("s0: int32")
+ .Input("s1: int32")
+ .Output("r0: int32")
+ .Output("r1: int32")
+ .Doc(R"doc(
+Return the reduction indices for computing gradients of s0 op s1 with broadcast.
+
+This is typically used by gradient computations for a broadcasting operation.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("Pad")
+ .Input("input: T")
+ .Input("paddings: int32")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Pads a tensor with zeros.
+
+This operation pads a `input` with zeros according to the `paddings` you
+specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the
+rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
+how many zeros to add before the contents of `input` in that dimension, and
+`paddings[D, 1]` indicates how many zeros to add after the contents of `input`
+in that dimension.
+
+The padded size of each dimension D of the output is:
+
+`paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
+
+For example:
+
+```prettyprint
+# 't' is [[1, 1], [2, 2]]
+# 'paddings' is [[1, 1]], [2, 2]]
+# rank of 't' is 2
+pad(t, paddings) ==> [[0, 0, 0, 0, 0]
+ [0, 0, 0, 0, 0]
+ [0, 1, 1, 0, 0]
+ [[0, 2, 2, 0, 0]
+ [0, 0, 0, 0, 0]]
+```
+
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Placeholder")
+ .Output("output: dtype")
+ .Attr("dtype: type")
+ .Attr("shape: shape")
+ .Doc(R"doc(
+A placeholder op for a value that will be fed into the computation.
+
+N.B. This operation will fail with an error if it is executed. It is
+intended as a way to represent a value that will always be fed, and to
+provide attrs that enable the fed value to be checked at runtime.
+
+output: A placeholder tensor that must be replaced using the feed mechanism.
+dtype: The type of elements in the tensor.
+shape: (Optional) The shape of the tensor. If the shape has 0 dimensions, the
+ shape is unconstrained.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ExpandDims")
+ .Input("input: T")
+ .Input("dim: int32")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Inserts a dimension of 1 into a tensor's shape.
+
+Given a tensor `input`, this operation inserts a dimension of 1 at the
+dimension index `dim` of `input`'s shape. The dimension index `dim` starts at
+zero; if you specify a negative number for `dim` it is counted backward from
+the end.
+
+This operation is useful if you want to add a batch dimension to a single
+element. For example, if you have a single image of shape `[height, width,
+channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`,
+which will make the shape `[1, height, width, channels]`.
+
+Other examples:
+
+```prettyprint
+# 't' is a tensor of shape [2]
+shape(expand_dims(t, 0)) ==> [1, 2]
+shape(expand_dims(t, 1)) ==> [2, 1]
+shape(expand_dims(t, -1)) ==> [2, 1]
+
+# 't2' is a tensor of shape [2, 3, 5]
+shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5]
+shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5]
+shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1]
+```
+
+This operation requires that:
+
+`-1-input.dims() <= dim <= input.dims()`
+
+This operation is related to `squeeze()`, which removes dimensions of
+size 1.
+
+dim: 0-D (scalar). Specifies the dimension index at which to
+ expand the shape of `input`.
+output: Contains the same data as `input`, but its shape has an additional
+ dimension of size 1 added.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Squeeze")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("squeeze_dims: list(int) >= 0 = []")
+ .Doc(R"doc(
+Removes dimensions of size 1 from the shape of a tensor.
+
+Given a tensor `input`, this operation returns a tensor of the same type with
+all dimensions of size 1 removed. If you don't want to remove all size 1
+dimensions, you can remove specific size 1 dimensions by specifying
+`squeeze_dims`.
+
+For example:
+
+```prettyprint
+# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
+shape(squeeze(t)) ==> [2, 3]
+```
+
+Or, to remove specific size 1 dimensions:
+
+```prettyprint
+# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
+shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
+```
+
+input: The `input` to squeeze.
+squeeze_dims: If specified, only squeezes the dimensions listed. The dimension
+ index starts at 0. It is an error to squeeze a dimension that is not 1.
+output: Contains the same data as `input`, but has one or more dimensions of
+ size 1 removed.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ListDiff")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("out: T")
+ .Output("idx: int32")
+ .Attr("T: type")
+ .Doc(R"doc(
+Computes the difference between two lists of numbers.
+
+Given a list `x` and a list `y`, this operation returns a list `out` that
+represents all numbers that are in `x` but not in `y`. The returned list `out`
+is sorted in the same order that the numbers appear in `x` (duplicates are
+preserved). This operation also returns a list `idx` that represents the
+position of each `out` element in `x`. In other words:
+
+`out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]`
+
+For example, given this input:
+
+```prettyprint
+x = [1, 2, 3, 4, 5, 6]
+y = [1, 3, 5]
+```
+
+This operation would return:
+
+```prettyprint
+out ==> [2, 4, 6]
+idx ==> [1, 3, 5]
+```
+
+x: 1-D. Values to keep.
+y: 1-D. Values to remove.
+out: 1-D. Values present in `x` but not in `y`.
+idx: 1-D. Positions of `x` values preserved in `out`.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/attention_ops.cc b/tensorflow/core/ops/attention_ops.cc
new file mode 100644
index 0000000000..6fa9a6e821
--- /dev/null
+++ b/tensorflow/core/ops/attention_ops.cc
@@ -0,0 +1,54 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+// Tout = extract_glimpse(Tin, size, offsets) extract the glimpse of size size
+// centered at location offsets from the input tensor Tin
+//
+// REQUIRES: Tin.dims() == 4
+//
+REGISTER_OP("ExtractGlimpse")
+ .Input("input: float")
+ .Input("size: int32")
+ .Input("offsets: float")
+ .Output("glimpse: float")
+ .Attr("centered: bool = true")
+ .Attr("normalized: bool = true")
+ .Attr("uniform_noise: bool = true")
+ .Doc(R"doc(
+Extracts a glimpse from the input tensor.
+
+Returns a set of windows called glimpses extracted at location `offsets`
+from the input tensor. If the windows only partially overlaps the inputs, the
+non overlapping areas will be filled with random noise.
+
+The result is a 4-D tensor of shape `[batch_size, glimpse_height,
+glimpse_width, channels]`. The channels and batch dimensions are the same as that
+of the input tensor. The height and width of the output windows are
+specified in the `size` parameter.
+
+The argument `normalized` and `centered` controls how the windows are built:
+* If the coordinates are normalized but not centered, 0.0 and 1.0
+ correspond to the minimum and maximum of each height and width dimension.
+* If the coordinates are both normalized and centered, they range from -1.0 to
+ 1.0. The coordinates (-1.0, -1.0) correspond to the upper left corner, the
+ lower right corner is located at (1.0, 1.0) and the center is at (0, 0).
+* If the coordinates are not normalized they are interpreted as numbers of pixels.
+
+input: A 4-D float tensor of shape `[batch_size, height, width, channels]`.
+size: A 1-D tensor of 2 elements containing the size of the glimpses to extract.
+ The glimpse height must be specified first, following by the glimpse width.
+offsets: A 2-D integer tensor of shape `[batch_size, 2]` containing the x, y
+ locations of the center of each window.
+glimpse: A tensor representing the glimpses `[batch_size, glimpse_height,
+ glimpse_width, channels]`.
+centered: indicates if the offset coordinates are centered relative to
+ the image, in which case the (0, 0) offset is relative to the center of the
+ input images. If false, the (0,0) offset corresponds to the upper left corner
+ of the input images.
+normalized: indicates if the offset coordinates are normalized.
+uniform_noise: indicates if the noise should be generated using a
+ uniform distribution or a gaussian distribution.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/candidate_sampling_ops.cc b/tensorflow/core/ops/candidate_sampling_ops.cc
new file mode 100644
index 0000000000..a98b0295ee
--- /dev/null
+++ b/tensorflow/core/ops/candidate_sampling_ops.cc
@@ -0,0 +1,351 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("UniformCandidateSampler")
+ .Input("true_classes: int64")
+ .Output("sampled_candidates: int64")
+ .Output("true_expected_count: float")
+ .Output("sampled_expected_count: float")
+ .Attr("num_true: int >= 1")
+ .Attr("num_sampled: int >= 1")
+ .Attr("unique: bool")
+ .Attr("range_max: int >= 1")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Doc(R"doc(
+Generates labels for candidate sampling with a uniform distribution.
+
+See explanations of candidate sampling and the data formats at
+go/candidate-sampling.
+
+For each batch, this op picks a single set of sampled candidate labels.
+
+The advantages of sampling candidates per-batch are simplicity and the
+possibility of efficient dense matrix multiplication. The disadvantage is that
+the sampled candidates must be chosen independently of the context and of the
+true labels.
+
+true_classes: A batch_size * num_true matrix, in which each row contains the
+ IDs of the num_true target_classes in the corresponding original label.
+sampled_candidates: A vector of length num_sampled, in which each element is
+ the ID of a sampled candidate.
+true_expected_count: A batch_size * num_true matrix, representing
+ the number of times each candidate is expected to occur in a batch
+ of sampled candidates. If unique=true, then this is a probability.
+sampled_expected_count: A vector of length num_sampled, for each sampled
+ candidate represting the number of times the candidate is expected
+ to occur in a batch of sampled candidates. If unique=true, then this is a
+ probability.
+num_true: Number of true labels per context.
+num_sampled: Number of candidates to randomly sample per batch.
+unique: If unique is true, we sample with rejection, so that all sampled
+ candidates in a batch are unique. This requires some approximation to
+ estimate the post-rejection sampling probabilities.
+range_max: The sampler will sample integers from the interval [0, range_max).
+seed: If either seed or seed2 are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: An second seed to avoid seed collision.
+)doc");
+
+REGISTER_OP("LogUniformCandidateSampler")
+ .Input("true_classes: int64")
+ .Output("sampled_candidates: int64")
+ .Output("true_expected_count: float")
+ .Output("sampled_expected_count: float")
+ .Attr("num_true: int >= 1")
+ .Attr("num_sampled: int >= 1")
+ .Attr("unique: bool")
+ .Attr("range_max: int >= 1")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Doc(R"doc(
+Generates labels for candidate sampling with a log-uniform distribution.
+
+See explanations of candidate sampling and the data formats at
+go/candidate-sampling.
+
+For each batch, this op picks a single set of sampled candidate labels.
+
+The advantages of sampling candidates per-batch are simplicity and the
+possibility of efficient dense matrix multiplication. The disadvantage is that
+the sampled candidates must be chosen independently of the context and of the
+true labels.
+
+
+true_classes: A batch_size * num_true matrix, in which each row contains the
+ IDs of the num_true target_classes in the corresponding original label.
+sampled_candidates: A vector of length num_sampled, in which each element is
+ the ID of a sampled candidate.
+true_expected_count: A batch_size * num_true matrix, representing
+ the number of times each candidate is expected to occur in a batch
+ of sampled candidates. If unique=true, then this is a probability.
+sampled_expected_count: A vector of length num_sampled, for each sampled
+ candidate represting the number of times the candidate is expected
+ to occur in a batch of sampled candidates. If unique=true, then this is a
+ probability.
+num_true: Number of true labels per context.
+num_sampled: Number of candidates to randomly sample per batch.
+unique: If unique is true, we sample with rejection, so that all sampled
+ candidates in a batch are unique. This requires some approximation to
+ estimate the post-rejection sampling probabilities.
+range_max: The sampler will sample integers from the interval [0, range_max).
+seed: If either seed or seed2 are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: An second seed to avoid seed collision.
+)doc");
+
+REGISTER_OP("LearnedUnigramCandidateSampler")
+ .Input("true_classes: int64")
+ .Output("sampled_candidates: int64")
+ .Output("true_expected_count: float")
+ .Output("sampled_expected_count: float")
+ .Attr("num_true: int >= 1")
+ .Attr("num_sampled: int >= 1")
+ .Attr("unique: bool")
+ .Attr("range_max: int >= 1")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Doc(R"doc(
+Generates labels for candidate sampling with a learned unigram distribution.
+
+See explanations of candidate sampling and the data formats at
+go/candidate-sampling.
+
+For each batch, this op picks a single set of sampled candidate labels.
+
+The advantages of sampling candidates per-batch are simplicity and the
+possibility of efficient dense matrix multiplication. The disadvantage is that
+the sampled candidates must be chosen independently of the context and of the
+true labels.
+
+true_classes: A batch_size * num_true matrix, in which each row contains the
+ IDs of the num_true target_classes in the corresponding original label.
+sampled_candidates: A vector of length num_sampled, in which each element is
+ the ID of a sampled candidate.
+true_expected_count: A batch_size * num_true matrix, representing
+ the number of times each candidate is expected to occur in a batch
+ of sampled candidates. If unique=true, then this is a probability.
+sampled_expected_count: A vector of length num_sampled, for each sampled
+ candidate represting the number of times the candidate is expected
+ to occur in a batch of sampled candidates. If unique=true, then this is a
+ probability.
+num_true: Number of true labels per context.
+num_sampled: Number of candidates to randomly sample per batch.
+unique: If unique is true, we sample with rejection, so that all sampled
+ candidates in a batch are unique. This requires some approximation to
+ estimate the post-rejection sampling probabilities.
+range_max: The sampler will sample integers from the interval [0, range_max).
+seed: If either seed or seed2 are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: An second seed to avoid seed collision.
+)doc");
+
+REGISTER_OP("ThreadUnsafeUnigramCandidateSampler")
+ .Input("true_classes: int64")
+ .Output("sampled_candidates: int64")
+ .Output("true_expected_count: float")
+ .Output("sampled_expected_count: float")
+ .Attr("num_true: int >= 1")
+ .Attr("num_sampled: int >= 1")
+ .Attr("unique: bool")
+ .Attr("range_max: int >= 1")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Doc(R"doc(
+Generates labels for candidate sampling with a learned unigram distribution.
+
+See explanations of candidate sampling and the data formats at
+go/candidate-sampling.
+
+For each batch, this op picks a single set of sampled candidate labels.
+
+The advantages of sampling candidates per-batch are simplicity and the
+possibility of efficient dense matrix multiplication. The disadvantage is that
+the sampled candidates must be chosen independently of the context and of the
+true labels.
+
+true_classes: A batch_size * num_true matrix, in which each row contains the
+ IDs of the num_true target_classes in the corresponding original label.
+sampled_candidates: A vector of length num_sampled, in which each element is
+ the ID of a sampled candidate.
+true_expected_count: A batch_size * num_true matrix, representing
+ the number of times each candidate is expected to occur in a batch
+ of sampled candidates. If unique=true, then this is a probability.
+sampled_expected_count: A vector of length num_sampled, for each sampled
+ candidate represting the number of times the candidate is expected
+ to occur in a batch of sampled candidates. If unique=true, then this is a
+ probability.
+num_true: Number of true labels per context.
+num_sampled: Number of candidates to randomly sample per batch.
+unique: If unique is true, we sample with rejection, so that all sampled
+ candidates in a batch are unique. This requires some approximation to
+ estimate the post-rejection sampling probabilities.
+range_max: The sampler will sample integers from the interval [0, range_max).
+seed: If either seed or seed2 are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: An second seed to avoid seed collision.
+)doc");
+
+REGISTER_OP("FixedUnigramCandidateSampler")
+ .Input("true_classes: int64")
+ .Output("sampled_candidates: int64")
+ .Output("true_expected_count: float")
+ .Output("sampled_expected_count: float")
+ .Attr("num_true: int >= 1")
+ .Attr("num_sampled: int >= 1")
+ .Attr("unique: bool")
+ .Attr("range_max: int >= 1")
+ .Attr("vocab_file: string = ''")
+ .Attr("distortion: float = 1.0")
+ .Attr("num_reserved_ids: int = 0")
+ .Attr("num_shards: int >= 1 = 1")
+ .Attr("shard: int >= 0 = 0")
+ .Attr("unigrams: list(float) = []")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Doc(R"doc(
+Generates labels for candidate sampling with a learned unigram distribution.
+
+A unigram sampler could use a fixed unigram distribution read from a
+file or passed in as an in-memory array instead of building up the distribution
+from data on the fly. There is also an option to skew the distribution by
+applying a distortion power to the weights.
+
+The vocabulary file should be in CSV-like format, with the last field
+being the weight associated with the word.
+
+For each batch, this op picks a single set of sampled candidate labels.
+
+The advantages of sampling candidates per-batch are simplicity and the
+possibility of efficient dense matrix multiplication. The disadvantage is that
+the sampled candidates must be chosen independently of the context and of the
+true labels.
+
+true_classes: A batch_size * num_true matrix, in which each row contains the
+ IDs of the num_true target_classes in the corresponding original label.
+sampled_candidates: A vector of length num_sampled, in which each element is
+ the ID of a sampled candidate.
+true_expected_count: A batch_size * num_true matrix, representing
+ the number of times each candidate is expected to occur in a batch
+ of sampled candidates. If unique=true, then this is a probability.
+sampled_expected_count: A vector of length num_sampled, for each sampled
+ candidate represting the number of times the candidate is expected
+ to occur in a batch of sampled candidates. If unique=true, then this is a
+ probability.
+num_true: Number of true labels per context.
+num_sampled: Number of candidates to randomly sample per batch.
+unique: If unique is true, we sample with rejection, so that all sampled
+ candidates in a batch are unique. This requires some approximation to
+ estimate the post-rejection sampling probabilities.
+range_max: The sampler will sample integers from the interval [0, range_max).
+vocab_file: Each valid line in this file (which should have a CSV-like format)
+ corresponds to a valid word ID. IDs are in sequential order, starting from
+ num_reserved_ids. The last entry in each line is expected to be a value
+ corresponding to the count or relative probability. Exactly one of vocab_file
+ and unigrams needs to be passed to this op.
+distortion: The distortion is used to skew the unigram probability distribution.
+ Each weight is first raised to the distortion's power before adding to the
+ internal unigram distribution. As a result, distortion = 1.0 gives regular
+ unigram sampling (as defined by the vocab file), and distortion = 0.0 gives
+ a uniform distribution.
+num_reserved_ids: Optionally some reserved IDs can be added in the range [0,
+ ..., num_reserved_ids) by the users. One use case is that a special unknown
+ word token is used as ID 0. These IDs will have a sampling probability of 0.
+num_shards: A sampler can be used to sample from a subset of the original range
+ in order to speed up the whole computation through parallelism. This parameter
+ (together with 'shard') indicates the number of partitions that are being
+ used in the overall computation.
+shard: A sampler can be used to sample from a subset of the original range
+ in order to speed up the whole computation through parallelism. This parameter
+ (together with 'num_shards') indicates the particular partition number of a
+ sampler op, when partitioning is being used.
+unigrams: A list of unigram counts or probabilities, one per ID in sequential
+ order. Exactly one of vocab_file and unigrams should be passed to this op.
+seed: If either seed or seed2 are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: An second seed to avoid seed collision.
+)doc");
+
+REGISTER_OP("AllCandidateSampler")
+ .Input("true_classes: int64")
+ .Output("sampled_candidates: int64")
+ .Output("true_expected_count: float")
+ .Output("sampled_expected_count: float")
+ .Attr("num_true: int >= 1")
+ .Attr("num_sampled: int >= 1")
+ .Attr("unique: bool")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Doc(R"doc(
+Generates labels for candidate sampling with a learned unigram distribution.
+
+See explanations of candidate sampling and the data formats at
+go/candidate-sampling.
+
+For each batch, this op picks a single set of sampled candidate labels.
+
+The advantages of sampling candidates per-batch are simplicity and the
+possibility of efficient dense matrix multiplication. The disadvantage is that
+the sampled candidates must be chosen independently of the context and of the
+true labels.
+
+true_classes: A batch_size * num_true matrix, in which each row contains the
+ IDs of the num_true target_classes in the corresponding original label.
+sampled_candidates: A vector of length num_sampled, in which each element is
+ the ID of a sampled candidate.
+true_expected_count: A batch_size * num_true matrix, representing
+ the number of times each candidate is expected to occur in a batch
+ of sampled candidates. If unique=true, then this is a probability.
+sampled_expected_count: A vector of length num_sampled, for each sampled
+ candidate represting the number of times the candidate is expected
+ to occur in a batch of sampled candidates. If unique=true, then this is a
+ probability.
+num_true: Number of true labels per context.
+num_sampled: Number of candidates to produce per batch.
+unique: If unique is true, we sample with rejection, so that all sampled
+ candidates in a batch are unique. This requires some approximation to
+ estimate the post-rejection sampling probabilities.
+seed: If either seed or seed2 are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: An second seed to avoid seed collision.
+)doc");
+
+REGISTER_OP("ComputeAccidentalHits")
+ .Input("true_classes: int64")
+ .Input("sampled_candidates: int64")
+ .Output("indices: int32")
+ .Output("ids: int64")
+ .Output("weights: float")
+ .Attr("num_true: int")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Doc(R"doc(
+Computes the ids of the positions in sampled_candidates that match true_labels.
+
+When doing log-odds NCE, the result of this op should be passed through a
+SparseToDense op, then added to the logits of the sampled candidates. This has
+the effect of 'removing' the sampled labels that match the true labels by
+making the classifier sure that they are sampled labels.
+
+true_classes: The true_classes output of UnpackSparseLabels.
+sampled_candidates: The sampled_candidates output of CandidateSampler.
+indices: A vector of indices corresponding to rows of true_candidates.
+ids: A vector of IDs of positions in sampled_candidates that match a true_label
+ for the row with the corresponding index in indices.
+weights: A vector of the same length as indices and ids, in which each element
+ is -FLOAT_MAX.
+num_true: Number of true labels per context.
+seed: If either seed or seed2 are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: An second seed to avoid seed collision.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/control_flow_ops.cc b/tensorflow/core/ops/control_flow_ops.cc
new file mode 100644
index 0000000000..517b2d2742
--- /dev/null
+++ b/tensorflow/core/ops/control_flow_ops.cc
@@ -0,0 +1,179 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Switch")
+ .Input("data: T")
+ .Input("pred: bool")
+ .Output("output_false: T")
+ .Output("output_true: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Forwards `data` to the output port determined by `pred`.
+
+If `pred` is true, the `data` input is forwared to `output_true`. Otherwise,
+the data goes to `output_false`.
+
+See also `RefSwitch` and `Merge`.
+
+data: The tensor to be forwarded to the appropriate output.
+pred: A scalar that specifies which output port will receive data.
+output_false: If `pred` is false, data will be forwarded to this output.
+output_true: If `pred` is true, data will be forwarded to this output.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("RefSwitch")
+ .Input("data: Ref(T)")
+ .Input("pred: bool")
+ .Output("output_false: Ref(T)")
+ .Output("output_true: Ref(T)")
+ .Attr("T: type")
+ .Doc(R"doc(
+Forwards the ref tensor `data` to the output port determined by `pred`.
+
+If `pred` is true, the `data` input is forwared to `output_true`. Otherwise,
+the data goes to `output_false`.
+
+See also `Switch` and `Merge`.
+
+data: The ref tensor to be forwarded to the appropriate output.
+pred: A scalar that specifies which output port will receive data.
+output_false: If `pred` is false, data will be forwarded to this output.
+output_true: If `pred` is true, data will be forwarded to this output.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("RefSelect")
+ .Input("index: int32")
+ .Input("inputs: Ref(N * T)")
+ .Output("output: Ref(T)")
+ .Attr("T: type")
+ .Attr("N: int >= 1")
+ .Doc(R"doc(
+Forwards the `index`th element of `inputs` to `output`.
+
+index: A scalar that determines the input that gets selected.
+inputs: A list of ref tensors, one of which will be forwarded to `output`.
+output: The forwarded tensor.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Merge")
+ .Input("inputs: N * T")
+ .Output("output: T")
+ .Output("value_index: int32")
+ .Attr("T: type")
+ .Attr("N: int >= 1")
+ .Doc(R"doc(
+Forwards the value of an available tensor from `inputs` to `output`.
+
+`Merge` waits for at least one of the tensors in `inputs` to become available.
+It is usually combined with `Switch` to implement branching.
+
+`Merge` forwards the first tensor for become available to `output`, and sets
+`value_index` to its index in `inputs`.
+
+It is an error if more than one tensor in `inputs` is available.
+
+inputs: The input tensors, exactly one of which will become available.
+output: Will be set to the available input tensor.
+value_index: The index of the chosen input tensor in `inputs`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Enter")
+ .Input("data: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("frame_name: string")
+ .Attr("is_constant: bool = false")
+ .Attr("parallel_iterations: int = 10")
+ .Doc(R"doc(
+Creates or finds a child frame, and makes `data` available to the child frame.
+
+This op is used together with `Exit` to create loops in the graph.
+The unique `frame_name` is used by the `Executor` to identify frames. If
+`is_constant` is true, `output` is a constant in the child frame; otherwise
+it may be changed in the child frame. At most `parallel_iterations` iterations
+are run in parallel in the child frame.
+
+data: The tensor to be made available to the child frame.
+frame_name: The name of the child frame.
+is_constant: If true, the output is constant within the child frame.
+parallel_iterations: The number of iterations allowed to run in parallel.
+output: The same tensor as `data`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("RefEnter")
+ .Input("data: Ref(T)")
+ .Output("output: Ref(T)")
+ .Attr("T: type")
+ .Attr("frame_name: string")
+ .Attr("is_constant: bool = false")
+ .Attr("parallel_iterations: int = 10")
+ .Doc(R"doc(
+Creates or finds a child frame, and makes `data` available to the child frame.
+
+The unique `frame_name` is used by the `Executor` to identify frames. If
+`is_constant` is true, `output` is a constant in the child frame; otherwise
+it may be changed in the child frame. At most `parallel_iterations` iterations
+are run in parallel in the child frame.
+
+data: The tensor to be made available to the child frame.
+frame_name: The name of the child frame.
+is_constant: If true, the output is constant within the child frame.
+parallel_iterations: The number of iterations allowed to run in parallel.
+output: The same tensor as `data`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("Exit")
+ .Input("data: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Exits the current frame to its parent frame.
+
+Exit makes its input `data` available to the parent frame.
+
+data: The tensor to be made available to the parent frame.
+output: The same tensor as `data`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("NextIteration")
+ .Input("data: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Makes its input available to the next iteration.
+
+data: The tensor to be made available to the next iteration.
+output: The same tensor as `data`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("LoopCond")
+ .Input("input: bool")
+ .Output("output: bool")
+ .Doc(R"doc(
+Forwards the input to the output.
+
+This operator represents the loop termination condition used by the
+"pivot" switches of a loop.
+
+input:= A boolean scalar, representing the branch predicate of the Switch op.
+output: The same tensor as `input`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ControlTrigger")
+ .Doc(R"doc(
+Does nothing. Serves as a control trigger for scheduling. Only useful as a
+placeholder for control edges.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
new file mode 100644
index 0000000000..49eba33188
--- /dev/null
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -0,0 +1,357 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("DynamicPartition")
+ .Input("data: T")
+ .Input("partitions: int32")
+ .Output("outputs: num_partitions * T")
+ .Attr("num_partitions: int")
+ .Attr("T: type")
+ .Doc(R"doc(
+Partitions `data` into `num_partitions` tensors using indices from `partitions`.
+
+For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]`
+becomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i`
+are placed in `outputs[i]` in lexicographic order of `js`, and the first
+dimension of `outputs[i]` is the number of entries in `partitions` equal to `i`.
+In detail,
+
+ outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:]
+
+ outputs[i] = pack([data[js, ...] for js if partitions[js] == i])
+
+`data.shape` must start with `partitions.shape`.
+
+For example:
+
+ # Scalar partitions
+ partitions = 1
+ num_partitions = 2
+ data = [10, 20]
+ outputs[0] = [] # Empty with shape [0, 2]
+ outputs[1] = [[10, 20]]
+
+ # Vector partitions
+ partitions = [0, 0, 1, 1, 0]
+ num_partitions = 2
+ data = [10, 20, 30, 40, 50]
+ outputs[0] = [10, 20, 50]
+ outputs[1] = [30, 40]
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/DynamicPartition.png" alt>
+</div>
+
+partitions: Any shape. Indices in the range `[0, num_partitions)`.
+num_partitions: The number of partitions to output.
+)doc");
+
+REGISTER_OP("DynamicStitch")
+ .Input("indices: N * int32")
+ .Input("data: N * T")
+ .Output("merged: T")
+ .Attr("N : int >= 2")
+ .Attr("T : type")
+ .Doc(R"doc(
+Interleave the values from the `data` tensors into a single tensor.
+
+Builds a merged tensor such that
+
+ merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...]
+
+For example, if each `indices[m]` is scalar or vector, we have
+
+ # Scalar indices
+ merged[indices[m], ...] = data[m][...]
+
+ # Vector indices
+ merged[indices[m][i], ...] = data[m][i, ...]
+
+Each `data[i].shape` must start with the corresponding `indices[i].shape`,
+and the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we
+must have `data[i].shape = indices[i].shape + constant`. In terms of this
+`constant`, the output shape is
+
+ merged.shape = [max(indices)] + constant
+
+Values are merged in order, so if an index appears in both `indices[m][i]` and
+`indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the
+merged result.
+
+For example:
+
+ indices[0] = 6
+ indices[1] = [4, 1]
+ indices[2] = [[5, 2], [0, 3]]
+ data[0] = [61, 62]
+ data[1] = [[41, 42], [11, 12]]
+ data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]]
+ merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42],
+ [51, 52], [61, 62]]
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/DynamicStitch.png" alt>
+</div>
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("RandomShuffleQueue")
+ .Output("handle: Ref(string)")
+ .Attr("component_types: list(type) >= 1")
+ .Attr("shapes: list(shape) >= 0 = []")
+ .Attr("capacity: int = -1")
+ .Attr("min_after_dequeue: int = 0")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+A queue that randomizes the order of elements.
+
+handle: The handle to the queue.
+component_types: The type of each component in a value.
+shapes: The shape of each component in a value. The length of this attr must
+ be either 0 or the same as the length of component_types. If the length of
+ this attr is 0, the shapes of queue elements are not constrained, and
+ only one element may be dequeued at a time.
+capacity: The upper bound on the number of elements in this queue.
+ Negative numbers mean no limit.
+min_after_dequeue: Dequeue will block unless there would be this
+ many elements after the dequeue or the queue is closed. This
+ ensures a minimum level of mixing of elements.
+seed: If either seed or seed2 is set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, a random seed is used.
+seed2: A second seed to avoid seed collision.
+container: If non-empty, this queue is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this queue will be shared under the given name
+ across multiple sessions.
+)doc");
+
+REGISTER_OP("FIFOQueue")
+ .Output("handle: Ref(string)")
+ .Attr("component_types: list(type) >= 1")
+ .Attr("shapes: list(shape) >= 0 = []")
+ .Attr("capacity: int = -1")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+A queue that produces elements in first-in first-out order.
+
+handle: The handle to the queue.
+component_types: The type of each component in a value.
+shapes: The shape of each component in a value. The length of this attr must
+ be either 0 or the same as the length of component_types. If the length of
+ this attr is 0, the shapes of queue elements are not constrained, and
+ only one element may be dequeued at a time.
+capacity: The upper bound on the number of elements in this queue.
+ Negative numbers mean no limit.
+container: If non-empty, this queue is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this queue will be shared under the given name
+ across multiple sessions.
+)doc");
+
+REGISTER_OP("QueueEnqueue")
+ .Input("handle: Ref(string)")
+ .Input("components: Tcomponents")
+ .Attr("Tcomponents: list(type) >= 1")
+ .Attr("timeout_ms: int = -1")
+ .Doc(R"doc(
+Enqueues a tuple of one or more tensors in the given queue.
+
+The components input has k elements, which correspond to the components of
+tuples stored in the given queue.
+
+N.B. If the queue is full, this operation will block until the given
+element has been enqueued (or 'timeout_ms' elapses, if specified).
+
+handle: The handle to a queue.
+components: One or more tensors from which the enqueued tensors should be taken.
+timeout_ms: If the queue is full, this operation will block for up to
+ timeout_ms milliseconds.
+ Note: This option is not supported yet.
+)doc");
+
+REGISTER_OP("QueueEnqueueMany")
+ .Input("handle: Ref(string)")
+ .Input("components: Tcomponents")
+ .Attr("Tcomponents: list(type) >= 1")
+ .Attr("timeout_ms: int = -1")
+ .Doc(R"doc(
+Enqueues zero or more tuples of one or more tensors in the given queue.
+
+This operation slices each component tensor along the 0th dimension to
+make multiple queue elements. All of the tuple components must have the
+same size in the 0th dimension.
+
+The components input has k elements, which correspond to the components of
+tuples stored in the given queue.
+
+N.B. If the queue is full, this operation will block until the given
+elements have been enqueued (or 'timeout_ms' elapses, if specified).
+
+handle: The handle to a queue.
+components: One or more tensors from which the enqueued tensors should
+ be taken.
+timeout_ms: If the queue is too full, this operation will block for up
+ to timeout_ms milliseconds.
+ Note: This option is not supported yet.
+)doc");
+
+REGISTER_OP("QueueDequeue")
+ .Input("handle: Ref(string)")
+ .Output("components: component_types")
+ .Attr("component_types: list(type) >= 1")
+ .Attr("timeout_ms: int = -1")
+ .Doc(R"doc(
+Dequeues a tuple of one or more tensors from the given queue.
+
+This operation has k outputs, where k is the number of components
+in the tuples stored in the given queue, and output i is the ith
+component of the dequeued tuple.
+
+N.B. If the queue is empty, this operation will block until an element
+has been dequeued (or 'timeout_ms' elapses, if specified).
+
+handle: The handle to a queue.
+components: One or more tensors that were dequeued as a tuple.
+component_types: The type of each component in a tuple.
+timeout_ms: If the queue is empty, this operation will block for up to
+ timeout_ms milliseconds.
+ Note: This option is not supported yet.
+)doc");
+
+REGISTER_OP("QueueDequeueMany")
+ .Input("handle: Ref(string)")
+ .Input("n: int32")
+ .Output("components: component_types")
+ .Attr("component_types: list(type) >= 1")
+ .Attr("timeout_ms: int = -1")
+ .Doc(R"doc(
+Dequeues n tuples of one or more tensors from the given queue.
+
+This operation concatenates queue-element component tensors along the
+0th dimension to make a single component tensor. All of the components
+in the dequeued tuple will have size n in the 0th dimension.
+
+This operation has k outputs, where k is the number of components in
+the tuples stored in the given queue, and output i is the ith
+component of the dequeued tuple.
+
+N.B. If the queue is empty, this operation will block until n elements
+have been dequeued (or 'timeout_ms' elapses, if specified).
+
+handle: The handle to a queue.
+n: The number of tuples to dequeue.
+components: One or more tensors that were dequeued as a tuple.
+component_types: The type of each component in a tuple.
+timeout_ms: If the queue has fewer than n elements, this operation
+ will block for up to timeout_ms milliseconds.
+ Note: This option is not supported yet.
+)doc");
+
+REGISTER_OP("QueueClose")
+ .Input("handle: Ref(string)")
+ .Attr("cancel_pending_enqueues: bool = false")
+ .Doc(R"doc(
+Closes the given queue.
+
+This operation signals that no more elements will be enqueued in the
+given queue. Subsequent Enqueue(Many) operations will fail.
+Subsequent Dequeue(Many) operations will continue to succeed if
+sufficient elements remain in the queue. Subsequent Dequeue(Many)
+operations that would block will fail immediately.
+
+handle: The handle to a queue.
+cancel_pending_enqueues: If true, all pending enqueue requests that are
+ blocked on the given queue will be cancelled.
+)doc");
+
+REGISTER_OP("QueueSize")
+ .Input("handle: Ref(string)")
+ .Output("size: int32")
+ .Doc(R"doc(
+Computes the number of elements in the given queue.
+
+handle: The handle to a queue.
+size: The number of elements in the given queue.
+)doc");
+
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("LookupTableFind")
+ .Input("table_handle: Ref(string)")
+ .Input("input_values: Tin")
+ .Input("default_value: Tout")
+ .Output("output_values: Tout")
+ .Attr("Tin: type")
+ .Attr("Tout: type")
+ .Doc(R"doc(
+Maps elements of a tensor into associated values given a lookup table.
+
+If an element of the input_values is not present in the table, the
+specified default_value is used.
+
+The table needs to be initialized and the input and output types correspond
+to the table key and value types.
+
+table_handle: A handle for a lookup table.
+input_values: A vector of key values.
+default_value: A scalar to return if the input is not found in the table.
+output_values: A vector of values associated to the inputs.
+)doc");
+
+REGISTER_OP("LookupTableSize")
+ .Input("table_handle: Ref(string)")
+ .Output("size: int64")
+ .Doc(R"doc(
+Computes the number of elements in the given table.
+
+table_handle: The handle to a lookup table.
+size: The number of elements in the given table.
+)doc");
+
+REGISTER_OP("HashTable")
+ .Output("table_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("key_dtype: type")
+ .Attr("value_dtype: type")
+ .Doc(R"doc(
+Creates and holds an immutable hash table.
+
+The key and value types can be specified. After initialization, the table
+becomes immutable.
+
+table_handle: a handle of a the lookup table.
+container: If non-empty, this hash table is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this hash table is shared under the given name across
+ multiple sessions.
+key_dtype: the type of the table key.
+value_dtype: the type of the table value.
+)doc");
+
+REGISTER_OP("InitializeTable")
+ .Input("table_handle: Ref(string)")
+ .Input("keys: Tkey")
+ .Input("values: Tval")
+ .Attr("Tkey: type")
+ .Attr("Tval: type")
+ .Doc(R"doc(
+Table initializer that takes two tensors for keys and values respectively.
+
+table_handle: a handle of the lookup table to be initialized.
+keys: a vector of keys of type Tkey.
+values: a vector of values of type Tval.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
new file mode 100644
index 0000000000..88af081893
--- /dev/null
+++ b/tensorflow/core/ops/image_ops.cc
@@ -0,0 +1,273 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ResizeArea")
+ .Input("images: T")
+ .Input("size: int32")
+ .Output("resized_images: float")
+ .Attr("T: {uint8, int8, int32, float, double}")
+ .Doc(R"doc(
+Resize `images` to `size` using area interpolation.
+
+Input images can be of different types but output images are always float.
+
+images: 4-D with shape `[batch, height, width, channels]`.
+size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
+ new size for the images.
+resized_images: 4-D with shape
+ `[batch, new_height, new_width, channels]`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ResizeBicubic")
+ .Input("images: T")
+ .Input("size: int32")
+ .Output("resized_images: float")
+ .Attr("T: {uint8, int8, int32, float, double}")
+ .Doc(R"doc(
+Resize `images` to `size` using bicubic interpolation.
+
+Input images can be of different types but output images are always float.
+
+images: 4-D with shape `[batch, height, width, channels]`.
+size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
+ new size for the images.
+resized_images: 4-D with shape
+ `[batch, new_height, new_width, channels]`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ResizeBilinear")
+ .Input("images: T")
+ .Input("size: int32")
+ .Output("resized_images: float")
+ .Attr("T: {uint8, int8, int32, float, double}")
+ .Doc(R"doc(
+Resize `images` to `size` using bilinear interpolation.
+
+Input images can be of different types but output images are always float.
+
+images: 4-D with shape `[batch, height, width, channels]`.
+size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
+ new size for the images.
+resized_images: 4-D with shape
+ `[batch, new_height, new_width, channels]`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("ResizeNearestNeighbor")
+ .Input("images: T")
+ .Input("size: int32")
+ .Output("resized_images: T")
+ .Attr("T: {uint8, int8, int32, float, double}")
+ .Doc(R"doc(
+Resize `images` to `size` using nearest neighbor interpolation.
+
+Input images can be of different types but output images are always float.
+
+images: 4-D with shape `[batch, height, width, channels]`.
+size:= A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
+ new size for the images.
+resized_images: 4-D with shape
+ `[batch, new_height, new_width, channels]`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("RandomCrop")
+ .Input("image: T")
+ .Input("size: int64")
+ .Output("output: T")
+ .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .SetIsStateful()
+ .Doc(R"doc(
+Randomly crop `image`.
+
+`size` is a 1-D int64 tensor with 2 elements representing the crop height and
+width. The values must be non negative.
+
+This Op picks a random location in `image` and crops a `height` by `width`
+rectangle from that location. The random location is picked so the cropped
+area will fit inside the original image.
+
+image: 3-D of shape `[height, width, channels]`.
+size: 1-D of length 2 containing: `crop_height`, `crop_width`..
+seed: If either seed or seed2 are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: An second seed to avoid seed collision.
+output: 3-D of shape `[crop_height, crop_width, channels].`
+)doc");
+// TODO(shlens): Support variable rank in RandomCrop.
+
+// --------------------------------------------------------------------------
+REGISTER_OP("DecodeJpeg")
+ .Input("contents: string")
+ .Attr("channels: int = 0")
+ .Attr("ratio: int = 1")
+ .Attr("fancy_upscaling: bool = true")
+ .Attr("try_recover_truncated: bool = false")
+ .Attr("acceptable_fraction: float = 1.0")
+ .Output("image: uint8")
+ .Doc(R"doc(
+Decode a JPEG-encoded image to a uint8 tensor.
+
+The attr `channels` indicates the desired number of color channels for the
+decoded image.
+
+Accepted values are:
+
+* 0: Use the number of channels in the JPEG-encoded image.
+* 1: output a grayscale image.
+* 3: output an RGB image.
+
+If needed, the JPEG-encoded image is transformed to match the requested number
+of color channels.
+
+The attr `ratio` allows downscaling the image by an integer factor during
+decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than
+downscaling the image later.
+
+contents: 0-D. The JPEG-encoded image.
+channels: Number of color channels for the decoded image.
+ratio: Downscaling ratio.
+fancy_upscaling: If true use a slower but nicer upscaling of the
+ chroma planes (yuv420/422 only).
+try_recover_truncated: If true try to recover an image from truncated input.
+acceptable_fraction: The minimum required fraction of lines before a truncated
+ input is accepted.
+image: 3-D with shape `[height, width, channels]`..
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("EncodeJpeg")
+ .Input("image: uint8")
+ .Attr("format: {'', 'grayscale', 'rgb'} = ''")
+ .Attr("quality: int = 95")
+ .Attr("progressive: bool = false")
+ .Attr("optimize_size: bool = false")
+ .Attr("chroma_downsampling: bool = true")
+ .Attr("density_unit: {'in', 'cm'} = 'in'")
+ .Attr("x_density: int = 300")
+ .Attr("y_density: int = 300")
+ .Attr("xmp_metadata: string = ''")
+ .Output("contents: string")
+ .Doc(R"doc(
+JPEG-encode an image.
+
+`image` is a 3-D uint8 Tensor of shape `[height, width, channels]`.
+
+The attr `format` can be used to override the color format of the encoded
+output. Values can be:
+
+* `''`: Use a default format based on the number of channels in the image.
+* `grayscale`: Output a grayscale JPEG image. The `channels` dimension
+ of `image` must be 1.
+* `rgb`: Output an RGB JPEG image. The `channels` dimension
+ of `image` must be 3.
+
+If `format` is not specified or is the empty string, a default format is picked
+in function of the number of channels in `image`:
+
+* 1: Output a grayscale image.
+* 3: Output an RGB image.
+
+image: 3-D with shape `[height, width, channels]`.
+format: Per pixel image format.
+quality: Quality of the compression from 0 to 100 (higher is better and slower).
+progressive: If True, create a JPEG that loads progressively (coarse to fine).
+optimize_size: If True, spend CPU/RAM to reduce size with no quality change.
+chroma_downsampling: See http://en.wikipedia.org/wiki/Chroma_subsampling.
+density_unit: Unit used to specify `x_density` and `y_density`:
+ pixels per inch (`'in'`) or centimeter (`'cm'`).
+x_density: Horizontal pixels per density unit.
+y_density: Vertical pixels per density unit.
+xmp_metadata: If not empty, embed this XMP metadata in the image header.
+contents: 0-D. JPEG-encoded image.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("AdjustContrast")
+ .Input("images: T")
+ .Input("contrast_factor: float")
+ .Input("min_value: float")
+ .Input("max_value: float")
+ .Output("output: float")
+ .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
+ .Doc(R"Doc(
+Adjust the contrast of one or more images.
+
+`images` is a tensor of at least 3 dimensions. The last 3 dimensions are
+interpreted as `[height, width, channels]`. The other dimensions only
+represent a collection of images, such as `[batch, height, width, channels].`
+
+Contrast is adjusted independently for each channel of each image.
+
+For each channel, the Op first computes the mean of the image pixels in the
+channel and then adjusts each component of each pixel to
+`(x - mean) * contrast_factor + mean`.
+
+These adjusted values are then clipped to fit in the `[min_value, max_value]`
+interval.
+
+`images: Images to adjust. At least 3-D.
+contrast_factor: A float multiplier for adjusting contrast.
+min_value: Minimum value for clipping the adjusted pixels.
+max_value: Maximum value for clipping the adjusted pixels.
+output: The constrast-adjusted image or images.
+)Doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("DecodePng")
+ .Input("contents: string")
+ .Attr("channels: int = 0")
+ .Output("image: uint8")
+ .Doc(R"doc(
+Decode a PNG-encoded image to a uint8 tensor.
+
+The attr `channels` indicates the desired number of color channels for the
+decoded image.
+
+Accepted values are:
+
+* 0: Use the number of channels in the PNG-encoded image.
+* 1: output a grayscale image.
+* 3: output an RGB image.
+* 4: output an RGBA image.
+
+If needed, the PNG-encoded image is transformed to match the requested number
+of color channels.
+
+contents: 0-D. The PNG-encoded image.
+channels: Number of color channels for the decoded image.
+image: 3-D with shape `[height, width, channels]`.
+)doc");
+
+// --------------------------------------------------------------------------
+REGISTER_OP("EncodePng")
+ .Input("image: uint8")
+ .Attr("compression: int = -1")
+ .Output("contents: string")
+ .Doc(R"doc(
+PNG-encode an image.
+
+`image` is a 3-D uint8 Tensor of shape `[height, width, channels]` where
+`channels` is:
+
+* 1: for grayscale.
+* 3: for RGB.
+* 4: for RGBA.
+
+The ZLIB compression level, `compression`, can be -1 for the PNG-encoder
+default or a value from 0 to 9. 9 is the highest compression level, generating
+the smallest output, but is slower.
+
+image: 3-D with shape `[height, width, channels]`.
+compression: Compression level.
+contents: 0-D. PNG-encoded image.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc
new file mode 100644
index 0000000000..937fedd45d
--- /dev/null
+++ b/tensorflow/core/ops/io_ops.cc
@@ -0,0 +1,332 @@
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+REGISTER_OP("Save")
+ .Input("filename: string")
+ .Input("tensor_names: string")
+ .Input("data: T")
+ .Attr("T: list({float, double, int32, int64, quint8, qint8, qint32})")
+ .Doc(R"doc(
+Saves the input tensors to disk.
+
+The size of `tensor_names` must match the number of tensors in `data`. `data[i]`
+is written to `filename` with name `tensor_names[i]`.
+
+See also `SaveSlices`.
+
+filename: Must have a single element. The name of the file to which we write
+the tensor.
+tensor_names: Shape `[N]`. The names of the tensors to be saved.
+data: `N` tensors to save.
+)doc");
+
+REGISTER_OP("SaveSlices")
+ .Input("filename: string")
+ .Input("tensor_names: string")
+ .Input("shapes_and_slices: string")
+ .Input("data: T")
+ .Attr("T: list({float, double, int32, int64, quint8, qint8, qint32})")
+ .Doc(R"doc(
+Saves input tensors slices to disk.
+
+This is like `Save` except that tensors can be listed in the saved file as being
+a slice of a larger tensor. `shapes_and_slices` specifies the shape of the
+larger tensor and the slice that this tensor covers. `shapes_and_slices` must
+have as many elements as `tensor_names`.
+
+Elements of the `shapes_and_slices` input must either be:
+
+* The empty string, in which case the corresponding tensor is
+ saved normally.
+* A string of the form `dim0 dim1 ... dimN-1 slice-spec` where the
+ `dimI` are the dimensions of the larger tensor and `slice-spec`
+ specifies what part is covered by the tensor to save.
+
+`slice-spec` itself is a `:`-separated list: `slice0:slice1:...:sliceN-1`
+where each `sliceI` is either:
+
+* The string `-` meaning that the slice covers all indices of this dimension
+* `start,length` where `start` and `length` are integers. In that
+ case the slice covers `length` indices starting at `start`.
+
+See also `Save`.
+
+filename: Must have a single element. The name of the file to which we write the
+tensor.
+tensor_names: Shape `[N]`. The names of the tensors to be saved.
+shapes_and_slices: Shape `[N]`. The shapes and slice specifications to use when
+saving the tensors.
+data: `N` tensors to save.
+)doc");
+
+REGISTER_OP("Restore")
+ .Input("file_pattern: string")
+ .Input("tensor_name: string")
+ .Output("tensor: dt")
+ .Attr("dt: type")
+ .Attr("preferred_shard: int = -1")
+ .Doc(R"doc(
+Restores a tensor from checkpoint files.
+
+Reads a tensor stored in one or several files. If there are several files (for
+instance because a tensor was saved as slices), `file_pattern` may contain
+wildcard symbols (`*` and `?`) in the filename portion only, not in the
+directory portion.
+
+If a `file_pattern` matches several files, `preferred_shard` can be used to hint
+in which file the requested tensor is likely to be found. This op will first
+open the file at index `preferred_shard` in the list of matching files and try
+to restore tensors from that file. Only if some tensors or tensor slices are
+not found in that first file, then the Op opens all the files. Setting
+`preferred_shard` to match the value passed as the `shard` input
+of a matching `Save` Op may speed up Restore. This attribute only affects
+performance, not correctness. The default value -1 means files are processed in
+order.
+
+See also `RestoreSlice`.
+
+file_pattern: Must have a single element. The pattern of the files from
+ which we read the tensor.
+tensor_name: Must have a single element. The name of the tensor to be
+ restored.
+tensor: The restored tensor.
+dt: The type of the tensor to be restored.
+preferred_shard: Index of file to open first if multiple files match
+ `file_pattern`.
+)doc");
+
+REGISTER_OP("RestoreSlice")
+ .Input("file_pattern: string")
+ .Input("tensor_name: string")
+ .Input("shape_and_slice: string")
+ .Output("tensor: dt")
+ .Attr("dt: type")
+ .Attr("preferred_shard: int = -1")
+ .Doc(R"doc(
+Restores a tensor from checkpoint files.
+
+This is like `Restore` except that restored tensor can be listed as filling
+only a slice of a larger tensor. `shape_and_slice` specifies the shape of the
+larger tensor and the slice that the restored tensor covers.
+
+The `shape_and_slice` input has the same format as the
+elements of the `shapes_and_slices` input of the `SaveSlices` op.
+
+file_pattern: Must have a single element. The pattern of the files from
+ which we read the tensor.
+tensor_name: Must have a single element. The name of the tensor to be
+ restored.
+shape_and_slice: Scalar. The shapes and slice specifications to use when
+ restoring a tensors.
+tensor: The restored tensor.
+dt: The type of the tensor to be restored.
+preferred_shard: Index of file to open first if multiple files match
+ `file_pattern`. See the documentation for `Restore`.
+)doc");
+
+REGISTER_OP("ShardedFilename")
+ .Input("basename: string")
+ .Input("shard: int32")
+ .Input("num_shards: int32")
+ .Output("filename: string")
+ .Doc(R"doc(
+Generate a sharded filename. The filename is printf formated as
+ %s-%05d-of-%05d, basename, shard, num_shards.
+)doc");
+
+REGISTER_OP("ShardedFilespec")
+ .Input("basename: string")
+ .Input("num_shards: int32")
+ .Output("filename: string")
+ .Doc(R"doc(
+Generate a glob pattern matching all sharded file names.
+)doc");
+
+// Reader source ops ----------------------------------------------------------
+
+REGISTER_OP("WholeFileReader")
+ .Output("reader_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+A Reader that outputs the entire contents of a file as a value.
+
+To use, enqueue filenames in a Queue. The output of ReaderRead will
+be a filename (key) and the contents of that file (value).
+
+reader_handle: The handle to reference the Reader.
+container: If non-empty, this reader is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this reader is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+)doc");
+
+REGISTER_OP("TextLineReader")
+ .Output("reader_handle: Ref(string)")
+ .Attr("skip_header_lines: int = 0")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+A Reader that outputs the lines of a file delimited by '\n'.
+
+reader_handle: The handle to reference the Reader.
+skip_header_lines: Number of lines to skip from the beginning of every file.
+container: If non-empty, this reader is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this reader is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+)doc");
+
+REGISTER_OP("FixedLengthRecordReader")
+ .Output("reader_handle: Ref(string)")
+ .Attr("header_bytes: int = 0")
+ .Attr("record_bytes: int")
+ .Attr("footer_bytes: int = 0")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+A Reader that outputs fixed-length records from a file.
+
+reader_handle: The handle to reference the Reader.
+container: If non-empty, this reader is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this reader is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+)doc");
+
+REGISTER_OP("TFRecordReader")
+ .Output("reader_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+A Reader that outputs the records from a TensorFlow Records file.
+
+reader_handle: The handle to reference the Reader.
+container: If non-empty, this reader is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this reader is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+)doc");
+
+REGISTER_OP("IdentityReader")
+ .Output("reader_handle: Ref(string)")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+A Reader that outputs the queued work as both the key and value.
+
+To use, enqueue strings in a Queue. ReaderRead will take the front
+work string and output (work, work).
+
+reader_handle: The handle to reference the Reader.
+container: If non-empty, this reader is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this reader is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+)doc");
+
+// Ops that operate on Readers ------------------------------------------------
+
+REGISTER_OP("ReaderRead")
+ .Input("reader_handle: Ref(string)")
+ .Input("queue_handle: Ref(string)")
+ .Output("key: string")
+ .Output("value: string")
+ .Doc(R"doc(
+Returns the next record (key, value pair) produced by a Reader.
+
+Will dequeue from the input queue if necessary (e.g. when the
+Reader needs to start reading from a new file since it has finished
+with the previous file).
+
+reader_handle: Handle to a Reader.
+queue_handle: Handle to a Queue, with string work items.
+key: A scalar.
+value: A scalar.
+)doc");
+
+REGISTER_OP("ReaderNumRecordsProduced")
+ .Input("reader_handle: Ref(string)")
+ .Output("records_produced: int64")
+ .Doc(R"doc(
+Returns the number of records this Reader has produced.
+
+This is the same as the number of ReaderRead executions that have
+succeeded.
+
+reader_handle: Handle to a Reader.
+)doc");
+
+REGISTER_OP("ReaderNumWorkUnitsCompleted")
+ .Input("reader_handle: Ref(string)")
+ .Output("units_completed: int64")
+ .Doc(R"doc(
+Returns the number of work units this Reader has finished processing.
+
+reader_handle: Handle to a Reader.
+)doc");
+
+REGISTER_OP("ReaderSerializeState")
+ .Input("reader_handle: Ref(string)")
+ .Output("state: string")
+ .Doc(R"doc(
+Produce a string tensor that encodes the state of a Reader.
+
+Not all Readers support being serialized, so this can produce an
+Unimplemented error.
+
+reader_handle: Handle to a Reader.
+)doc");
+
+REGISTER_OP("ReaderRestoreState")
+ .Input("reader_handle: Ref(string)")
+ .Input("state: string")
+ .Doc(R"doc(
+Restore a reader to a previously saved state.
+
+Not all Readers support being restored, so this can produce an
+Unimplemented error.
+
+reader_handle: Handle to a Reader.
+state: Result of a ReaderSerializeState of a Reader with type
+ matching reader_handle.
+)doc");
+
+REGISTER_OP("ReaderReset")
+ .Input("reader_handle: Ref(string)")
+ .Doc(R"doc(
+Restore a Reader to its initial clean state.
+
+reader_handle: Handle to a Reader.
+)doc");
+
+// Other input Ops ----------------------------------------------------------
+
+REGISTER_OP("ReadFile")
+ .Input("filename: string")
+ .Output("contents: string")
+ .Doc(R"doc(
+Reads and outputs the entire contents of the input filename.
+)doc");
+
+REGISTER_OP("MatchingFiles")
+ .Input("pattern: string")
+ .Output("filenames: string")
+ .Doc(R"doc(
+Returns the set of files matching a pattern.
+
+Note that this routine only supports wildcard characters in the
+basename portion of the pattern, not in the directory portion.
+
+pattern: A (scalar) shell wildcard pattern.
+filenames: A vector of matching filenames.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
new file mode 100644
index 0000000000..a9b940295e
--- /dev/null
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -0,0 +1,97 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("MatrixDeterminant")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Calculates the determinant of a square matrix.
+
+input: A tensor of shape `[M, M]`.
+output: A scalar, equal to the determinant of the input.
+T: The type of values in the input and output.
+)doc");
+
+REGISTER_OP("BatchMatrixDeterminant")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Calculates the determinants for a batch of square matrices.
+
+The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+form square matrices. The output is a 1-D tensor containing the determinants
+for all input submatrices `[..., :, :]`.
+
+input: Shape is `[..., M, M]`.
+output: Shape is `[...]`.
+T: The type of values in the input and output.
+)doc");
+
+REGISTER_OP("MatrixInverse")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Calculates the inverse of a square invertible matrix. Checks for invertibility.
+
+input: Shape is `[M, M]`.
+output: Shape is `[M, M]` containing the matrix inverse of the input.
+T: The type of values in the input and output.
+)doc");
+
+REGISTER_OP("BatchMatrixInverse")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Calculates the inverse of square invertible matrices. Checks for invertibility.
+
+The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+form square matrices. The output is a tensor of the same shape as the input
+containing the inverse for all input submatrices `[..., :, :]`.
+
+input: Shape is `[..., M, M]`.
+output: Shape is `[..., M, M]`.
+T: The type of values in the input and output.
+)doc");
+
+REGISTER_OP("Cholesky")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: {double, float}")
+ .Doc(R"doc(
+Calculates the Cholesky decomposition of a square matrix.
+
+The input has to be symmetric and positive definite. Only the lower-triangular
+part of the input will be used for this operation. The upper-triangular part
+will not be read.
+
+The result is the lower-triangular matrix of the Cholesky decomposition of the
+input.
+
+input: Shape is `[M, M]`.
+output: Shape is `[M, M]`.
+T: The type of values in the input and output.
+)doc");
+
+REGISTER_OP("BatchCholesky")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: {double, float}")
+ .Doc(R"doc(
+Calculates the Cholesky decomposition of a batch of square matrices.
+
+The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+form square matrices, with the same constraints as the single matrix Cholesky
+decomposition above. The output is a tensor of the same shape as the input
+containing the Cholesky decompositions for all input submatrices `[..., :, :]`.
+
+input: Shape is `[..., M, M]`.
+output: Shape is `[..., M, M]`.
+T: The type of values in the input and output.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc
new file mode 100644
index 0000000000..28546fe645
--- /dev/null
+++ b/tensorflow/core/ops/logging_ops.cc
@@ -0,0 +1,43 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("Assert")
+ .Input("condition: bool")
+ .Input("data: T")
+ .Attr("T: list(type)")
+ .Attr("summarize: int = 3")
+ .Doc(R"doc(
+Asserts that the given condition is true.
+
+If `condition` evaluates to false, print the list of tensors in `data`.
+`summarize` determines how many entries of the tensors to print.
+
+condition: The condition to evaluate.
+data: The tensors to print out when condition is false.
+summarize: Print this many entries of each tensor.
+)doc");
+
+REGISTER_OP("Print")
+ .Input("input: T")
+ .Input("data: U")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("U: list(type)")
+ .Attr("message: string = ''")
+ .Attr("first_n: int = -1")
+ .Attr("summarize: int = 3")
+ .Doc(R"doc(
+Prints a list of tensors.
+
+Passes `input` through to `output` and prints `data` when evaluating.
+
+input: The tensor passed to `output`
+data: A list of tensors to print out when op is evaluated.
+output:= The unmodified `input` tensor
+message: A string, prefix of the error message.
+first_n: Only log `first_n` number of times. -1 disables logging.
+summarize: Only print this many entries of each tensor.
+)doc");
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
new file mode 100644
index 0000000000..20e56316ea
--- /dev/null
+++ b/tensorflow/core/ops/math_ops.cc
@@ -0,0 +1,1053 @@
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("AddN")
+ .Input("inputs: N * T")
+ .Output("sum: T")
+ .Attr("N: int >= 1")
+ .Attr("T: numbertype")
+ .SetIsCommutative()
+ .SetIsAggregate()
+ .Doc(R"doc(
+Add all input tensors element wise.
+
+inputs: Must all be the same size and shape.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("BatchMatMul")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("out: T")
+ .Attr("T: {float, double, int32, complex64}")
+ .Attr("adj_x: bool = false")
+ .Attr("adj_y: bool = false")
+ .Doc(R"doc(
+Multiplies slices of two tensors in batches.
+
+Multiplies all slices of `Tensor` `x` and `y` (each slice can be
+viewed as an element of a batch), and arranges the individual results
+in a single output tensor of the same batch size. Each of the
+individual slices can optionally be adjointed (to adjoint a matrix
+means to transpose and conjugate it) before multiplication by setting
+the `adj_x` or `adj_y` flag to `True`, which are by default `False`.
+
+The input tensors `x` and `y` are 3-D or higher with shape `[..., r_x, c_x]`
+and `[..., r_y, c_y]`.
+
+The output tensor is 3-D or higher with shape `[..., r_o, c_o]`, where:
+
+ r_o = c_x if adj_x else r_x
+ c_o = r_y if adj_y else c_y
+
+It is computed as:
+
+ out[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
+
+x: 3-D or higher with shape `[..., r_x, c_x]`.
+y: 3-D or higher with shape `[..., r_y, c_y]`.
+out: 3-D or higher with shape `[..., r_o, c_o]`
+adj_x: If `True`, adjoint the slices of `x`. Defaults to `False`.
+adj_y: If `True`, adjoint the slices of `y`. Defaults to `False`.
+)doc");
+
+// --------------------------------------------------------------------------
+// Casting Ops
+//
+// NOTE: Only a smaller number of types are supported by
+// Cast. The exact casting rule is TBD. The current
+// implementation uses C++ static cast rules for numeric
+// types, which may be changed in the future.
+REGISTER_OP("Cast")
+ .Input("x: SrcT")
+ .Output("y: DstT")
+ .Attr("SrcT: type")
+ .Attr("DstT: type")
+ .Doc(R"doc(
+Cast x of type SrcT to y of DstT.
+)doc");
+
+REGISTER_OP("_HostCast")
+ .Input("x: SrcT")
+ .Output("y: DstT")
+ .Attr("SrcT: type")
+ .Attr("DstT: type")
+ .Doc(R"doc(
+Cast x of type SrcT to y of DstT.
+
+_HostCast requires its input and produces its output in host memory.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("Abs")
+ .Input("x: T")
+ .Output("y: T")
+ .Attr("T: {float, double, int32, int64}")
+ .Doc(R"doc(
+Computes the absolute value of a tensor.
+
+Given a tensor `x`, this operation returns a tensor containing the absolute
+value of each element in `x`. For example, if x is an input element and y is
+an output element, this operation computes \\(y = |x|\\).
+)doc");
+
+REGISTER_OP("ComplexAbs")
+ .Input("x: complex64")
+ .Output("y: float")
+ .Doc(R"doc(
+Computes the complex absolute value of a tensor.
+
+Given a tensor `x` of complex numbers, this operation returns a tensor of type
+`float` that is the absolute value of each element in `x`. All elements in `x`
+must be complex numbers of the form \\(a + bj\\). The absolute value is
+computed as \\( \sqrt{a^2 + b^2}\\).
+
+For example:
+
+```
+# tensor 'x' is [[-2.25 + 4.75j], [-3.25 + 5.75j]]
+tf.complex_abs(x) ==> [5.25594902, 6.60492229]
+```
+)doc");
+
+// Declares cwise unary operations signature: 't -> 't
+#define UNARY() \
+ Input("x: T").Output("y: T").Attr( \
+ "T: {float, double, int32, complex64, int64}")
+
+REGISTER_OP("Neg")
+ .UNARY()
+ .Doc(R"doc(
+Computes numerical negative value element-wise.
+I.e., \\(y = -x\\).
+)doc");
+
+REGISTER_OP("Inv")
+ .UNARY()
+ .Doc(R"doc(
+Computes the reciprocal of x element-wise.
+I.e., \\(y = 1 / x\\).
+)doc");
+
+REGISTER_OP("Square")
+ .UNARY()
+ .Doc(R"doc(
+Computes square of x element-wise.
+I.e., \\(y = x * x = x^2\\).
+)doc");
+
+REGISTER_OP("Sqrt")
+ .UNARY()
+ .Doc(R"doc(
+Computes square root of x element-wise.
+I.e., \\(y = \sqrt{x} = x^{1/2}\\).
+)doc");
+
+REGISTER_OP("Rsqrt")
+ .UNARY()
+ .Doc(R"doc(
+Computes reciprocal of square root of x element-wise.
+I.e., \\(y = 1 / \sqrt{x}\\).
+)doc");
+
+REGISTER_OP("Exp")
+ .UNARY()
+ .Doc(R"doc(
+Computes exponential of x element-wise. \\(y = e^x\\).
+)doc");
+
+REGISTER_OP("Log")
+ .UNARY()
+ .Doc(R"doc(
+Computes natural logrithm of x element-wise.
+I.e., \\(y = \log_e x\\).
+)doc");
+
+REGISTER_OP("Tanh")
+ .UNARY()
+ .Doc(R"doc(
+Computes hyperbolic tangent of `x` element-wise.
+)doc");
+
+REGISTER_OP("Sigmoid")
+ .UNARY()
+ .Doc(R"doc(
+Computes sigmoid of `x` element-wise.
+
+Specifically, `y = 1 / (1 + exp(-x))`.
+)doc");
+
+REGISTER_OP("Sin")
+ .UNARY()
+ .Doc(R"doc(
+Computes sin of x element-wise.
+)doc");
+
+REGISTER_OP("Cos")
+ .UNARY()
+ .Doc(R"doc(
+Computes cos of x element-wise.
+)doc");
+
+#undef UNARY
+
+REGISTER_OP("IsNan")
+ .Input("x: T")
+ .Output("y: bool")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Returns which elements of x are NaN.
+)doc");
+
+REGISTER_OP("IsInf")
+ .Input("x: T")
+ .Output("y: bool")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Returns which elements of x are Inf.
+)doc");
+
+REGISTER_OP("IsFinite")
+ .Input("x: T")
+ .Output("y: bool")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Returns which elements of x are finite.
+)doc");
+
+REGISTER_OP("Sign")
+ .Input("x: T")
+ .Output("y: T")
+ .Attr("T: {float, double, int32, int64}")
+ .Doc(R"doc(
+Returns an element-wise indication of the sign of a number.
+
+y = sign(x) = -1 if x < 0; 0 if x == 0; 1 if x > 0.
+)doc");
+
+REGISTER_OP("Floor")
+ .Input("x: T")
+ .Output("y: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Returns element-wise largest integer not greater than x.
+)doc");
+
+REGISTER_OP("Ceil")
+ .Input("x: T")
+ .Output("y: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Returns element-wise smallest integer in not less than x.
+)doc");
+
+// Declares cwise binary operations signature: 't, 't -> 't.
+
+#define BINARY_MORE() \
+ Input("x: T").Input("y: T").Output("z: T").Attr( \
+ "T: {float, double, int8, int16, int32, complex64, int64}")
+
+#define BINARY_FEWER() \
+ Input("x: T").Input("y: T").Output("z: T").Attr( \
+ "T: {float, double, int32, complex64, int64}")
+
+REGISTER_OP("Add")
+ .BINARY_MORE()
+ .SetIsCommutative()
+ .Doc(R"doc(
+Returns x + y element-wise.
+
+*NOTE*: Add supports broadcasting. AddN does not.
+)doc");
+
+REGISTER_OP("Sub")
+ .BINARY_FEWER()
+ .Doc(R"doc(
+Returns x - y element-wise.
+)doc");
+
+REGISTER_OP("Mul")
+ .BINARY_MORE()
+ .SetIsCommutative()
+ .Doc(R"doc(
+Returns x * y element-wise.
+)doc");
+
+REGISTER_OP("Div")
+ .BINARY_FEWER()
+ .Doc(R"doc(
+Returns x / y element-wise.
+)doc");
+
+#undef BINARY_FEWER
+#undef BINARY_MORE
+
+REGISTER_OP("Maximum")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {float, double, int32, int64}")
+ .SetIsCommutative()
+ .Doc(R"doc(
+Returns the max of x and y (i.e. x > y ? x : y) element-wise, broadcasts.
+)doc");
+
+REGISTER_OP("Minimum")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {float, double, int32, int64}")
+ .SetIsCommutative()
+ .Doc(R"doc(
+Returns the min of x and y (i.e. x < y ? x : y) element-wise, broadcasts.
+)doc");
+
+REGISTER_OP("Mod")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {int32, int64, float, double}")
+ .Doc(R"doc(
+Returns element-wise remainder of division.
+)doc");
+
+REGISTER_OP("Pow")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {float, double, int32, complex64, int64}")
+ .Doc(R"doc(
+Computes the power of one value to another.
+
+Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
+corresponding elements in `x` and `y`. For example:
+
+```
+# tensor 'x' is [[2, 2]], [3, 3]]
+# tensor 'y' is [[8, 16], [2, 3]]
+tf.pow(x, y) ==> [[256, 65536], [9, 27]]
+```
+)doc");
+
+// --------------------------------------------------------------------------
+
+// Declares cwise binary comparison operations signature: 't, 't -> bool,
+// where 't has a natural total order.
+#define COMPARISON() \
+ Input("x: T").Input("y: T").Output("z: bool").Attr( \
+ "T: {float, double, int32, int64}")
+
+REGISTER_OP("Less")
+ .COMPARISON()
+ .Doc(R"doc(
+Returns the truth value of (x < y) element-wise.
+)doc");
+
+REGISTER_OP("LessEqual")
+ .COMPARISON()
+ .Doc(R"doc(
+Returns the truth value of (x <= y) element-wise.
+)doc");
+
+REGISTER_OP("Greater")
+ .COMPARISON()
+ .Doc(R"doc(
+Returns the truth value of (x > y) element-wise.
+)doc");
+
+REGISTER_OP("GreaterEqual")
+ .COMPARISON()
+ .Doc(R"doc(
+Returns the truth value of (x >= y) element-wise.
+)doc");
+
+#undef COMPARISON
+
+// --------------------------------------------------------------------------
+
+#define COMPARISON() \
+ Input("x: T").Input("y: T").Output("z: bool").SetIsCommutative().Attr( \
+ "T: {float, double, int32, int64, complex64, quint8, qint8, qint32}")
+
+REGISTER_OP("Equal")
+ .COMPARISON()
+ .Doc(R"doc(
+Returns the truth value of (x == y) element-wise.
+)doc");
+
+REGISTER_OP("NotEqual")
+ .COMPARISON()
+ .Doc(R"doc(
+Returns the truth value of (x != y) element-wise.
+)doc");
+
+#undef COMPARISON
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("LogicalNot")
+ .Input("x: bool")
+ .Output("y: bool")
+ .Doc(R"doc(
+Returns the truth value of NOT x element-wise.
+)doc");
+
+#define BINARY_LOGICAL() \
+ Input("x: bool").Input("y: bool").Output("z: bool").SetIsCommutative()
+
+REGISTER_OP("LogicalAnd")
+ .BINARY_LOGICAL()
+ .Doc(R"doc(
+Returns the truth value of x AND y element-wise.
+)doc");
+
+REGISTER_OP("LogicalOr")
+ .BINARY_LOGICAL()
+ .Doc(R"doc(
+Returns the truth value of x OR y element-wise.
+)doc");
+
+#undef BINARY_LOGICAL
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("Select")
+ .Input("condition: bool")
+ .Input("t: T")
+ .Input("e: T")
+ .Output("out: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Selects elements from `t` or `e`, depending on `condition`.
+
+The `condition`, `t`, and `e` tensors must all have the same shape,
+and the output will also have that shape. The `condition` tensor acts
+as an element-wise mask that chooses, based on the value at each
+element, whether the corresponding element in the output should be
+taken from `t` (if true) or `e` (if false). For example:
+
+For example:
+
+```prettyprint
+# 'condition' tensor is [[True, False]
+# [True, False]]
+# 't' is [[1, 1],
+# [1, 1]]
+# 'e' is [[2, 2],
+# [2, 2]]
+select(condition, t, e) ==> [[1, 2],
+ [1, 2]]
+```
+
+t:= A `Tensor` with the same shape as `condition`.
+e:= A `Tensor` with the same type and shape as `t`.
+out:= A `Tensor` with the same type and shape as `t` and `e`.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("MatMul")
+ .Input("a: T")
+ .Input("b: T")
+ .Output("product: T")
+ .Attr("transpose_a: bool = false")
+ .Attr("transpose_b: bool = false")
+ .Attr("T: {float, double, int32, complex64}")
+ .Doc(R"doc(
+Multiply the matrix "a" by the matrix "b".
+
+The inputs must be two-dimensional matrices and the inner dimension of
+"a" (after being transposed if transpose_a is true) must match the
+outer dimension of "b" (after being transposed if transposed_b is
+true).
+
+*Note*: The default kernel implementation for MatMul on GPUs uses
+cublas.
+
+transpose_a: If true, "a" is transposed before multiplication.
+transpose_b: If true, "b" is transposed before multiplication.
+)doc");
+
+REGISTER_OP("SparseMatMul")
+ .Input("a: float")
+ .Input("b: float")
+ .Output("product: float")
+ .Attr("transpose_a: bool = false")
+ .Attr("transpose_b: bool = false")
+ .Attr("a_is_sparse: bool = false")
+ .Attr("b_is_sparse: bool = false")
+ .Doc(R"doc(
+Multiply matrix "a" by matrix "b".
+
+The inputs must be two-dimensional matrices and the inner dimension of "a" must
+match the outer dimension of "b". This op is optimized for the case where at
+least one of "a" or "b" is sparse. The breakeven for using this versus a dense
+matrix multiply on one platform was 30% zero values in the sparse matrix.
+)doc");
+
+// --------------------------------------------------------------------------
+
+// For operations where the output is a reduction function along some
+// dimensions of the input.
+REGISTER_OP("Sum")
+ .Input("input: T")
+ .Input("reduction_indices: int32")
+ .Output("output: T")
+ .Attr("keep_dims: bool = false")
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Computes the sum of elements across dimensions of a tensor.
+
+Reduces `input` along the dimensions given in `reduction_indices`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
+
+input: The tensor to reduce.
+reduction_indices: The dimensions to reduce.
+keep_dims: If true, retain reduced dimensions with length 1.
+output: The reduced tensor.
+)doc");
+
+REGISTER_OP("Mean")
+ .Input("input: T")
+ .Input("reduction_indices: int32")
+ .Output("output: T")
+ .Attr("keep_dims: bool = false")
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Computes the mean of elements across dimensions of a tensor.
+
+Reduces `input` along the dimensions given in `reduction_indices`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
+
+input: The tensor to reduce.
+reduction_indices: The dimensions to reduce.
+keep_dims: If true, retain reduced dimensions with length 1.
+output: The reduced tensor.
+)doc");
+
+REGISTER_OP("Prod")
+ .Input("input: T")
+ .Input("reduction_indices: int32")
+ .Output("output: T")
+ .Attr("keep_dims: bool = false")
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Computes the product of elements across dimensions of a tensor.
+
+Reduces `input` along the dimensions given in `reduction_indices`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
+
+input: The tensor to reduce.
+reduction_indices: The dimensions to reduce.
+keep_dims: If true, retain reduced dimensions with length 1.
+output: The reduced tensor.
+)doc");
+
+REGISTER_OP("Min")
+ .Input("input: T")
+ .Input("reduction_indices: int32")
+ .Output("output: T")
+ .Attr("keep_dims: bool = false")
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Computes the minimum of elements across dimensions of a tensor.
+
+Reduces `input` along the dimensions given in `reduction_indices`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
+
+input: The tensor to reduce.
+reduction_indices: The dimensions to reduce.
+keep_dims: If true, retain reduced dimensions with length 1.
+output: The reduced tensor.
+)doc");
+
+REGISTER_OP("Max")
+ .Input("input: T")
+ .Input("reduction_indices: int32")
+ .Output("output: T")
+ .Attr("keep_dims: bool = false")
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Computes the maximum of elements across dimensions of a tensor.
+
+Reduces `input` along the dimensions given in `reduction_indices`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
+
+input: The tensor to reduce.
+reduction_indices: The dimensions to reduce.
+keep_dims: If true, retain reduced dimensions with length 1.
+output: The reduced tensor.
+)doc");
+
+REGISTER_OP("ArgMax")
+ .Input("input: T")
+ .Input("dimension: int32")
+ .Output("output: int64")
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Returns the index with the largest value across dimensions of a tensor.
+
+dimension: int32, 0 <= dimension < rank(input). Describes which dimension
+ of the input Tensor to reduce across. For vectors, use dimension = 0.
+)doc");
+
+REGISTER_OP("ArgMin")
+ .Input("input: T")
+ .Input("dimension: int32")
+ .Output("output: int64")
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+Returns the index with the smallest value across dimensions of a tensor.
+
+dimension: int32, 0 <= dimension < rank(input). Describes which dimension
+ of the input Tensor to reduce across. For vectors, use dimension = 0.
+)doc");
+
+REGISTER_OP("SegmentSum")
+ .Input("data: T")
+ .Input("segment_ids: Tindices")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tindices: {int32,int64}")
+ .Doc(R"doc(
+Computes the sum along segments of a tensor.
+
+Read [the section on Segmentation](../python/math_ops.md#segmentation)
+for an explanation of segments.
+
+Computes a tensor such that
+\\(output_i = \sum_j data_j\\) where sum is over `j` such
+that `segment_ids[j] == i`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/SegmentSum.png" alt>
+</div>
+
+segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+first dimension. Values should be sorted and can be repeated.
+
+output: Has same shape as data, except for dimension_0 which
+has size `k`, the number of segments.
+)doc");
+
+REGISTER_OP("SegmentMean")
+ .Input("data: T")
+ .Input("segment_ids: Tindices")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tindices: {int32,int64}")
+ .Doc(R"doc(
+Computes the mean along segments of a tensor.
+
+Read [the section on Segmentation](../python/math_ops.md#segmentation)
+for an explanation of segments.
+
+Computes a tensor such that
+\\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is
+over `j` such that `segment_ids[j] == i` and `N` is the total number of
+values summed.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/SegmentMean.png" alt>
+</div>
+
+segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+first dimension. Values should be sorted and can be repeated.
+
+output: Has same shape as data, except for dimension_0 which
+has size `k`, the number of segments.
+)doc");
+
+REGISTER_OP("SegmentProd")
+ .Input("data: T")
+ .Input("segment_ids: Tindices")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tindices: {int32,int64}")
+ .Doc(R"doc(
+Computes the product along segments of a tensor.
+
+Read [the section on Segmentation](../python/math_ops.md#segmentation)
+for an explanation of segments.
+
+Computes a tensor such that
+\\(output_i = \prod_j data_j\\) where the product is over `j` such
+that `segment_ids[j] == i`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/SegmentProd.png" alt>
+</div>
+
+segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+first dimension. Values should be sorted and can be repeated.
+
+output: Has same shape as data, except for dimension_0 which
+has size `k`, the number of segments.
+)doc");
+
+REGISTER_OP("SegmentMin")
+ .Input("data: T")
+ .Input("segment_ids: Tindices")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tindices: {int32,int64}")
+ .Doc(R"doc(
+Computes the minimum along segments of a tensor.
+
+Read [the section on Segmentation](../python/math_ops.md#segmentation)
+for an explanation of segments.
+
+Computes a tensor such that
+\\(output_i = \min_j(data_j)\\) where `min` is over `j` such
+that `segment_ids[j] == i`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/SegmentMin.png" alt>
+</div>
+
+segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+first dimension. Values should be sorted and can be repeated.
+
+output: Has same shape as data, except for dimension_0 which
+has size `k`, the number of segments.
+)doc");
+
+REGISTER_OP("SegmentMax")
+ .Input("data: T")
+ .Input("segment_ids: Tindices")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tindices: {int32,int64}")
+ .Doc(R"doc(
+Computes the maximum along segments of a tensor.
+
+Read [the section on Segmentation](../python/math_ops.md#segmentation)
+for an explanation of segments.
+
+Computes a tensor such that
+\\(output_i = \max_j(data_j)\\) where `max` is over `j` such
+that `segment_ids[j] == i`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/SegmentMax.png" alt>
+</div>
+
+segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+first dimension. Values should be sorted and can be repeated.
+
+output: Has same shape as data, except for dimension_0 which
+has size `k`, the number of segments.
+)doc");
+
+REGISTER_OP("UnsortedSegmentSum")
+ .Input("data: T")
+ .Input("segment_ids: Tindices")
+ .Input("num_segments: int32")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tindices: {int32,int64}")
+ .Doc(R"doc(
+Computes the sum along segments of a tensor.
+
+Read [the section on Segmentation](../python/math_ops.md#segmentation)
+for an explanation of segments.
+
+Computes a tensor such that
+\\(output_i = \sum_j data_j\\) where sum is over `j` such
+that `segment_ids[j] == i`. Unlike `SegmentSum`, `segment_ids`
+need not be sorted and need not cover all values in the full
+ range of valid values.
+
+If the sum is empty for a given segment ID `i`, `output[i] = 0`.
+
+`num_segments` should equal the number of distinct segment IDs.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/UnsortedSegmentSum.png" alt>
+</div>
+
+segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+first dimension.
+
+output: Has same shape as data, except for dimension_0 which
+has size `num_segments`.
+
+)doc");
+
+REGISTER_OP("SparseSegmentSum")
+ .Input("data: T")
+ .Input("indices: int32")
+ .Input("segment_ids: int32")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Computes the sum along sparse segments of a tensor.
+
+Read [the section on Segmentation](../python/math_ops.md#segmentation)
+for an explanation of segments.
+
+Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
+dimension, selecting a subset of dimension_0, specified by `indices`.
+
+For example:
+
+```prettyprint
+c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+
+# Select two rows, one segment.
+tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))
+ ==> [[0 0 0 0]]
+
+# Select two rows, two segment.
+tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))
+ ==> [[ 1 2 3 4]
+ [-1 -2 -3 -4]]
+
+# Select all rows, two segments.
+tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))
+ ==> [[0 0 0 0]
+ [5 6 7 8]]
+
+# Which is equivalent to:
+tf.segment_sum(c, tf.constant([0, 0, 1]))
+```
+
+indices: A 1-D tensor. Has same rank as `segment_ids`.
+
+segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+
+output: Has same shape as data, except for dimension_0 which
+has size `k`, the number of segments.
+)doc");
+
+REGISTER_OP("SparseSegmentMean")
+ .Input("data: T")
+ .Input("indices: int32")
+ .Input("segment_ids: int32")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Computes the mean along sparse segments of a tensor.
+
+Read [the section on Segmentation](../python/math_ops.md#segmentation)
+for an explanation of segments.
+
+Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
+dimension, selecting a subset of dimension_0, specified by `indices`.
+
+indices: A 1-D tensor. Has same rank as `segment_ids`.
+
+segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+
+output: Has same shape as data, except for dimension_0 which
+has size `k`, the number of segments.
+
+)doc");
+
+REGISTER_OP("SparseSegmentMeanGrad")
+ .Input("grad: T")
+ .Input("indices: int32")
+ .Input("segment_ids: int32")
+ .Input("output_dim0: int32")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Computes gradients for SparseSegmentMean.
+
+Returns tensor "output" with same shape as grad, except for dimension_0 whose
+value is output_dim0.
+
+grad: gradient propagated to the SparseSegmentMean op.
+indices: indices passed to the corresponding SparseSegmentMean op.
+segment_ids: segment_ids passed to the corresponding SparseSegmentMean op.
+output_dim0: dimension_0 of "data" passed to SparseSegmentMean op.
+)doc");
+
+REGISTER_OP("All")
+ .Input("input: bool")
+ .Input("reduction_indices: int32")
+ .Output("output: bool")
+ .Attr("keep_dims: bool = false")
+ .Doc(R"doc(
+Computes the "logical and" of elements across dimensions of a tensor.
+
+Reduces `input` along the dimensions given in `reduction_indices`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
+
+input: The tensor to reduce.
+reduction_indices: The dimensions to reduce.
+keep_dims: If true, retain reduced dimensions with length 1.
+output: The reduced tensor.
+)doc");
+
+REGISTER_OP("Any")
+ .Input("input: bool")
+ .Input("reduction_indices: int32")
+ .Attr("keep_dims: bool = false")
+ .Output("output: bool")
+ .Doc(R"doc(
+Computes the "logical or" of elements across dimensions of a tensor.
+
+Reduces `input` along the dimensions given in `reduction_indices`. Unless
+`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+`reduction_indices`. If `keep_dims` is true, the reduced dimensions are
+retained with length 1.
+
+input: The tensor to reduce.
+reduction_indices: The dimensions to reduce.
+keep_dims: If true, retain reduced dimensions with length 1.
+output: The reduced tensor.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("Range")
+ .Input("start: int32")
+ .Input("limit: int32")
+ .Input("delta: int32")
+ .Output("output: int32")
+ .Doc(R"doc(
+Creates a sequence of integers.
+
+This operation creates a sequence of integers that begins at `start` and
+extends by increments of `delta` up to but not including `limit`.
+
+For example:
+
+```
+# 'start' is 3
+# 'limit' is 18
+# 'delta' is 3
+tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
+```
+
+start: 0-D (scalar). First entry in the sequence.
+limit: 0-D (scalar). Upper limit of sequence, exclusive.
+delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`.
+output: 1-D.
+)doc");
+
+REGISTER_OP("LinSpace")
+ .Input("start: T")
+ .Input("stop: T")
+ .Input("num: int32")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Generates values in an interval.
+
+A sequence of `num` evenly-spaced values are generated beginning at `start`.
+If `num > 1`, the values in the sequence increase by `stop - start / num - 1`,
+so that the last one is exactly `stop`.
+
+For example:
+
+```
+tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0]
+```
+
+start: First entry in the range.
+stop: Last entry in the range.
+num: Number of values to generate.
+output: 1-D. The generated values.
+)doc");
+
+REGISTER_OP("Complex")
+ .Input("real: float")
+ .Input("imag: float")
+ .Output("out: complex64")
+ .Doc(R"doc(
+Converts two real numbers to a complex number.
+
+Given a tensor `real` representing the real part of a complex number, and a
+tensor `imag` representing the imaginary part of a complex number, this
+operation returns complex numbers elementwise of the form \\(a + bj\\), where
+*a* represents the `real` part and *b* represents the `imag` part.
+
+The input tensors `real` and `imag` must have the same shape.
+
+For example:
+
+```
+# tensor 'real' is [2.25, 3.25]
+# tensor `imag` is [4.75, 5.75]
+tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]]
+```
+)doc");
+
+REGISTER_OP("Real")
+ .Input("in: complex64")
+ .Output("out: float")
+ .Doc(R"doc(
+Returns the real part of a complex number.
+
+Given a tensor `in` of complex numbers, this operation returns a tensor of type
+`float` that is the real part of each element in `in`. All elements in `in`
+must be complex numbers of the form \\(a + bj\\), where *a* is the real part
+returned by this operation and *b* is the imaginary part.
+
+For example:
+
+```
+# tensor 'in' is [-2.25 + 4.75j, 3.25 + 5.75j]
+tf.real(in) ==> [-2.25, 3.25]
+```
+)doc");
+
+REGISTER_OP("Imag")
+ .Input("in: complex64")
+ .Output("out: float")
+ .Doc(R"doc(
+Returns the imaginary part of a complex number.
+
+Given a tensor `in` of complex numbers, this operation returns a tensor of type
+`float` that is the imaginary part of each element in `in`. All elements in `in`
+must be complex numbers of the form \\(a + bj\\), where *a* is the real part
+and *b* is the imaginary part returned by this operation.
+
+For example:
+
+```
+# tensor 'in' is [-2.25 + 4.75j, 3.25 + 5.75j]
+tf.imag(in) ==> [4.75, 5.75]
+```
+)doc");
+
+REGISTER_OP("Conj")
+ .Input("in: complex64")
+ .Output("out: complex64")
+ .Doc(R"doc(
+Returns the complex conjugate of a complex number.
+
+Given a tensor `in` of complex numbers, this operation returns a tensor of
+complex numbers that are the complex conjugate of each element in `in`. The
+complex numbers in `in` must be of the form \\(a + bj\\), where *a* is the real
+part and *b* is the imaginary part.
+
+The complex conjugate returned by this operation is of the form \\(a - bj\\).
+
+For example:
+
+```
+# tensor 'in' is [-2.25 + 4.75j, 3.25 + 5.75j]
+tf.conj(in) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
+```
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
new file mode 100644
index 0000000000..03ba49d5cd
--- /dev/null
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -0,0 +1,543 @@
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/util/padding.h"
+namespace tensorflow {
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("AvgPool")
+ .Input("value: T")
+ .Output("output: T")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Performs average pooling on the input.
+
+Each entry in `output` is the mean of the corresponding size `ksize`
+window in `value`.
+
+value: 4-D with shape `[batch, height, width, channels]`.
+ksize: The size of the sliding window for each dimension of `value`.
+strides: The stride of the sliding window for each dimension of `value`.
+padding: The type of padding algorithm to use.
+output: The average pooled output tensor.
+)doc");
+
+REGISTER_OP("AvgPoolGrad")
+ .Input("orig_input_shape: int32")
+ .Input("grad: T")
+ .Output("output: T")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Computes gradients of the average pooling function.
+
+orig_input_shape: 1-D. Shape of the original input to `avg_pool`.
+grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t.
+ the output of `avg_pool`.
+ksize: The size of the sliding window for each dimension of the input.
+strides: The stride of the sliding window for each dimension of the input.
+padding: The type of padding algorithm to use.
+output: 4-D. Gradients w.r.t. the input of `avg_pool`.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("BatchNormWithGlobalNormalization")
+ .Input("t: T")
+ .Input("m: T")
+ .Input("v: T")
+ .Input("beta: T")
+ .Input("gamma: T")
+ .Output("result: T")
+ .Attr("T: numbertype")
+ .Attr("variance_epsilon: float")
+ .Attr("scale_after_normalization: bool")
+ .Doc(R"doc(
+Batch normalization.
+
+t: A 4D input Tensor.
+m: A 1D mean Tensor with size matching the last dimension of t.
+ This is the first output from MovingMoments.
+v: A 1D variance Tensor with size matching the last dimension of t.
+ This is the second output from MovingMoments.
+beta: A 1D beta Tensor with size matching the last dimension of t.
+ An offset to be added to the normalized tensor.
+gamma: A 1D gamma Tensor with size matching the last dimension of t.
+ If "scale_after_normalization" is true, this tensor will be multiplied
+ with the normalized tensor.
+variance_epsilon: A small float number to avoid dividing by 0.
+scale_after_normalization: A bool indicating whether the resulted tensor
+ needs to be multiplied with gamma.
+)doc");
+
+REGISTER_OP("BatchNormWithGlobalNormalizationGrad")
+ .Input("t: T")
+ .Input("m: T")
+ .Input("v: T")
+ .Input("gamma: T")
+ .Input("backprop: T")
+ .Output("dx: T")
+ .Output("dm: T")
+ .Output("dv: T")
+ .Output("db: T")
+ .Output("dg: T")
+ .Attr("T: numbertype")
+ .Attr("variance_epsilon: float")
+ .Attr("scale_after_normalization: bool")
+ .Doc(R"doc(
+Gradients for batch normalization.
+
+t: A 4D input Tensor.
+m: A 1D mean Tensor with size matching the last dimension of t.
+ This is the first output from MovingMoments.
+v: A 1D variance Tensor with size matching the last dimension of t.
+ This is the second output from MovingMoments.
+gamma: A 1D gamma Tensor with size matching the last dimension of t.
+ If "scale_after_normalization" is true, this Tensor will be multiplied
+ with the normalized Tensor.
+backprop: 4D backprop Tensor.
+variance_epsilon: A small float number to avoid dividing by 0.
+scale_after_normalization: A bool indicating whether the resulted tensor
+ needs to be multiplied with gamma.
+
+dx: 4D backprop tensor for input.
+dm: 1D backprop tensor for mean.
+dv: 1D backprop tensor for variance.
+db: 1D backprop tensor for beta.
+dg: 1D backprop tensor for gamma.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("BiasAdd")
+ .Attr("T: numbertype")
+ .Input("value: T")
+ .Input("bias: T")
+ .Output("output: T")
+ .Doc(R"doc(
+Adds `bias` to `value`.
+
+This is a special case of `tf.add` where `bias` is restricted to be 1-D.
+Broadcasting is supported, so `value` may have any number of dimensions.
+
+value: Any number of dimensions.
+bias: 1-D with size the last dimension of `value`.
+output: Broadcasted sum of `value` and `bias`.
+)doc");
+// --------------------------------------------------------------------------
+
+REGISTER_OP("Conv2D")
+ .Input("input: T")
+ .Input("filter: T")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Attr("strides: list(int)")
+ .Attr("use_cudnn_on_gpu: bool = true")
+ .Attr(GetPaddingAttrString())
+ .Doc(R"doc(
+Computes a 2-D convolution given 4-D `input` and `filter` tensors.
+
+Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
+and a filter / kernel tensor of shape
+`[filter_height, filter_width, in_channels, out_channels]`, this op
+performs the following:
+
+1. Flattens the filter to a 2-D matrix with shape
+ `[filter_height * filter_width * in_channels, output_channels]`.
+2. Extracts image patches from the the input tensor to form a *virtual*
+ tensor of shape `[batch, out_height, out_width,
+ filter_height * filter_width * in_channels]`.
+3. For each patch, right-multiplies the filter matrix and the image patch
+ vector.
+
+In detail,
+
+ output[b, i, j, k] =
+ sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] *
+ filter[di, dj, q, k]
+
+Must have `strides[0] = strides[3] = 1`. For the most common case of the same
+horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
+
+strides: 1-D of length 4. The stride of the sliding window for each dimension
+ of `input`.
+padding: The type of padding algorithm to use.
+)doc");
+
+REGISTER_OP("Conv2DBackpropInput")
+ .Input("input_sizes: int32")
+ .Input("filter: T")
+ .Input("out_backprop: T")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Attr("strides: list(int)")
+ .Attr("use_cudnn_on_gpu: bool = true")
+ .Attr(GetPaddingAttrString())
+ .Doc(R"doc(
+Computes the gradients of convolution with respect to the input.
+
+input_sizes: An integer vector representing the shape of `input`,
+ where `input` is a 4-D `[batch, height, width, channels]` tensor.
+filter: 4-D with shape
+ `[filter_height, filter_width, in_channels, out_channels]`.
+out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`.
+ Gradients w.r.t. the output of the convolution.
+strides: The stride of the sliding window for each dimension of the input
+ of the convolution.
+padding: The type of padding algorithm to use.
+output: 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient
+ w.r.t. the input of the convolution.
+)doc");
+
+// TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a
+// more general string attribute ('kernel_impl'?) that can be used to
+// select among several possible implementations.
+REGISTER_OP("Conv2DBackpropFilter")
+ .Input("input: T")
+ .Input("filter_sizes: int32")
+ .Output("output: T")
+ .Input("out_backprop: T")
+ .Attr("T: {float, double}")
+ .Attr("strides: list(int)")
+ .Attr("use_cudnn_on_gpu: bool = true")
+ .Attr(GetPaddingAttrString())
+ .Doc(R"doc(
+Computes the gradients of convolution with respect to the filter.
+
+input: 4-D with shape `[batch, in_height, in_width, in_channels]`.
+filter_sizes: An integer vector representing the tensor shape of `filter`,
+ where `filter` is a 4-D
+ `[filter_height, filter_width, in_channels, out_channels]` tensor.
+out_backprop: 4-D with shape `[batch, out_height, out_width, out_channels]`.
+ Gradients w.r.t. the output of the convolution.
+strides: The stride of the sliding window for each dimension of the input
+ of the convolution.
+padding: The type of padding algorithm to use.
+output: 4-D with shape
+ `[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t.
+ the `filter` input of the convolution.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("L2Loss")
+ .Input("t: T")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .Doc(R"doc(
+L2 Loss.
+
+Computes half the L2 norm of a tensor without the `sqrt`:
+
+ output = sum(t ** 2) / 2
+
+t: Typically 2-D, but may have any dimensions.
+output: 0-D.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("LRN")
+ .Input("input: float")
+ .Output("output: float")
+ .Attr("depth_radius: int = 5")
+ .Attr("bias: float = 1.0")
+ .Attr("alpha: float = 1.0")
+ .Attr("beta: float = 0.5")
+ .Doc(R"doc(
+Local Response Normalization.
+
+The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last
+dimension), and each vector is normalized independently. Within a given vector,
+each component is divided by the weighted, squared sum of inputs within
+`depth_radius`. In detail,
+
+ sqr_sum[a, b, c, d] =
+ sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2)
+ output = input / (bias + alpha * sqr_sum ** beta)
+
+For details, see [Krizhevsky et al., ImageNet classification with deep
+convolutional neural networks (NIPS 2012)]
+(http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks).
+
+input: 4-D.
+depth_radius: 0-D. Half-width of the 1-D normalization window.
+bias: An offset (usually positive to avoid dividing by 0).
+alpha: A scale factor, usually positive.
+beta: An exponent.
+)doc");
+
+REGISTER_OP("LRNGrad")
+ .Input("input_grads: float")
+ .Input("input_image: float")
+ .Input("output_image: float")
+ .Output("output: float")
+ .Attr("depth_radius: int = 5")
+ .Attr("bias: float = 1.0")
+ .Attr("alpha: float = 1.0")
+ .Attr("beta: float = 0.5")
+ .Doc(R"doc(
+Gradients for Local Response Normalization.
+
+input_grads: 4-D with shape `[batch, height, width, channels]`.
+input_image: 4-D with shape `[batch, height, width, channels]`.
+output_image: 4-D with shape `[batch, height, width, channels]`.
+depth_radius: A depth radius.
+bias: An offset (usually > 0 to avoid dividing by 0).
+alpha: A scale factor, usually positive.
+beta: An exponent.
+output: The gradients for LRN.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("MaxPool")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Input("input: float")
+ .Output("output: float")
+ .Doc(R"doc(
+Performs max pooling on the input.
+
+ksize: The size of the window for each dimension of the input tensor.
+strides: The stride of the sliding window for each dimension of the
+ input tensor.
+padding: The type of padding algorithm to use.
+input: 4-D input to pool over.
+output: The max pooled output tensor.
+)doc");
+
+REGISTER_OP("MaxPoolGrad")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Input("orig_input: float")
+ .Input("orig_output: float")
+ .Input("grad: float")
+ .Output("output: float")
+ .Doc(R"doc(
+Computes gradients of the maxpooling function.
+
+ksize: The size of the window for each dimension of the input tensor.
+strides: The stride of the sliding window for each dimension of the
+ input tensor.
+padding: The type of padding algorithm to use.
+orig_input: The original input tensor.
+orig_output: The original output tensor.
+grad: 4-D. Gradients w.r.t. the output of `max_pool`.
+output: Gradients w.r.t. the input to `max_pool`.
+)doc");
+
+REGISTER_OP("MaxPoolWithArgmax")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr("Targmax: {int32, int64} = DT_INT64")
+ .Attr(GetPaddingAttrString())
+ .Input("input: float")
+ .Output("output: float")
+ .Output("argmax: Targmax")
+ .Doc(R"doc(
+Performs max pooling on the input and outputs both max values and indices.
+
+The indices in `argmax` are flattened, so that a maximum value at position
+`[b, y, x, c]` becomes flattened index
+`((b * height + y) * width + x) * channels + c`.
+
+ksize: The size of the window for each dimension of the input tensor.
+strides: The stride of the sliding window for each dimension of the
+ input tensor.
+padding: The type of padding algorithm to use.
+input: 4-D with shape `[batch, height, width, channels]`. Input to pool over.
+output: The max pooled output tensor.
+argmax: 4-D. The flattened indices of the max values chosen for each output.
+)doc");
+
+REGISTER_OP("MaxPoolGradWithArgmax")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Attr("Targmax: {int32, int64}")
+ .Input("input: float")
+ .Input("grad: float")
+ .Input("argmax: Targmax")
+ .Output("output: float")
+ .Doc(R"doc(
+Computes gradients of the maxpooling function.
+
+ksize: The size of the window for each dimension of the input tensor.
+strides: The stride of the sliding window for each dimension of the
+ input tensor.
+padding: The type of padding algorithm to use.
+input: The original input.
+grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the
+ output of `max_pool`.
+argmax: The indices of the maximum values chosen for each output of `max_pool`.
+output: Gradients w.r.t. the input of `max_pool`.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("Relu")
+ .Input("features: T")
+ .Output("activations: T")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Computes rectified linear: `max(features, 0)`.
+)doc");
+
+REGISTER_OP("ReluGrad")
+ .Input("gradients: T")
+ .Input("features: T")
+ .Output("backprops: T")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Computes rectified linear gradients for a Relu operation.
+
+gradients: The backpropagated gradients to the corresponding Relu operation.
+features: The features passed as input to the corresponding Relu operation.
+backprops: The gradients: `gradients * features * (features > 0)`.
+)doc");
+
+REGISTER_OP("Relu6")
+ .Input("features: T")
+ .Output("activations: T")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Computes rectified linear 6: `min(max(features, 0), 6)`.
+)doc");
+
+REGISTER_OP("Relu6Grad")
+ .Input("gradients: T")
+ .Input("features: T")
+ .Output("backprops: T")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Computes rectified linear 6 gradients for a Relu6 operation.
+
+gradients: The backpropagated gradients to the corresponding Relu6 operation.
+features: The features passed as input to the corresponding Relu6 operation.
+backprops: The gradients:
+ `gradients * features * (features > 0) * (features < 6)`.
+)doc");
+
+REGISTER_OP("Softplus")
+ .Input("features: T")
+ .Output("activations: T")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Computes softplus: `log(exp(features) + 1)`.
+)doc");
+
+REGISTER_OP("SoftplusGrad")
+ .Input("gradients: T")
+ .Input("features: T")
+ .Output("backprops: T")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Computes softplus gradients for a softplus operation.
+
+gradients: The backpropagated gradients to the corresponding softplus operation.
+features: The features passed as input to the corresponding softplus operation.
+backprops: The gradients: `gradients / (1 + exp(-features))`.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("Softmax")
+ .Input("logits: T")
+ .Output("softmax: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Computes softmax activations.
+
+For each batch `i` and class `j` we have
+
+ softmax[i, j] = exp(logits[i, j]) / sum(exp(logits[i]))
+
+logits: 2-D with shape `[batch_size, num_classes]`.
+softmax: Same shape as `logits`.
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("SoftmaxCrossEntropyWithLogits")
+ .Input("features: T")
+ .Input("labels: T")
+ .Output("loss: T")
+ .Output("backprop: T")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Computes softmax cross entropy cost and gradients to backpropagate.
+
+Inputs are the logits, not probabilities.
+
+features: batch_size x num_classes matrix
+labels: batch_size x num_classes matrix
+ The caller must ensure that each batch of labels represents a valid
+ probability distribution.
+loss: Per example loss (batch_size vector).
+backprop: backpropagated gradients (batch_size x num_classes matrix).
+)doc");
+
+// --------------------------------------------------------------------------
+
+REGISTER_OP("InTopK")
+ .Attr("k: int")
+ .Input("predictions: float")
+ .Input("targets: int32")
+ .Output("precision: bool")
+ .Doc(R"doc(
+Says whether the targets are in the top K predictions.
+
+This outputs a batch_size bool array, an entry out[i] is true if the
+prediction for the target class is among the top k predictions among
+all predictions for example i. Note that the behavior of InTopK differs
+from the TopK op in its handling of ties; if multiple classes have the
+same prediction value and straddle the top-k boundary, all of those
+classes are considered to be in the top k.
+
+More formally, let
+
+ \\(predictions_i\\) be the predictions for all classes for example i,
+ \\(targets_i\\) be the target class for example i,
+ \\(out_i\\) be the output for example i,
+
+$$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
+
+predictions: A batch_size x classes tensor
+targets: A batch_size vector of class ids
+k: Number of top elements to look at for computing precision
+precision: Computed Precision at k as a bool Tensor
+
+)doc");
+
+REGISTER_OP("TopK")
+ .Attr("k: int >= 1")
+ .Input("input: T")
+ .Output("values: T")
+ .Output("indices: int32")
+ .Attr("T: realnumbertype")
+ .Doc(R"doc(
+Returns the values and indices of the k largest elements for each row.
+
+\\(values_{i, j}\\) represents the j-th largest element in \\(input_i\\).
+
+\\(indices_{i, j}\\) gives the column index of the corresponding element,
+such that \\(input_{i, indices_{i, j}} = values_{i, j}\\). If two
+elements are equal, the lower-index element appears first.
+
+k: Number of top elements to look for within each row
+input: A batch_size x classes tensor
+values: A batch_size x k tensor with the k largest elements for each row,
+ sorted in descending order
+indices: A batch_size x k tensor with the index of each value within each row
+
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/no_op.cc b/tensorflow/core/ops/no_op.cc
new file mode 100644
index 0000000000..52778917cb
--- /dev/null
+++ b/tensorflow/core/ops/no_op.cc
@@ -0,0 +1,10 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("NoOp")
+ .Doc(R"doc(
+Does nothing. Only useful as a placeholder for control edges.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc
new file mode 100644
index 0000000000..7fcaa3abf1
--- /dev/null
+++ b/tensorflow/core/ops/parsing_ops.cc
@@ -0,0 +1,104 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("DecodeRaw")
+ .Input("bytes: string")
+ .Output("output: out_type")
+ .Attr("out_type: {float,double,int32,uint8,int16,int8,int64}")
+ .Attr("little_endian: bool = true")
+ .Doc(R"doc(
+Reinterpret the bytes of a string as a vector of numbers.
+
+bytes: All the elements must have the same length.
+little_endian: Whether the input bytes are in little-endian order.
+ Ignored for out_types that are stored in a single byte like uint8.
+output: A Tensor with one more dimension than the input bytes. The
+ added dimension will have size equal to the length of the elements
+ of bytes divided by the number of bytes to represent out_type.
+)doc");
+
+REGISTER_OP("ParseExample")
+ .Input("serialized: string")
+ .Input("names: string")
+ .Input("sparse_keys: Nsparse * string")
+ .Input("dense_keys: Ndense * string")
+ .Input("dense_defaults: Tdense")
+ .Output("sparse_indices: Nsparse * int64")
+ .Output("sparse_values: sparse_types")
+ .Output("sparse_shapes: Nsparse * int64")
+ .Output("dense_values: Tdense")
+ .Attr("Nsparse: int >= 0") // Inferred from sparse_keys
+ .Attr("Ndense: int >= 0") // Inferred from dense_keys
+ .Attr("sparse_types: list({float,int64,string}) >= 0")
+ .Attr("Tdense: list({float,int64,string}) >= 0")
+ .Attr("dense_shapes: list(shape) >= 0")
+ .Doc(R"doc(
+Transforms a vector of brain.Example protos (as strings) into typed tensors.
+
+serialized: A vector containing a batch of binary serialized Example protos.
+names: A vector containing the names of the serialized protos.
+ May contain, for example, table key (descriptive) names for the
+ corresponding serialized protos. These are purely useful for debugging
+ purposes, and the presence of values here has no effect on the output.
+ May also be an empty vector if no names are available.
+ If non-empty, this vector must be the same length as "serialized".
+dense_keys: A list of Ndense string Tensors (scalars).
+ The keys expected in the Examples' features associated with dense values.
+dense_defaults: A list of Ndense Tensors (some may be empty).
+ dense_defaults[j] provides default values
+ when the example's feature_map lacks dense_key[j]. If an empty Tensor is
+ provided for dense_defaults[j], then the Feature dense_keys[j] is required.
+ The input type is inferred from dense_defaults[j], even when it's empty.
+ If dense_defaults[j] is not empty, its shape must match dense_shapes[j].
+dense_shapes: A list of Ndense shapes; the shapes of data in each Feature
+ given in dense_keys.
+ The number of elements in the Feature corresponding to dense_key[j]
+ must always equal dense_shapes[j].NumEntries().
+ If dense_shapes[j] == (D0, D1, ..., DN) then the the shape of output
+ Tensor dense_values[j] will be (|serialized|, D0, D1, ..., DN):
+ The dense outputs are just the inputs row-stacked by batch.
+sparse_keys: A list of Nsparse string Tensors (scalars).
+ The keys expected in the Examples' features associated with sparse values.
+sparse_types: A list of Nsparse types; the data types of data in each Feature
+ given in sparse_keys.
+ Currently the ParseExample supports DT_FLOAT (FloatList),
+ DT_INT64 (Int64List), and DT_STRING (BytesList).
+)doc");
+
+REGISTER_OP("DecodeCSV")
+ .Input("records: string")
+ .Input("record_defaults: OUT_TYPE")
+ .Output("output: OUT_TYPE")
+ .Attr("OUT_TYPE: list({float,int32,int64,string})")
+ .Attr("field_delim: string = ','")
+ .Doc(R"doc(
+Convert CSV records to tensors. Each column maps to one tensor.
+
+RFC 4180 format is expected for the CSV records.
+(https://tools.ietf.org/html/rfc4180)
+Note that we allow leading and trailing spaces with int or float field.
+
+records: Each string is a record/row in the csv and all records should have
+ the same format.
+record_defaults: One tensor per column of the input record, with either a
+ scalar default value for that column or empty if the column is required.
+field_delim: delimiter to separate fields in a record.
+output: Each tensor will have the same shape as records.
+)doc");
+
+REGISTER_OP("StringToNumber")
+ .Input("string_tensor: string")
+ .Output("output: out_type")
+ .Attr("out_type: {float, int32} = DT_FLOAT")
+ .Doc(R"doc(
+Converts each string in the input Tensor to the specified numeric type.
+
+(Note that int32 overflow results in an error while float overflow
+results in a rounded value.)
+
+out_type: The numeric type to interpret each string in string_tensor as.
+output: A Tensor of the same shape as the input string_tensor.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc
new file mode 100644
index 0000000000..4be4354b85
--- /dev/null
+++ b/tensorflow/core/ops/random_ops.cc
@@ -0,0 +1,108 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("RandomUniform")
+ .Input("shape: T")
+ .SetIsStateful()
+ .Output("output: dtype")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("dtype: {float,double}")
+ .Attr("T: {int32, int64}")
+ .Doc(R"doc(
+Outputs random values from a uniform distribution.
+
+The generated values follow a uniform distribution in the range `[0, 1)`. The
+lower bound 0 is included in the range, while the upper bound 1 is excluded.
+
+shape: The shape of the output tensor.
+dtype: The type of the output.
+seed: If either `seed` or `seed2` are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: A second seed to avoid seed collision.
+
+output: A tensor of the specified shape filled with uniform random values.
+)doc");
+
+REGISTER_OP("RandomStandardNormal")
+ .Input("shape: T")
+ .SetIsStateful()
+ .Output("output: dtype")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("dtype: {float,double}")
+ .Attr("T: {int32, int64}")
+ .Doc(R"doc(
+Outputs random values from a normal distribution.
+
+The generated values will have mean 0 and standard deviation 1.
+
+shape: The shape of the output tensor.
+dtype: The type of the output.
+seed: If either `seed` or `seed2` are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: A second seed to avoid seed collision.
+
+output: A tensor of the specified shape filled with random normal values.
+)doc");
+
+REGISTER_OP("TruncatedNormal")
+ .Input("shape: T")
+ .SetIsStateful()
+ .Output("output: dtype")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("dtype: {float,double}")
+ .Attr("T: {int32, int64}")
+ .Doc(R"doc(
+Outputs random values from a truncated normal distribution.
+
+The generated values follow a normal distribution with mean 0 and standard
+deviation 1, except that values whose magnitude is more than 2 standard
+deviations from the mean are dropped and re-picked.
+
+shape: The shape of the output tensor.
+dtype: The type of the output.
+seed: If either `seed` or `seed2` are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: A second seed to avoid seed collision.
+
+output: A tensor of the specified shape filled with random truncated normal
+ values.
+)doc");
+
+REGISTER_OP("RandomShuffle")
+ .Input("value: T")
+ .SetIsStateful()
+ .Output("output: T")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("T: type")
+ .Doc(R"doc(
+Randomly shuffles a tensor along its first dimension.
+
+ The tensor is shuffled along dimension 0, such that each `value[j]` is mapped
+ to one and only one `output[i]`. For example, a mapping that might occur for a
+ 3x2 tensor is:
+
+```prettyprint
+[[1, 2], [[5, 6],
+ [3, 4], ==> [1, 2],
+ [5, 6]] [3, 4]]
+```
+
+value: The tensor to be shuffled.
+seed: If either `seed` or `seed2` are set to be non-zero, the random number
+ generator is seeded by the given seed. Otherwise, it is seeded by a
+ random seed.
+seed2: A second seed to avoid seed collision.
+
+output: A tensor of same shape and type as `value`, shuffled along its first
+ dimension.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/sendrecv_ops.cc b/tensorflow/core/ops/sendrecv_ops.cc
new file mode 100644
index 0000000000..51158263c1
--- /dev/null
+++ b/tensorflow/core/ops/sendrecv_ops.cc
@@ -0,0 +1,99 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("_Send")
+ .Input("tensor: T")
+ .Attr("T: type")
+ .Attr("tensor_name: string")
+ .Attr("send_device: string")
+ .Attr("send_device_incarnation: int")
+ .Attr("recv_device: string")
+ .Attr("client_terminated: bool = false")
+ .Doc(R"doc(
+Sends the named tensor from send_device to recv_device.
+
+tensor: The tensor to send.
+tensor_name: The name of the tensor to send.
+send_device: The name of the device sending the tensor.
+send_device_incarnation: The current incarnation of send_device.
+recv_device: The name of the device receiving the tensor.
+client_terminated: If set to true, this indicates that the node was added
+ to the graph as a result of a client-side feed or fetch of Tensor data,
+ in which case the corresponding send or recv is expected to be managed
+ locally by the caller.
+)doc");
+
+REGISTER_OP("_Recv")
+ .Output("tensor: tensor_type")
+ .Attr("tensor_type: type")
+ .Attr("tensor_name: string")
+ .Attr("send_device: string")
+ .Attr("send_device_incarnation: int")
+ .Attr("recv_device: string")
+ .Attr("client_terminated: bool = false")
+ .Doc(R"doc(
+Receives the named tensor from send_device on recv_device.
+
+tensor: The tensor to receive.
+tensor_name: The name of the tensor to receive.
+send_device: The name of the device sending the tensor.
+send_device_incarnation: The current incarnation of send_device.
+recv_device: The name of the device receiving the tensor.
+client_terminated: If set to true, this indicates that the node was added
+ to the graph as a result of a client-side feed or fetch of Tensor data,
+ in which case the corresponding send or recv is expected to be managed
+ locally by the caller.
+)doc");
+
+REGISTER_OP("_HostSend")
+ .Input("tensor: T")
+ .Attr("T: type")
+ .Attr("tensor_name: string")
+ .Attr("send_device: string")
+ .Attr("send_device_incarnation: int")
+ .Attr("recv_device: string")
+ .Attr("client_terminated: bool = false")
+ .Doc(R"doc(
+Sends the named tensor from send_device to recv_device.
+
+_HostSend requires its input on host memory whereas _Send requires its
+input on device memory.
+
+tensor: The tensor to send.
+tensor_name: The name of the tensor to send.
+send_device: The name of the device sending the tensor.
+send_device_incarnation: The current incarnation of send_device.
+recv_device: The name of the device receiving the tensor.
+client_terminated: If set to true, this indicates that the node was added
+ to the graph as a result of a client-side feed or fetch of Tensor data,
+ in which case the corresponding send or recv is expected to be managed
+ locally by the caller.
+)doc");
+
+REGISTER_OP("_HostRecv")
+ .Output("tensor: tensor_type")
+ .Attr("tensor_type: type")
+ .Attr("tensor_name: string")
+ .Attr("send_device: string")
+ .Attr("send_device_incarnation: int")
+ .Attr("recv_device: string")
+ .Attr("client_terminated: bool = false")
+ .Doc(R"doc(
+Receives the named tensor from send_device on recv_device.
+
+_HostRecv requires its input on host memory whereas _Recv requires its
+input on device memory.
+
+tensor: The tensor to receive.
+tensor_name: The name of the tensor to receive.
+send_device: The name of the device sending the tensor.
+send_device_incarnation: The current incarnation of send_device.
+recv_device: The name of the device receiving the tensor.
+client_terminated: If set to true, this indicates that the node was added
+ to the graph as a result of a client-side feed or fetch of Tensor data,
+ in which case the corresponding send or recv is expected to be managed
+ locally by the caller.
+)doc");
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc
new file mode 100644
index 0000000000..51262373d5
--- /dev/null
+++ b/tensorflow/core/ops/sparse_ops.cc
@@ -0,0 +1,134 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("SparseToDense")
+ .Input("sparse_indices: Tindices")
+ .Input("output_shape: Tindices")
+ .Input("sparse_values: T")
+ .Input("default_value: T")
+ .Output("dense: T")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .Doc(R"doc(
+Converts a sparse representation into a dense tensor.
+
+Builds an array `dense` with shape `output_shape` such that
+
+```prettyprint
+# If sparse_indices is scalar
+dense[i] = (i == sparse_indices ? sparse_values : default_value)
+
+# If sparse_indices is a vector, then for each i
+dense[sparse_indices[i]] = sparse_values[i]
+
+# If sparse_indices is an n by d matrix, then for each i in [0, n)
+dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i]
+```
+
+All other values in `dense` are set to `default_value`. If `sparse_values` is a
+scalar, all sparse indices are set to this single value.
+
+sparse_indices: 0-D, 1-D, or 2-D. `sparse_indices[i]` contains the complete
+ index where `sparse_values[i]` will be placed.
+output_shape: 1-D. Shape of the dense output tensor.
+sparse_values: 1-D. Values corresponding to each row of `sparse_indices`,
+ or a scalar value to be used for all sparse indices.
+default_value: Scalar value to set for indices not specified in
+ `sparse_indices`.
+dense: Dense output tensor of shape `output_shape`.
+)doc");
+
+REGISTER_OP("SparseConcat")
+ .Input("indices: N * int64")
+ .Input("values: N * T")
+ .Input("shapes: N * int64")
+ .Output("output_indices: int64")
+ .Output("output_values: T")
+ .Output("output_shape: int64")
+ .Attr("concat_dim: int >= 0")
+ .Attr("N: int >= 2")
+ .Attr("T: type")
+ .Doc(R"doc(
+Concatenates a list of `SparseTensor` along the specified dimension.
+
+Concatenation is with respect to the dense versions of these sparse tensors.
+It is assumed that each input is a `SparseTensor` whose elements are ordered
+along increasing dimension number.
+
+All inputs' shapes must match, except for the concat dimension. The
+`indices`, `values`, and `shapes` lists must have the same length.
+
+The output shape is identical to the inputs', except along the concat
+dimension, where it is the sum of the inputs' sizes along that dimension.
+
+The output elements will be resorted to preserve the sort order along
+increasing dimension number.
+
+This op runs in `O(M log M)` time, where `M` is the total number of non-empty
+values across all inputs. This is due to the need for an internal sort in
+order to concatenate efficiently across an arbitrary dimension.
+
+For example, if `concat_dim = 1` and the inputs are
+
+ sp_inputs[0]: shape = [2, 3]
+ [0, 2]: "a"
+ [1, 0]: "b"
+ [1, 1]: "c"
+
+ sp_inputs[1]: shape = [2, 4]
+ [0, 1]: "d"
+ [0, 2]: "e"
+
+then the output will be
+
+ shape = [2, 7]
+ [0, 2]: "a"
+ [0, 4]: "d"
+ [0, 5]: "e"
+ [1, 0]: "b"
+ [1, 1]: "c"
+
+Graphically this is equivalent to doing
+
+ [ a] concat [ d e ] = [ a d e ]
+ [b c ] [ ] [b c ]
+
+indices: 2-D. Indices of each input `SparseTensor`.
+values: 1-D. Non-empty values of each `SparseTensor`.
+shapes: 1-D. Shapes of each `SparseTensor`.
+output_indices: 2-D. Indices of the concatenated `SparseTensor`.
+output_values: 1-D. Non-empty values of the concatenated `SparseTensor`.
+output_shape: 1-D. Shape of the concatenated `SparseTensor`.
+concat_dim: Dimension to concatenate along.
+)doc");
+
+REGISTER_OP("SparseReorder")
+ .Input("input_indices: int64")
+ .Input("input_values: T")
+ .Input("input_shape: int64")
+ .Output("output_indices: int64")
+ .Output("output_values: T")
+ .Attr("T: type")
+ .Doc(R"doc(
+Reorders a SparseTensor into the canonical, row-major ordering.
+
+Note that by convention, all sparse ops preserve the canonical ordering along
+increasing dimension number. The only time ordering can be violated is during
+manual manipulation of the indices and values vectors to add entries.
+
+Reordering does not affect the shape of the SparseTensor.
+
+If the tensor has rank `R` and `N` non-empty values, `input_indices` has
+shape `[N, R]`, input_values has length `N`, and input_shape has length `R`.
+
+input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a
+ SparseTensor, possibly not in canonical ordering.
+input_values: 1-D. `N` non-empty values corresponding to `input_indices`.
+input_shape: 1-D. Shape of the input SparseTensor.
+output_indices: 2-D. `N x R` matrix with the same indices as input_indices, but
+ in canonical row-major ordering.
+output_values: 1-D. `N` non-empty values corresponding to `output_indices`.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
new file mode 100644
index 0000000000..da9fd4ad08
--- /dev/null
+++ b/tensorflow/core/ops/state_ops.cc
@@ -0,0 +1,290 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("Variable")
+ .Output("ref: Ref(dtype)")
+ .Attr("shape: shape")
+ .Attr("dtype: type")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+Holds state in the form of a tensor that persists across steps.
+
+Outputs a ref to the tensor state so it may be read or modified.
+TODO(zhifengc/mrry): Adds a pointer to a more detail document
+about sharing states in tensorflow.
+
+ref: A reference to the variable tensor.
+shape: The shape of the variable tensor.
+dtype: The type of elements in the variable tensor.
+container: If non-empty, this variable is placed in the given container.
+ Otherwise, a default container is used.
+shared_name: If non-empty, this variable is named in the given bucket
+ with this shared_name. Otherwise, the node name is used instead.
+)doc");
+
+REGISTER_OP("TemporaryVariable")
+ .Output("ref: Ref(dtype)")
+ .Attr("shape: shape")
+ .Attr("dtype: type")
+ .Attr("var_name: string = ''")
+ .SetIsStateful()
+ .Doc(R"doc(
+Returns a tensor that may be mutated, but only persists within a single step.
+
+This is an experimental op for internal use only and it is possible to use this
+op in unsafe ways. DO NOT USE unless you fully understand the risks.
+
+It is the caller's responsibility to ensure that 'ref' is eventually passed to a
+matching 'DestroyTemporaryVariable' op after all other uses have completed.
+
+Outputs a ref to the tensor state so it may be read or modified.
+
+ E.g.
+ var = state_ops._temporary_variable([1, 2], types.float_)
+ var_name = var.op.name
+ var = state_ops.assign(var, [[4.0, 5.0]])
+ var = state_ops.assign_add(var, [[6.0, 7.0]])
+ final = state_ops._destroy_temporary_variable(var, var_name=var_name)
+
+ref: A reference to the variable tensor.
+shape: The shape of the variable tensor.
+dtype: The type of elements in the variable tensor.
+var_name: Overrides the name used for the temporary variable resource. Default
+value is the name of the 'TemporaryVariable' op (which is guaranteed unique).
+)doc");
+
+REGISTER_OP("DestroyTemporaryVariable")
+ .Input("ref: Ref(T)")
+ .Output("value: T")
+ .Attr("T: type")
+ .Attr("var_name: string")
+ .Doc(R"doc(
+Destroys the temporary variable and returns its final value.
+
+Sets output to the value of the Tensor pointed to by 'ref', then destroys
+the temporary variable called 'var_name'.
+All other uses of 'ref' *must* have executed before this op.
+This is typically achieved by chaining the ref through each assign op, or by
+using control dependencies.
+
+Outputs the final value of the tensor pointed to by 'ref'.
+
+ref: A reference to the temporary variable tensor.
+var_name: Name of the temporary variable, usually the name of the matching
+'TemporaryVariable' op.
+)doc");
+
+REGISTER_OP("Assign")
+ .Input("ref: Ref(T)")
+ .Input("value: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: type")
+ .Attr("validate_shape: bool = true")
+ .Attr("use_locking: bool = true")
+ .SetAllowsUninitializedInput()
+ .Doc(R"doc(
+Update 'ref' by assigning 'value' to it.
+
+This operation outputs "ref" after the assignment is done.
+This makes it easier to chain operations that need to use the reset value.
+
+ref: Should be from a `Variable` node. May be uninitialized.
+value: The value to be assigned to the variable.
+validate_shape: If true, the operation will validate that the shape
+ of 'value' matches the shape of the Tensor being assigned to. If false,
+ 'ref' will take on the shape of 'value'.
+use_locking: If True, the assignment will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+output_ref:= Same as "ref". Returned as a convenience for operations that want
+ to use the new value after the variable has been reset.
+)doc");
+
+REGISTER_OP("AssignAdd")
+ .Input("ref: Ref(T)")
+ .Input("value: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update 'ref' by adding 'value' to it.
+
+This operation outputs "ref" after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+ref: Should be from a `Variable` node.
+value: The value to be added to the variable.
+use_locking: If True, the addition will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+output_ref:= Same as "ref". Returned as a convenience for operations that want
+ to use the new value after the variable has been updated.
+)doc");
+
+REGISTER_OP("AssignSub")
+ .Input("ref: Ref(T)")
+ .Input("value: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update 'ref' by subtracting 'value' from it.
+
+This operation outputs "ref" after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+ref: Should be from a `Variable` node.
+value: The value to be subtracted to the variable.
+use_locking: If True, the subtraction will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+output_ref:= Same as "ref". Returned as a convenience for operations that want
+ to use the new value after the variable has been updated.
+)doc");
+
+REGISTER_OP("ScatterUpdate")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = true")
+ .Doc(R"doc(
+Applies sparse updates to a variable reference.
+
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] = updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] = updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+If `indices` contains duplicate entries, lexicographically later entries
+override earlier entries.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/ScatterUpdate.png" alt>
+</div>
+
+ref: Should be from a `Variable` node.
+indices: A tensor of indices into the first dimension of `ref`.
+updates: A tensor of updated values to store in `ref`.
+output_ref:= Same as `ref`. Returned as a convenience for operations that want
+ to use the updated values after the update is done.
+use_locking: If True, the assignment will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("ScatterAdd")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Adds sparse updates to a variable reference.
+
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] += updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] += updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] += updates[i, ..., j, ...]
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions add.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/ScatterAdd.png" alt>
+</div>
+
+ref: Should be from a `Variable` node.
+indices: A tensor of indices into the first dimension of `ref`.
+updates: A tensor of updated values to add to `ref`.
+output_ref:= Same as `ref`. Returned as a convenience for operations that want
+ to use the updated values after the update is done.
+use_locking: If True, the addition will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("ScatterSub")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Subtracts sparse updates to a variable reference.
+
+ # Scalar indices
+ ref[indices, ...] -= updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] -= updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...]
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their (negated) contributions add.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/ScatterSub.png" alt>
+</div>
+
+ref: Should be from a `Variable` node.
+indices: A tensor of indices into the first dimension of `ref`.
+updates: A tensor of updated values to subtract from `ref`.
+output_ref:= Same as `ref`. Returned as a convenience for operations that want
+ to use the updated values after the update is done.
+use_locking: If True, the subtraction will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("CountUpTo")
+ .Input("ref: Ref(T)")
+ .Output("output: T")
+ .Attr("limit: int")
+ .Attr("T: {int32, int64}")
+ .Doc(R"doc(
+Increments 'ref' until it reaches 'limit'.
+
+This operation outputs "ref" after the update is done. This makes it
+easier to chain operations that need to use the updated value.
+
+ref: Should be from a scalar `Variable` node.
+limit: If incrementing ref would bring it above limit, instead generates an
+ 'OutOfRange' error.
+output: A copy of the input before increment. If nothing else modifies the
+ input, the values produced will all be distinct.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
new file mode 100644
index 0000000000..57b471074c
--- /dev/null
+++ b/tensorflow/core/ops/string_ops.cc
@@ -0,0 +1,21 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("StringToHashBucket")
+ .Input("string_tensor: string")
+ .Output("output: int64")
+ .Attr("num_buckets: int >= 1")
+ .Doc(R"doc(
+Converts each string in the input Tensor to its hash mod by a number of buckets.
+
+The hash function is deterministic on the content of the string within the
+process.
+
+Note that the hash function may change from time to time.
+
+num_buckets: The number of buckets.
+output: A Tensor of the same shape as the input string_tensor.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/summary_ops.cc b/tensorflow/core/ops/summary_ops.cc
new file mode 100644
index 0000000000..5f46c871b6
--- /dev/null
+++ b/tensorflow/core/ops/summary_ops.cc
@@ -0,0 +1,115 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
+// inputs or outputs in various ways.
+
+REGISTER_OP("ScalarSummary")
+ .Input("tags: string")
+ .Input("values: T")
+ .Output("summary: string")
+ .Attr("T: {float, double}")
+ .Doc(R"doc(
+Outputs a `Summary` protocol buffer with scalar values.
+
+The input `tags` and `values` must have the same shape. The generated summary
+has a summary value for each tag-value pair in `tags` and `values`.
+
+tags: 1-D. Tags for the summary.
+values: 1-D, same size as `tags. Values for the summary.
+summary: Scalar. Serialized `Summary` protocol buffer.
+)doc");
+
+REGISTER_OP("HistogramSummary")
+ .Input("tag: string")
+ .Input("values: float")
+ .Output("summary: string")
+ .Doc(R"doc(
+Outputs a `Summary` protocol buffer with a histogram.
+
+The generated
+[`Summary`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/summary.proto)
+has one summary value containing a histogram for `values`.
+
+This op reports an `OutOfRange` error if any value is not finite.
+
+tag: Scalar. Tag to use for the `Summary.Value`.
+values: Any shape. Values to use to build the histogram.
+summary: Scalar. Serialized `Summary` protocol buffer.
+)doc");
+
+REGISTER_OP("ImageSummary")
+ .Input("tag: string")
+ .Input("tensor: float")
+ .Output("summary: string")
+ .Attr("max_images: int >= 1 = 3")
+ .Attr(
+ "bad_color: tensor = { dtype: DT_UINT8 "
+ "tensor_shape: { dim { size: 4 } } "
+ "int_val: 255 int_val: 0 int_val: 0 int_val: 255 }")
+ .Doc(R"doc(
+Outputs a `Summary` protocol buffer with images.
+
+The summary has up to `max_images` summary values containing images. The
+images are built from `tensor` which must be 4-D with shape `[batch_size,
+height, width, channels]` and where `channels` can be:
+
+* 1: `tensor` is interpreted as Grayscale.
+* 3: `tensor` is interpreted as RGB.
+* 4: `tensor` is interpreted as RGBA.
+
+The images have the same number of channels as the input tensor. Their values
+are normalized, one image at a time, to fit in the range `[0, 255]`. The
+op uses two different normalization algorithms:
+
+* If the input values are all positive, they are rescaled so the largest one
+ is 255.
+
+* If any input value is negative, the values are shifted so input value 0.0
+ is at 127. They are then rescaled so that either the smallest value is 0,
+ or the largest one is 255.
+
+The `tag` argument is a scalar `Tensor` of type `string`. It is used to
+build the `tag` of the summary values:
+
+* If `max_images` is 1, the summary value tag is '*tag*/image'.
+* If `max_images` is greater than 1, the summary value tags are
+ generated sequentially as '*tag*/image/0', '*tag*/image/1', etc.
+
+The `bad_color` argument is the color to use in the generated images for
+non-finite input values. It is a `unit8` 1-D tensor of length `channels`.
+Each element must be in the range `[0, 255]` (It represents the value of a
+pixel in the output image). Non-finite values in the input tensor are
+replaced by this tensor in the output image. The default value is the color
+red.
+
+tag: Scalar. Used to build the `tag` attribute of the summary values.
+tensor: 4-D of shape `[batch_size, height, width, channels]` where
+ `channels` is 1, 3, or 4.
+max_images: Max number of batch elements to generate images for.
+bad_color: Color to use for pixels with non-finite values.
+summary: Scalar. Serialized `Summary` protocol buffer.
+)doc");
+
+REGISTER_OP("MergeSummary")
+ .Input("inputs: N * string")
+ .Output("summary: string")
+ .Attr("N : int >= 1")
+ .Doc(R"doc(
+Merges summaries.
+
+This op creates a
+[`Summary`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/summary.proto)
+protocol buffer that contains the union of all the values in the input
+summaries.
+
+When the Op is run, it reports an `InvalidArgument` error if multiple values
+in the summaries to merge use the same tag.
+
+inputs: Can be of any shape. Each must contain serialized `Summary` protocol
+ buffers.
+summary: Scalar. Serialized `Summary` protocol buffer.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc
new file mode 100644
index 0000000000..e7b4e92fd5
--- /dev/null
+++ b/tensorflow/core/ops/training_ops.cc
@@ -0,0 +1,199 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("ApplyGradientDescent")
+ .Input("var: Ref(T)")
+ .Input("alpha: T")
+ .Input("delta: T")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update '*var' by subtracting 'alpha' * 'delta' from it.
+
+var: Should be from a Variable().
+alpha: Scaling factor. Must be a scalar.
+delta: The change.
+out: Same as "var".
+use_locking: If True, the subtraction will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("ApplyAdagrad")
+ .Input("var: Ref(T)")
+ .Input("accum: Ref(T)")
+ .Input("lr: T")
+ .Input("grad: T")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update '*var' according to the adagrad scheme.
+
+accum += grad * grad
+var -= lr * grad * (1 / sqrt(accum))
+
+var: Should be from a Variable().
+accum: Should be from a Variable().
+lr: Scaling factor. Must be a scalar.
+grad: The gradient.
+out: Same as "var".
+use_locking: If True, updating of the var and accum tensors will be protected by
+a lock; otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("SparseApplyAdagrad")
+ .Input("var: Ref(T)")
+ .Input("accum: Ref(T)")
+ .Input("lr: T")
+ .Input("grad: T")
+ .Input("indices: Tindices")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update relevant entries in '*var' and '*accum' according to the adagrad scheme.
+
+That is for rows we have grad for, we update var and accum as follows:
+accum += grad * grad
+var -= lr * grad * (1 / sqrt(accum))
+
+var: Should be from a Variable().
+accum: Should be from a Variable().
+lr: Learning rate. Must be a scalar.
+grad: The gradient.
+indices: A vector of indices into the first dimension of var and accum.
+out: Same as "var".
+use_locking: If True, updating of the var and accum tensors will be protected by
+a lock; otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("ApplyMomentum")
+ .Input("var: Ref(T)")
+ .Input("accum: Ref(T)")
+ .Input("lr: T")
+ .Input("grad: T")
+ .Input("momentum: T")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update '*var' according to the momentum scheme.
+
+accum = accum * momentum + grad
+var -= lr * accum
+
+var: Should be from a Variable().
+accum: Should be from a Variable().
+lr: Scaling factor. Must be a scalar.
+grad: The gradient.
+momentum: Momentum. Must be a scalar.
+out: Same as "var".
+use_locking: If True, updating of the var and accum tensors will be protected by
+a lock; otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("SparseApplyMomentum")
+ .Input("var: Ref(T)")
+ .Input("accum: Ref(T)")
+ .Input("lr: T")
+ .Input("grad: T")
+ .Input("indices: Tindices")
+ .Input("momentum: T")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update relevant entries in '*var' and '*accum' according to the momentum scheme.
+
+That is for rows we have grad for, we update var and accum as follows:
+
+accum = accum * momentum + grad
+var -= lr * accum
+
+var: Should be from a Variable().
+accum: Should be from a Variable().
+lr: Learning rate. Must be a scalar.
+grad: The gradient.
+indices: A vector of indices into the first dimension of var and accum.
+momentum: Momentum. Must be a scalar.
+out: Same as "var".
+use_locking: If True, updating of the var and accum tensors will be protected by
+a lock; otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("ApplyAdam")
+ .Input("var: Ref(T)")
+ .Input("m: Ref(T)")
+ .Input("v: Ref(T)")
+ .Input("beta1_power: T")
+ .Input("beta2_power: T")
+ .Input("lr: T")
+ .Input("beta1: T")
+ .Input("beta2: T")
+ .Input("epsilon: T")
+ .Input("grad: T")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update '*var' according to the Adam algorithm.
+
+lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
+m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t
+v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t
+variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon)
+
+var: Should be from a Variable().
+m: Should be from a Variable().
+v: Should be from a Variable().
+beta1_power: Must be a scalar.
+beta2_power: Must be a scalar.
+lr: Scaling factor. Must be a scalar.
+beta1: Momentum factor. Must be a scalar.
+beta2: Momentum factor. Must be a scalar.
+epsilon: Ridge term. Must be a scalar.
+grad: The gradient.
+out: Same as "var".
+use_locking: If True, updating of the var, m, and v tensors will be protected by
+a lock; otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("ApplyRMSProp")
+ .Input("var: Ref(T)")
+ .Input("ms: Ref(T)")
+ .Input("mom: Ref(T)")
+ .Input("lr: T")
+ .Input("rho: T")
+ .Input("momentum: T")
+ .Input("epsilon: T")
+ .Input("grad: T")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update '*var' according to the RMSProp algorithm.
+
+mean_square = decay * mean_square + (1-decay) * gradient ** 2
+Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
+
+ms <- rho * ms_{t-1} + (1-rho) * grad * grad
+mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
+var <- var - mom
+
+var: Should be from a Variable().
+ms: Should be from a Variable().
+mom: Should be from a Variable().
+lr: Scaling factor. Must be a scalar.
+epsilon: Ridge term. Must be a scalar.
+rho: Decay rate. Must be a scalar.
+grad: The gradient.
+out: Same as "var".
+use_locking: If True, updating of the var, m, and v tensors will be protected by
+a lock; otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
new file mode 100644
index 0000000000..7cf6c274be
--- /dev/null
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -0,0 +1,65 @@
+# Platform-specific build configurations.
+
+load("/google/protobuf/protobuf", "cc_proto_library")
+load("/google/protobuf/protobuf", "py_proto_library")
+
+# Appends a suffix to a list of deps.
+def tf_deps(deps, suffix):
+ tf_deps = []
+
+ # If the package name is in shorthand form (ie: does not contain a ':'),
+ # expand it to the full name.
+ for dep in deps:
+ tf_dep = dep
+
+ if not ":" in dep:
+ dep_pieces = dep.split("/")
+ tf_dep += ":" + dep_pieces[len(dep_pieces) - 1]
+
+ tf_deps += [tf_dep + suffix]
+
+ return tf_deps
+
+def tf_proto_library(name, srcs = [], has_services = False,
+ deps = [], visibility = [], testonly = 0,
+ cc_api_version = 2, go_api_version = 2,
+ java_api_version = 2,
+ py_api_version = 2):
+ native.filegroup(name=name + "_proto_srcs",
+ srcs=srcs + tf_deps(deps, "_proto_srcs"),
+ testonly=testonly,)
+
+ cc_proto_library(name=name + "_cc",
+ srcs=srcs + tf_deps(deps, "_proto_srcs"),
+ deps=deps,
+ cc_libs = ["//google/protobuf:protobuf"],
+ testonly=testonly,
+ visibility=visibility,)
+
+ py_proto_library(name=name + "_py",
+ srcs=srcs + tf_deps(deps, "_proto_srcs"),
+ deps=deps,
+ py_libs = ["//google/protobuf:protobuf_python"],
+ testonly=testonly,
+ visibility=visibility,)
+
+def tf_proto_library_py(name, srcs=[], deps=[], visibility=[], testonly=0):
+ py_proto_library(name = name + "_py",
+ srcs = srcs,
+ deps = deps,
+ visibility = visibility,
+ testonly = testonly)
+
+def tf_additional_lib_srcs():
+ return [
+ "platform/default/*.h",
+ "platform/default/*.cc",
+ "platform/posix/*.h",
+ "platform/posix/*.cc",
+ ]
+
+def tf_additional_test_srcs():
+ return ["platform/default/test_benchmark.cc"]
+
+def tf_kernel_tests_linkstatic():
+ return 0
diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD
new file mode 100644
index 0000000000..44dbc47ad1
--- /dev/null
+++ b/tensorflow/core/platform/default/build_config/BUILD
@@ -0,0 +1,85 @@
+# Description:
+# Platform-specific build configurations.
+
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("/tensorflow/tensorflow", "tf_copts")
+load("/tensorflow/tensorflow", "tf_cuda_library")
+
+cc_library(
+ name = "gtest",
+ testonly = 1,
+ copts = tf_copts(),
+ deps = [
+ "//external:gtest",
+ ],
+)
+
+cc_library(
+ name = "tensorflow_platform_specific",
+ copts = tf_copts(),
+ linkstatic = 1,
+ deps = [],
+)
+
+tf_cuda_library(
+ name = "stream_executor",
+ deps = [
+ "//tensorflow/stream_executor",
+ ],
+)
+
+cc_library(
+ name = "platformlib",
+ copts = tf_copts(),
+ deps = [
+ "@jpeg_archive//:jpeg",
+ "@png_archive//:png",
+ "@re2//:re2",
+ "//tensorflow/core:protos_cc",
+ ],
+)
+
+cc_library(
+ name = "protos_cc",
+ copts = tf_copts(),
+ deps = [
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
+ name = "test_main",
+ testonly = 1,
+ linkstatic = 1,
+ deps = [],
+)
+
+cc_library(
+ name = "cuda_runtime_extra",
+ linkstatic = 1,
+ deps = [],
+)
+
+filegroup(
+ name = "android_proto_lib_portable_proto",
+ srcs = [],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cuda",
+ data = [
+ "//third_party/gpus/cuda:lib64/libcudart.so.7.0",
+ ],
+ linkopts = [
+ "-Wl,-rpath,third_party/gpus/cuda/lib64",
+ ],
+ deps = [
+ "//third_party/gpus/cuda:cudart",
+ ],
+)
diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl
new file mode 100644
index 0000000000..439bf97a2c
--- /dev/null
+++ b/tensorflow/core/platform/default/build_config_root.bzl
@@ -0,0 +1,6 @@
+# Lower-level functionality for build config.
+# The functions in this file might be referred by tensorflow.bzl. They have to
+# be separate to avoid cyclic references.
+
+def tf_cuda_tests_tags():
+ return ["local"]
diff --git a/tensorflow/core/platform/default/dynamic_annotations.h b/tensorflow/core/platform/default/dynamic_annotations.h
new file mode 100644
index 0000000000..1705fb9955
--- /dev/null
+++ b/tensorflow/core/platform/default/dynamic_annotations.h
@@ -0,0 +1,9 @@
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_DYNAMIC_ANNOTATIONS_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_DYNAMIC_ANNOTATIONS_H_
+
+// Do nothing for this platform
+#define TF_ANNOTATE_MEMORY_IS_INITIALIZED(ptr, bytes) \
+ do { \
+ } while (0)
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_DYNAMIC_ANNOTATIONS_H_
diff --git a/tensorflow/core/platform/default/integral_types.h b/tensorflow/core/platform/default/integral_types.h
new file mode 100644
index 0000000000..04aae172da
--- /dev/null
+++ b/tensorflow/core/platform/default/integral_types.h
@@ -0,0 +1,18 @@
+#ifndef TENSORFLOW_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_
+#define TENSORFLOW_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_
+
+namespace tensorflow {
+
+typedef signed char int8;
+typedef short int16;
+typedef int int32;
+typedef long long int64;
+
+typedef unsigned char uint8;
+typedef unsigned short uint16;
+typedef unsigned int uint32;
+typedef unsigned long long uint64;
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_
diff --git a/tensorflow/core/platform/default/logging.cc b/tensorflow/core/platform/default/logging.cc
new file mode 100644
index 0000000000..8a16a537b0
--- /dev/null
+++ b/tensorflow/core/platform/default/logging.cc
@@ -0,0 +1,125 @@
+#include "tensorflow/core/platform/default/logging.h"
+
+#if defined(PLATFORM_POSIX_ANDROID)
+#include <android/log.h>
+#include <sstream>
+#endif
+
+#include <stdlib.h>
+
+namespace tensorflow {
+namespace internal {
+
+LogMessage::LogMessage(const char* fname, int line, int severity)
+ : fname_(fname), line_(line), severity_(severity) {}
+
+#if defined(PLATFORM_POSIX_ANDROID)
+void LogMessage::GenerateLogMessage() {
+ int android_log_level;
+ switch (severity_) {
+ case INFO:
+ android_log_level = ANDROID_LOG_INFO;
+ break;
+ case WARNING:
+ android_log_level = ANDROID_LOG_WARN;
+ break;
+ case ERROR:
+ android_log_level = ANDROID_LOG_ERROR;
+ break;
+ case FATAL:
+ android_log_level = ANDROID_LOG_FATAL;
+ break;
+ default:
+ if (severity_ < INFO) {
+ android_log_level = ANDROID_LOG_VERBOSE;
+ } else {
+ android_log_level = ANDROID_LOG_ERROR;
+ }
+ break;
+ }
+
+ std::stringstream ss;
+ ss << fname_ << ":" << line_ << " " << str();
+ __android_log_write(android_log_level, "native", ss.str().c_str());
+
+ // Android logging at level FATAL does not terminate execution, so abort()
+ // is still required to stop the program.
+ if (severity_ == FATAL) {
+ abort();
+ }
+}
+
+#else
+
+void LogMessage::GenerateLogMessage() {
+ // TODO(jeff,sanjay): For open source version, replace this with something
+ // that logs through the env or something and fill in appropriate time info.
+ fprintf(stderr, "%c %s:%d] %s\n", "IWEF"[severity_], fname_, line_,
+ str().c_str());
+}
+#endif
+
+LogMessage::~LogMessage() { GenerateLogMessage(); }
+
+LogMessageFatal::LogMessageFatal(const char* file, int line)
+ : LogMessage(file, line, FATAL) {}
+LogMessageFatal::~LogMessageFatal() {
+ // abort() ensures we don't return (we promised we would not via
+ // ATTRIBUTE_NORETURN).
+ GenerateLogMessage();
+ abort();
+}
+
+template <>
+void MakeCheckOpValueString(std::ostream* os, const char& v) {
+ if (v >= 32 && v <= 126) {
+ (*os) << "'" << v << "'";
+ } else {
+ (*os) << "char value " << (short)v;
+ }
+}
+
+template <>
+void MakeCheckOpValueString(std::ostream* os, const signed char& v) {
+ if (v >= 32 && v <= 126) {
+ (*os) << "'" << v << "'";
+ } else {
+ (*os) << "signed char value " << (short)v;
+ }
+}
+
+template <>
+void MakeCheckOpValueString(std::ostream* os, const unsigned char& v) {
+ if (v >= 32 && v <= 126) {
+ (*os) << "'" << v << "'";
+ } else {
+ (*os) << "unsigned char value " << (unsigned short)v;
+ }
+}
+
+#if LANG_CXX11
+template <>
+void MakeCheckOpValueString(std::ostream* os, const std::nullptr_t& p) {
+ (*os) << "nullptr";
+}
+#endif
+
+CheckOpMessageBuilder::CheckOpMessageBuilder(const char* exprtext)
+ : stream_(new std::ostringstream) {
+ *stream_ << "Check failed: " << exprtext << " (";
+}
+
+CheckOpMessageBuilder::~CheckOpMessageBuilder() { delete stream_; }
+
+std::ostream* CheckOpMessageBuilder::ForVar2() {
+ *stream_ << " vs. ";
+ return stream_;
+}
+
+string* CheckOpMessageBuilder::NewString() {
+ *stream_ << ")";
+ return new string(stream_->str());
+}
+
+} // namespace internal
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/default/logging.h b/tensorflow/core/platform/default/logging.h
new file mode 100644
index 0000000000..034178751e
--- /dev/null
+++ b/tensorflow/core/platform/default/logging.h
@@ -0,0 +1,258 @@
+#ifndef TENSORFLOW_PLATFORM_DEFAULT_LOGGING_H_
+#define TENSORFLOW_PLATFORM_DEFAULT_LOGGING_H_
+
+#include <sstream>
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+const int INFO = 0; // base_logging::INFO;
+const int WARNING = 1; // base_logging::WARNING;
+const int ERROR = 2; // base_logging::ERROR;
+const int FATAL = 3; // base_logging::FATAL;
+const int NUM_SEVERITIES = 4; // base_logging::NUM_SEVERITIES;
+
+namespace internal {
+
+class LogMessage : public std::basic_ostringstream<char> {
+ public:
+ LogMessage(const char* fname, int line, int severity);
+ ~LogMessage();
+
+ protected:
+ void GenerateLogMessage();
+
+ private:
+ const char* fname_;
+ int line_;
+ int severity_;
+};
+
+// LogMessageFatal ensures the process will exit in failure after
+// logging this message.
+class LogMessageFatal : public LogMessage {
+ public:
+ LogMessageFatal(const char* file, int line) TF_ATTRIBUTE_COLD;
+ ~LogMessageFatal() TF_ATTRIBUTE_NORETURN;
+};
+
+#define _TF_LOG_INFO \
+ ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::INFO)
+#define _TF_LOG_WARNING \
+ ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::WARNING)
+#define _TF_LOG_ERROR \
+ ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::ERROR)
+#define _TF_LOG_FATAL \
+ ::tensorflow::internal::LogMessageFatal(__FILE__, __LINE__)
+
+#define LOG(severity) _TF_LOG_##severity
+
+// TODO(jeff): Define a proper implementation of VLOG_IS_ON
+#define VLOG_IS_ON(lvl) ((lvl) <= 0)
+
+#define VLOG(lvl) \
+ if (VLOG_IS_ON(lvl)) \
+ ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::INFO)
+
+// CHECK dies with a fatal error if condition is not true. It is *not*
+// controlled by NDEBUG, so the check will be executed regardless of
+// compilation mode. Therefore, it is safe to do things like:
+// CHECK(fp->Write(x) == 4)
+#define CHECK(condition) \
+ if (TF_PREDICT_FALSE(!(condition))) \
+ LOG(FATAL) << "Check failed: " #condition " "
+
+// Function is overloaded for integral types to allow static const
+// integrals declared in classes and not defined to be used as arguments to
+// CHECK* macros. It's not encouraged though.
+template <typename T>
+inline const T& GetReferenceableValue(const T& t) {
+ return t;
+}
+inline char GetReferenceableValue(char t) { return t; }
+inline unsigned char GetReferenceableValue(unsigned char t) { return t; }
+inline signed char GetReferenceableValue(signed char t) { return t; }
+inline short GetReferenceableValue(short t) { return t; }
+inline unsigned short GetReferenceableValue(unsigned short t) { return t; }
+inline int GetReferenceableValue(int t) { return t; }
+inline unsigned int GetReferenceableValue(unsigned int t) { return t; }
+inline long GetReferenceableValue(long t) { return t; }
+inline unsigned long GetReferenceableValue(unsigned long t) { return t; }
+inline long long GetReferenceableValue(long long t) { return t; }
+inline unsigned long long GetReferenceableValue(unsigned long long t) {
+ return t;
+}
+
+// This formats a value for a failing CHECK_XX statement. Ordinarily,
+// it uses the definition for operator<<, with a few special cases below.
+template <typename T>
+inline void MakeCheckOpValueString(std::ostream* os, const T& v) {
+ (*os) << v;
+}
+
+// Overrides for char types provide readable values for unprintable
+// characters.
+template <>
+void MakeCheckOpValueString(std::ostream* os, const char& v);
+template <>
+void MakeCheckOpValueString(std::ostream* os, const signed char& v);
+template <>
+void MakeCheckOpValueString(std::ostream* os, const unsigned char& v);
+
+#if LANG_CXX11
+// We need an explicit specialization for std::nullptr_t.
+template <>
+void MakeCheckOpValueString(std::ostream* os, const std::nullptr_t& p);
+#endif
+
+// A container for a string pointer which can be evaluated to a bool -
+// true iff the pointer is non-NULL.
+struct CheckOpString {
+ CheckOpString(string* str) : str_(str) {}
+ // No destructor: if str_ is non-NULL, we're about to LOG(FATAL),
+ // so there's no point in cleaning up str_.
+ operator bool() const { return TF_PREDICT_FALSE(str_ != NULL); }
+ string* str_;
+};
+
+// Build the error message string. Specify no inlining for code size.
+template <typename T1, typename T2>
+string* MakeCheckOpString(const T1& v1, const T2& v2,
+ const char* exprtext) TF_ATTRIBUTE_NOINLINE;
+
+// A helper class for formatting "expr (V1 vs. V2)" in a CHECK_XX
+// statement. See MakeCheckOpString for sample usage. Other
+// approaches were considered: use of a template method (e.g.,
+// base::BuildCheckOpString(exprtext, base::Print<T1>, &v1,
+// base::Print<T2>, &v2), however this approach has complications
+// related to volatile arguments and function-pointer arguments).
+class CheckOpMessageBuilder {
+ public:
+ // Inserts "exprtext" and " (" to the stream.
+ explicit CheckOpMessageBuilder(const char* exprtext);
+ // Deletes "stream_".
+ ~CheckOpMessageBuilder();
+ // For inserting the first variable.
+ std::ostream* ForVar1() { return stream_; }
+ // For inserting the second variable (adds an intermediate " vs. ").
+ std::ostream* ForVar2();
+ // Get the result (inserts the closing ")").
+ string* NewString();
+
+ private:
+ std::ostringstream* stream_;
+};
+
+template <typename T1, typename T2>
+string* MakeCheckOpString(const T1& v1, const T2& v2, const char* exprtext) {
+ CheckOpMessageBuilder comb(exprtext);
+ MakeCheckOpValueString(comb.ForVar1(), v1);
+ MakeCheckOpValueString(comb.ForVar2(), v2);
+ return comb.NewString();
+}
+
+// Helper functions for CHECK_OP macro.
+// The (int, int) specialization works around the issue that the compiler
+// will not instantiate the template version of the function on values of
+// unnamed enum type - see comment below.
+#define TF_DEFINE_CHECK_OP_IMPL(name, op) \
+ template <typename T1, typename T2> \
+ inline string* name##Impl(const T1& v1, const T2& v2, \
+ const char* exprtext) { \
+ if (TF_PREDICT_TRUE(v1 op v2)) \
+ return NULL; \
+ else \
+ return ::tensorflow::internal::MakeCheckOpString(v1, v2, exprtext); \
+ } \
+ inline string* name##Impl(int v1, int v2, const char* exprtext) { \
+ return name##Impl<int, int>(v1, v2, exprtext); \
+ }
+
+// We use the full name Check_EQ, Check_NE, etc. in case the file including
+// base/logging.h provides its own #defines for the simpler names EQ, NE, etc.
+// This happens if, for example, those are used as token names in a
+// yacc grammar.
+TF_DEFINE_CHECK_OP_IMPL(Check_EQ,
+ == ) // Compilation error with CHECK_EQ(NULL, x)?
+TF_DEFINE_CHECK_OP_IMPL(Check_NE, != ) // Use CHECK(x == NULL) instead.
+TF_DEFINE_CHECK_OP_IMPL(Check_LE, <= )
+TF_DEFINE_CHECK_OP_IMPL(Check_LT, < )
+TF_DEFINE_CHECK_OP_IMPL(Check_GE, >= )
+TF_DEFINE_CHECK_OP_IMPL(Check_GT, > )
+#undef TF_DEFINE_CHECK_OP_IMPL
+
+// In optimized mode, use CheckOpString to hint to compiler that
+// the while condition is unlikely.
+#define CHECK_OP_LOG(name, op, val1, val2) \
+ while (::tensorflow::internal::CheckOpString _result = \
+ ::tensorflow::internal::name##Impl( \
+ ::tensorflow::internal::GetReferenceableValue(val1), \
+ ::tensorflow::internal::GetReferenceableValue(val2), \
+ #val1 " " #op " " #val2)) \
+ ::tensorflow::internal::LogMessageFatal(__FILE__, __LINE__) << *(_result.str_)
+
+#define CHECK_OP(name, op, val1, val2) CHECK_OP_LOG(name, op, val1, val2)
+
+// CHECK_EQ/NE/...
+#define CHECK_EQ(val1, val2) CHECK_OP(Check_EQ, ==, val1, val2)
+#define CHECK_NE(val1, val2) CHECK_OP(Check_NE, !=, val1, val2)
+#define CHECK_LE(val1, val2) CHECK_OP(Check_LE, <=, val1, val2)
+#define CHECK_LT(val1, val2) CHECK_OP(Check_LT, <, val1, val2)
+#define CHECK_GE(val1, val2) CHECK_OP(Check_GE, >=, val1, val2)
+#define CHECK_GT(val1, val2) CHECK_OP(Check_GT, >, val1, val2)
+#define CHECK_NOTNULL(val) \
+ ::tensorflow::internal::CheckNotNull(__FILE__, __LINE__, \
+ "'" #val "' Must be non NULL", (val))
+
+#ifndef NDEBUG
+// DCHECK_EQ/NE/...
+#define DCHECK(condition) CHECK(condition)
+#define DCHECK_EQ(val1, val2) CHECK_EQ(val1, val2)
+#define DCHECK_NE(val1, val2) CHECK_NE(val1, val2)
+#define DCHECK_LE(val1, val2) CHECK_LE(val1, val2)
+#define DCHECK_LT(val1, val2) CHECK_LT(val1, val2)
+#define DCHECK_GE(val1, val2) CHECK_GE(val1, val2)
+#define DCHECK_GT(val1, val2) CHECK_GT(val1, val2)
+
+#else
+
+#define DCHECK(condition) \
+ while (false && (condition)) LOG(FATAL)
+
+// NDEBUG is defined, so DCHECK_EQ(x, y) and so on do nothing.
+// However, we still want the compiler to parse x and y, because
+// we don't want to lose potentially useful errors and warnings.
+// _DCHECK_NOP is a helper, and should not be used outside of this file.
+#define _TF_DCHECK_NOP(x, y) \
+ while (false && ((void)(x), (void)(y), 0)) LOG(FATAL)
+
+#define DCHECK_EQ(x, y) _TF_DCHECK_NOP(x, y)
+#define DCHECK_NE(x, y) _TF_DCHECK_NOP(x, y)
+#define DCHECK_LE(x, y) _TF_DCHECK_NOP(x, y)
+#define DCHECK_LT(x, y) _TF_DCHECK_NOP(x, y)
+#define DCHECK_GE(x, y) _TF_DCHECK_NOP(x, y)
+#define DCHECK_GT(x, y) _TF_DCHECK_NOP(x, y)
+
+#endif
+
+// These are for when you don't want a CHECK failure to print a verbose
+// stack trace. The implementation of CHECK* in this file already doesn't.
+#define QCHECK(condition) CHECK(condition)
+#define QCHECK_EQ(x, y) CHECK_EQ(x, y)
+#define QCHECK_NE(x, y) CHECK_NE(x, y)
+#define QCHECK_LE(x, y) CHECK_LE(x, y)
+#define QCHECK_LT(x, y) CHECK_LT(x, y)
+#define QCHECK_GE(x, y) CHECK_GE(x, y)
+#define QCHECK_GT(x, y) CHECK_GT(x, y)
+
+template <typename T>
+T&& CheckNotNull(const char* file, int line, const char* exprtext, T&& t) {
+ if (t == nullptr) {
+ LogMessageFatal(file, line) << string(exprtext);
+ }
+ return std::forward<T>(t);
+}
+
+} // namespace internal
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PLATFORM_DEFAULT_LOGGING_H_
diff --git a/tensorflow/core/platform/default/mutex.h b/tensorflow/core/platform/default/mutex.h
new file mode 100644
index 0000000000..b26b418e1b
--- /dev/null
+++ b/tensorflow/core/platform/default/mutex.h
@@ -0,0 +1,33 @@
+#ifndef TENSORFLOW_PLATFORM_DEFAULT_MUTEX_H_
+#define TENSORFLOW_PLATFORM_DEFAULT_MUTEX_H_
+
+#include <chrono>
+#include <condition_variable>
+#include <mutex>
+
+namespace tensorflow {
+
+enum LinkerInitialized { LINKER_INITIALIZED };
+
+// A class that wraps around the std::mutex implementation, only adding an
+// additional LinkerInitialized constructor interface.
+class mutex : public std::mutex {
+ public:
+ mutex() {}
+ // The default implementation of std::mutex is safe to use after the linker
+ // initializations
+ explicit mutex(LinkerInitialized x) {}
+};
+
+using std::condition_variable;
+typedef std::unique_lock<std::mutex> mutex_lock;
+
+inline ConditionResult WaitForMilliseconds(mutex_lock* mu,
+ condition_variable* cv, int64 ms) {
+ std::cv_status s = cv->wait_for(*mu, std::chrono::milliseconds(ms));
+ return (s == std::cv_status::timeout) ? kCond_Timeout : kCond_MaybeNotified;
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PLATFORM_DEFAULT_MUTEX_H_
diff --git a/tensorflow/core/platform/default/protobuf.h b/tensorflow/core/platform/default/protobuf.h
new file mode 100644
index 0000000000..f6083c318d
--- /dev/null
+++ b/tensorflow/core/platform/default/protobuf.h
@@ -0,0 +1,13 @@
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_PROTOBUF_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_PROTOBUF_H_
+
+#include "google/protobuf/descriptor.h"
+#include "google/protobuf/io/coded_stream.h"
+#include "google/protobuf/io/zero_copy_stream.h"
+#include "google/protobuf/text_format.h"
+
+namespace tensorflow {
+namespace protobuf = ::google::protobuf;
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_DEFAULT_PROTOBUF_H_
diff --git a/tensorflow/core/platform/default/stream_executor_util.h b/tensorflow/core/platform/default/stream_executor_util.h
new file mode 100644
index 0000000000..d7fad4e233
--- /dev/null
+++ b/tensorflow/core/platform/default/stream_executor_util.h
@@ -0,0 +1,19 @@
+#ifndef TENSORFLOW_PLATFORM_DEFAULT_STREAM_EXECUTOR_UTIL_H_
+#define TENSORFLOW_PLATFORM_DEFAULT_STREAM_EXECUTOR_UTIL_H_
+
+#include "tensorflow/stream_executor/lib/status.h"
+
+namespace tensorflow {
+
+namespace gpu = ::perftools::gputools;
+
+// On the open-source platform, stream_executor currently uses
+// tensorflow::Status
+inline Status FromStreamExecutorStatus(
+ const perftools::gputools::port::Status& s) {
+ return s;
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PLATFORM_DEFAULT_STREAM_EXECUTOR_UTIL_H_
diff --git a/tensorflow/core/platform/default/test_benchmark.cc b/tensorflow/core/platform/default/test_benchmark.cc
new file mode 100644
index 0000000000..4004bf026b
--- /dev/null
+++ b/tensorflow/core/platform/default/test_benchmark.cc
@@ -0,0 +1,162 @@
+#include "tensorflow/core/platform/test_benchmark.h"
+
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace testing {
+
+static std::vector<Benchmark*>* all_benchmarks = nullptr;
+static std::string label;
+static int64 bytes_processed;
+static int64 items_processed;
+static int64 accum_time = 0;
+static int64 start_time = 0;
+static Env* env;
+
+Benchmark::Benchmark(const char* name, void (*fn)(int))
+ : name_(name), num_args_(0), fn0_(fn) {
+ args_.push_back(-1);
+ Register();
+}
+
+Benchmark::Benchmark(const char* name, void (*fn)(int, int))
+ : name_(name), num_args_(1), fn1_(fn) {
+ Register();
+}
+
+Benchmark* Benchmark::Arg(int x) {
+ CHECK_EQ(num_args_, 1);
+ args_.push_back(x);
+ return this;
+}
+
+Benchmark* Benchmark::Range(int lo, int hi) {
+ Arg(lo);
+ for (int32 i = 1; i < kint32max / 8 && i < hi; i *= 8) {
+ Arg(i);
+ }
+ if (lo != hi) Arg(hi);
+ return this;
+}
+
+void Benchmark::Run(const char* pattern) {
+ if (!all_benchmarks) return;
+
+ if (StringPiece(pattern) == "all") {
+ pattern = ".*";
+ }
+
+ // Compute name width.
+ int width = 10;
+ string name;
+ for (auto b : *all_benchmarks) {
+ name = b->name_;
+ for (auto arg : b->args_) {
+ name.resize(b->name_.size());
+ if (arg >= 0) {
+ strings::StrAppend(&name, "/", arg);
+ }
+ if (RE2::PartialMatch(name, pattern)) {
+ width = std::max<int>(width, name.size());
+ }
+ }
+ }
+
+ printf("%-*s %10s %10s\n", width, "Benchmark", "Time(ns)", "Iterations");
+ printf("%s\n", string(width + 22, '-').c_str());
+ for (auto b : *all_benchmarks) {
+ name = b->name_;
+ for (auto arg : b->args_) {
+ name.resize(b->name_.size());
+ if (arg >= 0) {
+ strings::StrAppend(&name, "/", arg);
+ }
+ if (!RE2::PartialMatch(name, pattern)) {
+ continue;
+ }
+
+ int iters;
+ double seconds;
+ b->Run(arg, &iters, &seconds);
+
+ char buf[100];
+ std::string full_label = label;
+ if (bytes_processed > 0) {
+ snprintf(buf, sizeof(buf), " %.1fMB/s",
+ (bytes_processed * 1e-6) / seconds);
+ full_label += buf;
+ }
+ if (items_processed > 0) {
+ snprintf(buf, sizeof(buf), " %.1fM items/s",
+ (items_processed * 1e-6) / seconds);
+ full_label += buf;
+ }
+ printf("%-*s %10.0f %10d\t%s\n", width, name.c_str(),
+ seconds * 1e9 / iters, iters, full_label.c_str());
+ }
+ }
+}
+
+void Benchmark::Register() {
+ if (!all_benchmarks) all_benchmarks = new std::vector<Benchmark*>;
+ all_benchmarks->push_back(this);
+}
+
+void Benchmark::Run(int arg, int* run_count, double* run_seconds) {
+ env = Env::Default();
+ static const int64 kMinIters = 100;
+ static const int64 kMaxIters = 1000000000;
+ static const double kMinTime = 0.5;
+ int64 iters = kMinIters;
+ while (true) {
+ accum_time = 0;
+ start_time = env->NowMicros();
+ bytes_processed = -1;
+ items_processed = -1;
+ label.clear();
+ if (fn0_) {
+ (*fn0_)(iters);
+ } else {
+ (*fn1_)(iters, arg);
+ }
+ StopTiming();
+ const double seconds = accum_time * 1e-6;
+ if (seconds >= kMinTime || iters >= kMaxIters) {
+ *run_count = iters;
+ *run_seconds = seconds;
+ return;
+ }
+
+ // Update number of iterations. Overshoot by 40% in an attempt
+ // to succeed the next time.
+ double multiplier = 1.4 * kMinTime / std::max(seconds, 1e-9);
+ multiplier = std::min(10.0, multiplier);
+ if (multiplier <= 1.0) multiplier *= 2.0;
+ iters = std::max<int64>(multiplier * iters, iters + 1);
+ iters = std::min(iters, kMaxIters);
+ }
+}
+
+// TODO(vrv): Add support for running a subset of benchmarks by having
+// RunBenchmarks take in a spec (and maybe other options such as
+// benchmark_min_time, etc).
+void RunBenchmarks() { Benchmark::Run("all"); }
+void SetLabel(const std::string& l) { label = l; }
+void BytesProcessed(int64 n) { bytes_processed = n; }
+void ItemsProcessed(int64 n) { items_processed = n; }
+void StartTiming() {
+ if (start_time == 0) start_time = env->NowMicros();
+}
+void StopTiming() {
+ if (start_time != 0) {
+ accum_time += (env->NowMicros() - start_time);
+ start_time = 0;
+ }
+}
+void UseRealTime() {}
+
+} // namespace testing
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/default/thread_annotations.h b/tensorflow/core/platform/default/thread_annotations.h
new file mode 100644
index 0000000000..fed39bf810
--- /dev/null
+++ b/tensorflow/core/platform/default/thread_annotations.h
@@ -0,0 +1,185 @@
+// Copyright (c) 2008, Google Inc.
+// All rights reserved.
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+// ---
+//
+// This header file contains the macro definitions for thread safety
+// annotations that allow the developers to document the locking policies
+// of their multi-threaded code. The annotations can also help program
+// analysis tools to identify potential thread safety issues.
+//
+// The primary documentation on these annotations is external:
+// http://clang.llvm.org/docs/ThreadSafetyAnalysis.html
+//
+// The annotations are implemented using compiler attributes.
+// Using the macros defined here instead of the raw attributes allows
+// for portability and future compatibility.
+//
+// When referring to mutexes in the arguments of the attributes, you should
+// use variable names or more complex expressions (e.g. my_object->mutex_)
+// that evaluate to a concrete mutex object whenever possible. If the mutex
+// you want to refer to is not in scope, you may use a member pointer
+// (e.g. &MyClass::mutex_) to refer to a mutex in some (unknown) object.
+//
+
+#ifndef TENSORFLOW_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_
+#define TENSORFLOW_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_
+
+#if defined(__clang__) && (!defined(SWIG))
+#define THREAD_ANNOTATION_ATTRIBUTE__(x) __attribute__((x))
+#else
+#define THREAD_ANNOTATION_ATTRIBUTE__(x) // no-op
+#endif
+
+// Document if a shared variable/field needs to be protected by a mutex.
+// GUARDED_BY allows the user to specify a particular mutex that should be
+// held when accessing the annotated variable. GUARDED_VAR indicates that
+// a shared variable is guarded by some unspecified mutex, for use in rare
+// cases where a valid mutex expression cannot be specified.
+#define GUARDED_BY(x) THREAD_ANNOTATION_ATTRIBUTE__(guarded_by(x))
+#define GUARDED_VAR THREAD_ANNOTATION_ATTRIBUTE__(guarded)
+
+// Document if the memory location pointed to by a pointer should be guarded
+// by a mutex when dereferencing the pointer. PT_GUARDED_VAR is analogous to
+// GUARDED_VAR. Note that a pointer variable to a shared memory location
+// could itself be a shared variable. For example, if a shared global pointer
+// q, which is guarded by mu1, points to a shared memory location that is
+// guarded by mu2, q should be annotated as follows:
+// int *q GUARDED_BY(mu1) PT_GUARDED_BY(mu2);
+#define PT_GUARDED_BY(x) THREAD_ANNOTATION_ATTRIBUTE__(pt_guarded_by(x))
+#define PT_GUARDED_VAR THREAD_ANNOTATION_ATTRIBUTE__(pt_guarded)
+
+// Document the acquisition order between locks that can be held
+// simultaneously by a thread. For any two locks that need to be annotated
+// to establish an acquisition order, only one of them needs the annotation.
+// (i.e. You don't have to annotate both locks with both ACQUIRED_AFTER
+// and ACQUIRED_BEFORE.)
+#define ACQUIRED_AFTER(...) \
+ THREAD_ANNOTATION_ATTRIBUTE__(acquired_after(__VA_ARGS__))
+
+#define ACQUIRED_BEFORE(...) \
+ THREAD_ANNOTATION_ATTRIBUTE__(acquired_before(__VA_ARGS__))
+
+// Document a function that expects a mutex to be held prior to entry.
+// The mutex is expected to be held both on entry to and exit from the
+// function.
+#define EXCLUSIVE_LOCKS_REQUIRED(...) \
+ THREAD_ANNOTATION_ATTRIBUTE__(exclusive_locks_required(__VA_ARGS__))
+
+#define SHARED_LOCKS_REQUIRED(...) \
+ THREAD_ANNOTATION_ATTRIBUTE__(shared_locks_required(__VA_ARGS__))
+
+// Document the locks acquired in the body of the function. These locks
+// cannot be held when calling this function (for instance, when the
+// mutex implementation is non-reentrant).
+#define LOCKS_EXCLUDED(...) \
+ THREAD_ANNOTATION_ATTRIBUTE__(locks_excluded(__VA_ARGS__))
+
+// Document a function that returns a mutex without acquiring it. For example,
+// a public getter method that returns a pointer to a private mutex should
+// be annotated with LOCK_RETURNED.
+#define LOCK_RETURNED(x) THREAD_ANNOTATION_ATTRIBUTE__(lock_returned(x))
+
+// Document if a class/type is a lockable type (such as the Mutex class).
+#define LOCKABLE THREAD_ANNOTATION_ATTRIBUTE__(lockable)
+
+// Document if a class does RAII locking (such as the MutexLock class).
+// The constructor should use LOCK_FUNCTION to specify the mutex that is
+// acquired, and the destructor should use UNLOCK_FUNCTION with no arguments;
+// the analysis will assume that the destructor unlocks whatever the
+// constructor locked.
+#define SCOPED_LOCKABLE THREAD_ANNOTATION_ATTRIBUTE__(scoped_lockable)
+
+// Document functions that acquire a lock in the body of a function, and do
+// not release it.
+#define EXCLUSIVE_LOCK_FUNCTION(...) \
+ THREAD_ANNOTATION_ATTRIBUTE__(exclusive_lock_function(__VA_ARGS__))
+
+#define SHARED_LOCK_FUNCTION(...) \
+ THREAD_ANNOTATION_ATTRIBUTE__(shared_lock_function(__VA_ARGS__))
+
+// Document functions that expect a lock to be held on entry to the function,
+// and release it in the body of the function.
+#define UNLOCK_FUNCTION(...) \
+ THREAD_ANNOTATION_ATTRIBUTE__(unlock_function(__VA_ARGS__))
+
+// Document functions that try to acquire a lock, and return success or failure
+// (or a non-boolean value that can be interpreted as a boolean).
+// The first argument should be true for functions that return true on success,
+// or false for functions that return false on success. The second argument
+// specifies the mutex that is locked on success. If unspecified, it is assumed
+// to be 'this'.
+#define EXCLUSIVE_TRYLOCK_FUNCTION(...) \
+ THREAD_ANNOTATION_ATTRIBUTE__(exclusive_trylock_function(__VA_ARGS__))
+
+#define SHARED_TRYLOCK_FUNCTION(...) \
+ THREAD_ANNOTATION_ATTRIBUTE__(shared_trylock_function(__VA_ARGS__))
+
+// Document functions that dynamically check to see if a lock is held, and fail
+// if it is not held.
+#define ASSERT_EXCLUSIVE_LOCK(...) \
+ THREAD_ANNOTATION_ATTRIBUTE__(assert_exclusive_lock(__VA_ARGS__))
+
+#define ASSERT_SHARED_LOCK(...) \
+ THREAD_ANNOTATION_ATTRIBUTE__(assert_shared_lock(__VA_ARGS__))
+
+// Turns off thread safety checking within the body of a particular function.
+// This is used as an escape hatch for cases where either (a) the function
+// is correct, but the locking is more complicated than the analyzer can handle,
+// or (b) the function contains race conditions that are known to be benign.
+#define NO_THREAD_SAFETY_ANALYSIS \
+ THREAD_ANNOTATION_ATTRIBUTE__(no_thread_safety_analysis)
+
+// TS_UNCHECKED should be placed around lock expressions that are not valid
+// C++ syntax, but which are present for documentation purposes. These
+// annotations will be ignored by the analysis.
+#define TS_UNCHECKED(x) ""
+
+// Disables warnings for a single read operation. This can be used to do racy
+// reads of guarded data members, in cases where the race is benign.
+#define TS_UNCHECKED_READ(x) \
+ ::tensorflow::thread_safety_analysis::ts_unchecked_read(x)
+
+namespace tensorflow {
+namespace thread_safety_analysis {
+
+// Takes a reference to a guarded data member, and returns an unguarded
+// reference.
+template <class T>
+inline const T& ts_unchecked_read(const T& v) NO_THREAD_SAFETY_ANALYSIS {
+ return v;
+}
+
+template <class T>
+inline T& ts_unchecked_read(T& v) NO_THREAD_SAFETY_ANALYSIS {
+ return v;
+}
+} // namespace thread_safety_analysis
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_
diff --git a/tensorflow/core/platform/default/tracing.cc b/tensorflow/core/platform/default/tracing.cc
new file mode 100644
index 0000000000..a4ddfad928
--- /dev/null
+++ b/tensorflow/core/platform/default/tracing.cc
@@ -0,0 +1,37 @@
+#include "tensorflow/core/platform/tracing.h"
+
+#include <unistd.h>
+
+namespace tensorflow {
+namespace port {
+
+void Tracing::RegisterEvent(EventCategory id, const char* name) {
+ // TODO(opensource): implement
+}
+
+void Tracing::Initialize() {}
+
+static bool TryGetEnv(const char* name, const char** value) {
+ *value = getenv(name);
+ return *value != nullptr && (*value)[0] != '\0';
+}
+
+const char* Tracing::LogDir() {
+ const char* dir;
+ if (TryGetEnv("TEST_TMPDIR", &dir)) return dir;
+ if (TryGetEnv("TMP", &dir)) return dir;
+ if (TryGetEnv("TMPDIR", &dir)) return dir;
+ dir = "/tmp";
+ if (access(dir, R_OK | W_OK | X_OK) == 0) return dir;
+ return "."; // Default to current directory.
+}
+
+static bool DoInit() {
+ Tracing::Initialize();
+ return true;
+}
+
+static const bool dummy = DoInit();
+
+} // namespace port
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/default/tracing_impl.h b/tensorflow/core/platform/default/tracing_impl.h
new file mode 100644
index 0000000000..e2f5d3cb3f
--- /dev/null
+++ b/tensorflow/core/platform/default/tracing_impl.h
@@ -0,0 +1,44 @@
+#ifndef TENSORFLOW_PLATFORM_DEFAULT_TRACING_IMPL_H_
+#define TENSORFLOW_PLATFORM_DEFAULT_TRACING_IMPL_H_
+
+// Stub implementations of tracing functionality.
+
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/tracing.h"
+
+namespace tensorflow {
+namespace port {
+
+// Definitions that do nothing for platforms that don't have underlying thread
+// tracing support.
+#define TRACELITERAL(a) \
+ do { \
+ } while (0)
+#define TRACESTRING(s) \
+ do { \
+ } while (0)
+#define TRACEPRINTF(format, ...) \
+ do { \
+ } while (0)
+
+inline uint64 Tracing::UniqueId() { return random::New64(); }
+inline bool Tracing::IsActive() { return false; }
+inline void Tracing::RegisterCurrentThread(const char* name) {}
+
+// Posts an atomic threadscape event with the supplied category and arg.
+inline void Tracing::RecordEvent(EventCategory category, uint64 arg) {
+ // TODO(opensource): Implement
+}
+
+inline Tracing::ScopedActivity::ScopedActivity(EventCategory category,
+ uint64 arg)
+ : enabled_(false), region_id_(category_id_[category]) {}
+
+inline Tracing::ScopedActivity::~ScopedActivity() {}
+
+} // namespace port
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PLATFORM_DEFAULT_TRACING_IMPL_H_
diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc
new file mode 100644
index 0000000000..3e3c0ad74e
--- /dev/null
+++ b/tensorflow/core/platform/env.cc
@@ -0,0 +1,129 @@
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+Env::~Env() {}
+
+RandomAccessFile::~RandomAccessFile() {}
+
+WritableFile::~WritableFile() {}
+
+Thread::~Thread() {}
+
+EnvWrapper::~EnvWrapper() {}
+
+Status ReadFileToString(Env* env, const string& fname, string* data) {
+ data->clear();
+ RandomAccessFile* file;
+ Status s = env->NewRandomAccessFile(fname, &file);
+ if (!s.ok()) {
+ return s;
+ }
+ int64 offset = 0;
+ static const int kBufferSize = 8192;
+ char* space = new char[kBufferSize];
+ while (true) {
+ StringPiece fragment;
+ s = file->Read(offset, kBufferSize, &fragment, space);
+ if (!s.ok()) {
+ if (errors::IsOutOfRange(s)) { // No more bytes, but not an error
+ s = Status::OK();
+ data->append(fragment.data(), fragment.size());
+ }
+ break;
+ }
+ offset += fragment.size();
+ data->append(fragment.data(), fragment.size());
+ if (fragment.empty()) {
+ break;
+ }
+ }
+ delete[] space;
+ delete file;
+ return s;
+}
+
+Status WriteStringToFile(Env* env, const string& fname,
+ const StringPiece& data) {
+ WritableFile* file;
+ Status s = env->NewWritableFile(fname, &file);
+ if (!s.ok()) {
+ return s;
+ }
+ s = file->Append(data);
+ if (s.ok()) {
+ s = file->Close();
+ }
+ delete file;
+ return s;
+}
+
+// A ZeroCopyInputStream on a RandomAccessFile.
+namespace {
+class FileStream : public ::tensorflow::protobuf::io::ZeroCopyInputStream {
+ public:
+ explicit FileStream(RandomAccessFile* file) : file_(file), pos_(0) {}
+
+ void BackUp(int count) override { pos_ -= count; }
+ bool Skip(int count) override {
+ pos_ += count;
+ return true;
+ }
+ int64 ByteCount() const override { return pos_; }
+ Status status() const { return status_; }
+
+ bool Next(const void** data, int* size) override {
+ StringPiece result;
+ Status s = file_->Read(pos_, kBufSize, &result, scratch_);
+ if (result.empty()) {
+ status_ = s;
+ return false;
+ }
+ pos_ += result.size();
+ *data = result.data();
+ *size = result.size();
+ return true;
+ }
+
+ private:
+ static const int kBufSize = 512 << 10;
+
+ RandomAccessFile* file_;
+ int64 pos_;
+ Status status_;
+ char scratch_[kBufSize];
+};
+
+} // namespace
+
+Status ReadBinaryProto(Env* env, const string& fname,
+ ::tensorflow::protobuf::MessageLite* proto) {
+ RandomAccessFile* file;
+ auto s = env->NewRandomAccessFile(fname, &file);
+ if (!s.ok()) {
+ return s;
+ }
+ std::unique_ptr<RandomAccessFile> file_holder(file);
+ std::unique_ptr<FileStream> stream(new FileStream(file));
+
+ // TODO(jiayq): the following coded stream is for debugging purposes to allow
+ // one to parse arbitrarily large messages for MessageLite. One most likely
+ // doesn't want to put protobufs larger than 64MB on Android, so we should
+ // eventually remove this and quit loud when a large protobuf is passed in.
+ ::tensorflow::protobuf::io::CodedInputStream coded_stream(stream.get());
+ // Total bytes hard limit / warning limit are set to 1GB and 512MB
+ // respectively.
+ coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
+
+ if (!proto->ParseFromCodedStream(&coded_stream)) {
+ s = stream->status();
+ if (s.ok()) {
+ s = Status(error::DATA_LOSS, "Parse error");
+ }
+ }
+ return s;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
new file mode 100644
index 0000000000..be15c4a5cb
--- /dev/null
+++ b/tensorflow/core/platform/env_test.cc
@@ -0,0 +1,31 @@
+#include "tensorflow/core/public/env.h"
+
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+struct EnvTest {};
+
+TEST(EnvTest, ReadFileToString) {
+ Env* env = Env::Default();
+ const string dir = testing::TmpDir();
+ for (const int length : {0, 1, 1212, 2553, 4928, 8196, 9000}) {
+ const string filename = io::JoinPath(dir, strings::StrCat("file", length));
+
+ // Write a file with the given length
+ string input(length, 0);
+ for (int i = 0; i < length; i++) input[i] = i;
+ WriteStringToFile(env, filename, input);
+
+ // Read the file back and check equality
+ string output;
+ TF_CHECK_OK(ReadFileToString(env, filename, &output));
+ CHECK_EQ(length, output.size());
+ CHECK_EQ(input, output);
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/init_main.h b/tensorflow/core/platform/init_main.h
new file mode 100644
index 0000000000..ce3d1fbc2f
--- /dev/null
+++ b/tensorflow/core/platform/init_main.h
@@ -0,0 +1,16 @@
+#ifndef TENSORFLOW_PLATFORM_INIT_MAIN_H_
+#define TENSORFLOW_PLATFORM_INIT_MAIN_H_
+
+namespace tensorflow {
+namespace port {
+
+// Platform-specific initialization routine that may be invoked by a
+// main() program that uses TensorFlow.
+//
+// Default implementation does nothing.
+void InitMain(const char* usage, int* argc, char*** argv);
+
+} // namespace port
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PLATFORM_INIT_MAIN_H_
diff --git a/tensorflow/core/platform/integral_types_test.cc b/tensorflow/core/platform/integral_types_test.cc
new file mode 100644
index 0000000000..067787a9f4
--- /dev/null
+++ b/tensorflow/core/platform/integral_types_test.cc
@@ -0,0 +1,33 @@
+#include "tensorflow/core/platform/port.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+TEST(IntegralTypes, Basic) {
+ EXPECT_EQ(1, sizeof(int8));
+ EXPECT_EQ(2, sizeof(int16));
+ EXPECT_EQ(4, sizeof(int32));
+ EXPECT_EQ(8, sizeof(int64));
+
+ EXPECT_EQ(1, sizeof(uint8));
+ EXPECT_EQ(2, sizeof(uint16));
+ EXPECT_EQ(4, sizeof(uint32));
+ EXPECT_EQ(8, sizeof(uint64));
+}
+
+TEST(IntegralTypes, MinAndMaxConstants) {
+ EXPECT_EQ(static_cast<uint8>(kint8min), static_cast<uint8>(kint8max) + 1);
+ EXPECT_EQ(static_cast<uint16>(kint16min), static_cast<uint16>(kint16max) + 1);
+ EXPECT_EQ(static_cast<uint32>(kint32min), static_cast<uint32>(kint32max) + 1);
+ EXPECT_EQ(static_cast<uint64>(kint64min), static_cast<uint64>(kint64max) + 1);
+
+ EXPECT_EQ(0, static_cast<uint8>(kuint8max + 1));
+ EXPECT_EQ(0, static_cast<uint16>(kuint16max + 1));
+ EXPECT_EQ(0, static_cast<uint32>(kuint32max + 1));
+ EXPECT_EQ(0, static_cast<uint64>(kuint64max + 1));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/logging.h b/tensorflow/core/platform/logging.h
new file mode 100644
index 0000000000..66caf22ede
--- /dev/null
+++ b/tensorflow/core/platform/logging.h
@@ -0,0 +1,12 @@
+#ifndef TENSORFLOW_PLATFORM_LOGGING_H_
+#define TENSORFLOW_PLATFORM_LOGGING_H_
+
+#include "tensorflow/core/platform/port.h" // To pick up PLATFORM_define
+
+#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
+#include "base/logging.h"
+#else
+#include "tensorflow/core/platform/default/logging.h"
+#endif
+
+#endif // TENSORFLOW_PLATFORM_LOGGING_H_
diff --git a/tensorflow/core/platform/logging_test.cc b/tensorflow/core/platform/logging_test.cc
new file mode 100644
index 0000000000..03d734ae95
--- /dev/null
+++ b/tensorflow/core/platform/logging_test.cc
@@ -0,0 +1,76 @@
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TEST(Logging, Log) {
+ LOG(INFO) << "Hello";
+ LOG(INFO) << "Another log message";
+ LOG(ERROR) << "Error message";
+ VLOG(1) << "A VLOG message";
+ VLOG(2) << "A higher VLOG message";
+}
+
+TEST(Logging, CheckChecks) {
+ CHECK(true);
+ CHECK(7 > 5);
+ string a("abc");
+ string b("xyz");
+ CHECK_EQ(a, a);
+ CHECK_NE(a, b);
+ CHECK_EQ(3, 3);
+ CHECK_NE(4, 3);
+ CHECK_GT(4, 3);
+ CHECK_GE(3, 3);
+ CHECK_LT(2, 3);
+ CHECK_LE(2, 3);
+
+ DCHECK(true);
+ DCHECK(7 > 5);
+ DCHECK_EQ(a, a);
+ DCHECK_NE(a, b);
+ DCHECK_EQ(3, 3);
+ DCHECK_NE(4, 3);
+ DCHECK_GT(4, 3);
+ DCHECK_GE(3, 3);
+ DCHECK_LT(2, 3);
+ DCHECK_LE(2, 3);
+}
+
+TEST(LoggingDeathTest, FailedChecks) {
+ string a("abc");
+ string b("xyz");
+ const char* p_const = "hello there";
+ const char* p_null_const = nullptr;
+ char mybuf[10];
+ char* p_non_const = mybuf;
+ char* p_null = nullptr;
+ CHECK_NOTNULL(p_const);
+ CHECK_NOTNULL(p_non_const);
+
+ ASSERT_DEATH(CHECK(false), "false");
+ ASSERT_DEATH(CHECK(9 < 7), "9 < 7");
+ ASSERT_DEATH(CHECK_EQ(a, b), "a == b");
+ ASSERT_DEATH(CHECK_EQ(3, 4), "3 == 4");
+ ASSERT_DEATH(CHECK_NE(3, 3), "3 != 3");
+ ASSERT_DEATH(CHECK_GT(2, 3), "2 > 3");
+ ASSERT_DEATH(CHECK_GE(2, 3), "2 >= 3");
+ ASSERT_DEATH(CHECK_LT(3, 2), "3 < 2");
+ ASSERT_DEATH(CHECK_LE(3, 2), "3 <= 2");
+ ASSERT_DEATH(CHECK(false), "false");
+ ASSERT_DEATH(printf("%s", CHECK_NOTNULL(p_null)), "Must be non NULL");
+ ASSERT_DEATH(printf("%s", CHECK_NOTNULL(p_null_const)), "Must be non NULL");
+#ifndef NDEBUG
+ ASSERT_DEATH(DCHECK(9 < 7), "9 < 7");
+ ASSERT_DEATH(DCHECK(9 < 7), "9 < 7");
+ ASSERT_DEATH(DCHECK_EQ(a, b), "a == b");
+ ASSERT_DEATH(DCHECK_EQ(3, 4), "3 == 4");
+ ASSERT_DEATH(DCHECK_NE(3, 3), "3 != 3");
+ ASSERT_DEATH(DCHECK_GT(2, 3), "2 > 3");
+ ASSERT_DEATH(DCHECK_GE(2, 3), "2 >= 3");
+ ASSERT_DEATH(DCHECK_LT(3, 2), "3 < 2");
+ ASSERT_DEATH(DCHECK_LE(3, 2), "3 <= 2");
+#endif
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/port.h b/tensorflow/core/platform/port.h
new file mode 100644
index 0000000000..fef20f7753
--- /dev/null
+++ b/tensorflow/core/platform/port.h
@@ -0,0 +1,228 @@
+#ifndef TENSORFLOW_PLATFORM_PORT_H_
+#define TENSORFLOW_PLATFORM_PORT_H_
+
+#include <string>
+#include <vector>
+
+#if !defined(PLATFORM_POSIX) && !defined(PLATFORM_GOOGLE) && \
+ !defined(PLATFORM_POSIX_ANDROID) && !defined(PLATFORM_GOOGLE_ANDROID)
+
+// Choose which platform we are on.
+#if defined(ANDROID) || defined(__ANDROID__)
+#define PLATFORM_POSIX_ANDROID
+#elif defined(__APPLE__)
+#define PLATFORM_POSIX
+#else
+// If no platform specified, use:
+#define PLATFORM_POSIX
+#endif
+
+#endif
+
+// Define tensorflow::string to refer to appropriate platform specific type.
+namespace tensorflow {
+#if defined(PLATFORM_GOOGLE)
+using ::string;
+#else
+using std::string;
+#endif
+} // namespace tensorflow
+
+namespace tensorflow {
+enum ConditionResult { kCond_Timeout, kCond_MaybeNotified };
+} // namespace tensorflow
+
+// Include appropriate platform-dependent implementations of mutex etc.
+#if defined(PLATFORM_GOOGLE)
+#include "tensorflow/core/platform/google/integral_types.h"
+#include "tensorflow/core/platform/google/mutex.h"
+#include "tensorflow/core/platform/google/dynamic_annotations.h"
+#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \
+ defined(PLATFORM_GOOGLE_ANDROID)
+#include "tensorflow/core/platform/default/integral_types.h"
+#include "tensorflow/core/platform/default/mutex.h"
+#include "tensorflow/core/platform/default/dynamic_annotations.h"
+#else
+#error Define the appropriate PLATFORM_<foo> macro for this platform
+#endif
+
+namespace tensorflow {
+
+static const uint8 kuint8max = ((uint8)0xFF);
+static const uint16 kuint16max = ((uint16)0xFFFF);
+static const uint32 kuint32max = ((uint32)0xFFFFFFFF);
+static const uint64 kuint64max = ((uint64)0xFFFFFFFFFFFFFFFFull);
+static const int8 kint8min = ((int8)~0x7F);
+static const int8 kint8max = ((int8)0x7F);
+static const int16 kint16min = ((int16)~0x7FFF);
+static const int16 kint16max = ((int16)0x7FFF);
+static const int32 kint32min = ((int32)~0x7FFFFFFF);
+static const int32 kint32max = ((int32)0x7FFFFFFF);
+static const int64 kint64min = ((int64)~0x7FFFFFFFFFFFFFFFll);
+static const int64 kint64max = ((int64)0x7FFFFFFFFFFFFFFFll);
+
+// A typedef for a uint64 used as a short fingerprint.
+typedef uint64 Fprint;
+
+// The mutex library included above defines:
+// class mutex;
+// class mutex_lock;
+// class condition_variable;
+// It also defines the following:
+
+// Like "cv->wait(*mu)", except that it only waits for up to "ms" milliseconds.
+//
+// Returns kCond_Timeout if the timeout expired without this
+// thread noticing a signal on the condition variable. Otherwise may
+// return either kCond_Timeout or kCond_MaybeNotified
+ConditionResult WaitForMilliseconds(mutex_lock* mu, condition_variable* cv,
+ int64 ms);
+} // namespace tensorflow
+
+namespace tensorflow {
+namespace port {
+
+// TODO(jeff,sanjay): Make portable
+static const bool kLittleEndian = true;
+
+// TODO(jeff,sanjay): Find appropriate places for all the code below.
+// Possible places for any particular item below:
+// (a) Here, so it gets reimplemented on every platform
+// (b) Env
+// (c) config.h (auto-generated by autotools?)
+// (d) macros.h
+// ...
+
+// Return the hostname of the machine on which this process is running
+string Hostname();
+
+// Returns an estimate of the number of schedulable CPUs for this
+// process. Usually, it's constant throughout the lifetime of a
+// process, but it might change if the underlying cluster management
+// software can change it dynamically.
+int NumSchedulableCPUs();
+
+// Some platforms require that filenames be of a certain form when
+// used for logging. This function is invoked to allow platforms to
+// adjust the filename used for logging appropriately, if necessary
+// (most ports can just do nothing). If any changes are necessary, the
+// implementation should mutate "*filename" appropriately.
+void AdjustFilenameForLogging(string* filename);
+
+// Aligned allocation/deallocation
+void* aligned_malloc(size_t size, int minimum_alignment);
+void aligned_free(void* aligned_memory);
+
+// Prefetching support
+//
+// Defined behavior on some of the uarchs:
+// PREFETCH_HINT_T0:
+// prefetch to all levels of the hierarchy (except on p4: prefetch to L2)
+// PREFETCH_HINT_NTA:
+// p4: fetch to L2, but limit to 1 way (out of the 8 ways)
+// core: skip L2, go directly to L1
+// k8 rev E and later: skip L2, can go to either of the 2-ways in L1
+enum PrefetchHint {
+ PREFETCH_HINT_T0 = 3, // More temporal locality
+ PREFETCH_HINT_T1 = 2,
+ PREFETCH_HINT_T2 = 1, // Less temporal locality
+ PREFETCH_HINT_NTA = 0 // No temporal locality
+};
+template <PrefetchHint hint>
+void prefetch(const void* x);
+
+// Snappy compression/decompression support
+bool Snappy_Compress(const char* input, size_t length, string* output);
+
+bool Snappy_GetUncompressedLength(const char* input, size_t length,
+ size_t* result);
+bool Snappy_Uncompress(const char* input, size_t length, char* output);
+
+#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L
+// Define this to 1 if the code is compiled in C++11 mode; leave it
+// undefined otherwise. Do NOT define it to 0 -- that causes
+// '#ifdef LANG_CXX11' to behave differently from '#if LANG_CXX11'.
+#define LANG_CXX11 1
+#endif
+
+// Compiler attributes
+#if (defined(__GNUC__) || defined(__APPLE__)) && !defined(SWIG)
+// Compiler supports GCC-style attributes
+#define TF_ATTRIBUTE_NORETURN __attribute__((noreturn))
+#define TF_ATTRIBUTE_NOINLINE __attribute__((noinline))
+#define TF_ATTRIBUTE_UNUSED __attribute__((unused))
+#define TF_ATTRIBUTE_COLD __attribute__((cold))
+#define TF_PACKED __attribute__((packed))
+#define TF_MUST_USE_RESULT __attribute__((warn_unused_result))
+#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check) \
+ __attribute__((__format__(__printf__, string_index, first_to_check)))
+#define TF_SCANF_ATTRIBUTE(string_index, first_to_check) \
+ __attribute__((__format__(__scanf__, string_index, first_to_check)))
+
+#else
+// Non-GCC equivalents
+#define TF_ATTRIBUTE_NORETURN
+#define TF_ATTRIBUTE_NOINLINE
+#define TF_ATTRIBUTE_UNUSED
+#define TF_ATTRIBUTE_COLD
+#define TF_MUST_USE_RESULT
+#define TF_PACKED
+#define TF_PRINTF_ATTRIBUTE(string_index, first_to_check)
+#define TF_SCANF_ATTRIBUTE(string_index, first_to_check)
+#endif
+
+// GCC can be told that a certain branch is not likely to be taken (for
+// instance, a CHECK failure), and use that information in static analysis.
+// Giving it this information can help it optimize for the common case in
+// the absence of better information (ie. -fprofile-arcs).
+//
+#if defined(COMPILER_GCC3)
+#define TF_PREDICT_FALSE(x) (__builtin_expect(x, 0))
+#define TF_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1))
+#else
+#define TF_PREDICT_FALSE(x) x
+#define TF_PREDICT_TRUE(x) x
+#endif
+
+// ---------------------------------------------------------------------------
+// Inline implementations of some performance-critical methods
+// ---------------------------------------------------------------------------
+template <PrefetchHint hint>
+inline void prefetch(const void* x) {
+#if defined(__llvm__) || defined(COMPILER_GCC)
+ __builtin_prefetch(x, 0, hint);
+#else
+// You get no effect. Feel free to add more sections above.
+#endif
+}
+
+// A macro to disallow the copy constructor and operator= functions
+// This is usually placed in the private: declarations for a class.
+#define TF_DISALLOW_COPY_AND_ASSIGN(TypeName) \
+ TypeName(const TypeName&) = delete; \
+ void operator=(const TypeName&) = delete
+
+// The TF_ARRAYSIZE(arr) macro returns the # of elements in an array arr.
+//
+// The expression TF_ARRAYSIZE(a) is a compile-time constant of type
+// size_t.
+#define TF_ARRAYSIZE(a) \
+ ((sizeof(a) / sizeof(*(a))) / \
+ static_cast<size_t>(!(sizeof(a) % sizeof(*(a)))))
+
+#if defined(__clang__) && defined(LANG_CXX11) && defined(__has_warning)
+#if __has_feature(cxx_attributes) && __has_warning("-Wimplicit-fallthrough")
+#define TF_FALLTHROUGH_INTENDED [[clang::fallthrough]] // NOLINT
+#endif
+#endif
+
+#ifndef TF_FALLTHROUGH_INTENDED
+#define TF_FALLTHROUGH_INTENDED \
+ do { \
+ } while (0)
+#endif
+
+} // namespace port
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PLATFORM_PORT_H_
diff --git a/tensorflow/core/platform/port_test.cc b/tensorflow/core/platform/port_test.cc
new file mode 100644
index 0000000000..8cf1c30aa3
--- /dev/null
+++ b/tensorflow/core/platform/port_test.cc
@@ -0,0 +1,48 @@
+#include "tensorflow/core/platform/port.h"
+#include <condition_variable>
+#include "tensorflow/core/lib/core/threadpool.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace port {
+
+TEST(Port, AlignedMalloc) {
+ for (size_t alignment = 1; alignment <= 1 << 20; alignment <<= 1) {
+ void* p = aligned_malloc(1, alignment);
+ ASSERT_TRUE(p != NULL) << "aligned_malloc(1, " << alignment << ")";
+ uintptr_t pval = reinterpret_cast<uintptr_t>(p);
+ EXPECT_EQ(pval % alignment, 0);
+ aligned_free(p);
+ }
+}
+
+TEST(ConditionVariable, WaitForMilliseconds_Timeout) {
+ mutex m;
+ mutex_lock l(m);
+ condition_variable cv;
+ time_t start = time(NULL);
+ EXPECT_EQ(WaitForMilliseconds(&l, &cv, 3000), kCond_Timeout);
+ time_t finish = time(NULL);
+ EXPECT_GE(finish - start, 3);
+}
+
+TEST(ConditionVariable, WaitForMilliseconds_Signalled) {
+ thread::ThreadPool pool(Env::Default(), "test", 1);
+ mutex m;
+ mutex_lock l(m);
+ condition_variable cv;
+ time_t start = time(NULL);
+ // Sleep for just 1 second then notify. We have a timeout of 3 secs,
+ // so the condition variable will notice the cv signal before the timeout.
+ pool.Schedule([&m, &cv]() {
+ sleep(1);
+ mutex_lock l(m);
+ cv.notify_all();
+ });
+ EXPECT_EQ(WaitForMilliseconds(&l, &cv, 3000), kCond_MaybeNotified);
+ time_t finish = time(NULL);
+ EXPECT_LT(finish - start, 3);
+}
+
+} // namespace port
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/posix/env.cc b/tensorflow/core/platform/posix/env.cc
new file mode 100644
index 0000000000..6ba2010005
--- /dev/null
+++ b/tensorflow/core/platform/posix/env.cc
@@ -0,0 +1,385 @@
+#include <dirent.h>
+#include <errno.h>
+#include <fcntl.h>
+#include <stdio.h>
+#include <sys/stat.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <time.h>
+#include <unistd.h>
+
+#include <thread>
+
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+namespace {
+
+error::Code ErrnoToCode(int err_number) {
+ error::Code code;
+ switch (err_number) {
+ case 0:
+ code = error::OK;
+ break;
+ case EINVAL: // Invalid argument
+ case ENAMETOOLONG: // Filename too long
+ case E2BIG: // Argument list too long
+ case EDESTADDRREQ: // Destination address required
+ case EDOM: // Mathematics argument out of domain of function
+ case EFAULT: // Bad address
+ case EILSEQ: // Illegal byte sequence
+ case ENOPROTOOPT: // Protocol not available
+ case ENOSTR: // Not a STREAM
+ case ENOTSOCK: // Not a socket
+ case ENOTTY: // Inappropriate I/O control operation
+ case EPROTOTYPE: // Protocol wrong type for socket
+ case ESPIPE: // Invalid seek
+ code = error::INVALID_ARGUMENT;
+ break;
+ case ETIMEDOUT: // Connection timed out
+ case ETIME: // Timer expired
+ code = error::DEADLINE_EXCEEDED;
+ break;
+ case ENODEV: // No such device
+ case ENOENT: // No such file or directory
+ case ENXIO: // No such device or address
+ case ESRCH: // No such process
+ code = error::NOT_FOUND;
+ break;
+ case EEXIST: // File exists
+ case EADDRNOTAVAIL: // Address not available
+ case EALREADY: // Connection already in progress
+ code = error::ALREADY_EXISTS;
+ break;
+ case EPERM: // Operation not permitted
+ case EACCES: // Permission denied
+ case EROFS: // Read only file system
+ code = error::PERMISSION_DENIED;
+ break;
+ case ENOTEMPTY: // Directory not empty
+ case EISDIR: // Is a directory
+ case ENOTDIR: // Not a directory
+ case EADDRINUSE: // Address already in use
+ case EBADF: // Invalid file descriptor
+ case EBUSY: // Device or resource busy
+ case ECHILD: // No child processes
+ case EISCONN: // Socket is connected
+ case ENOTBLK: // Block device required
+ case ENOTCONN: // The socket is not connected
+ case EPIPE: // Broken pipe
+ case ESHUTDOWN: // Cannot send after transport endpoint shutdown
+ case ETXTBSY: // Text file busy
+ code = error::FAILED_PRECONDITION;
+ break;
+ case ENOSPC: // No space left on device
+ case EDQUOT: // Disk quota exceeded
+ case EMFILE: // Too many open files
+ case EMLINK: // Too many links
+ case ENFILE: // Too many open files in system
+ case ENOBUFS: // No buffer space available
+ case ENODATA: // No message is available on the STREAM read queue
+ case ENOMEM: // Not enough space
+ case ENOSR: // No STREAM resources
+ case EUSERS: // Too many users
+ code = error::RESOURCE_EXHAUSTED;
+ break;
+ case EFBIG: // File too large
+ case EOVERFLOW: // Value too large to be stored in data type
+ case ERANGE: // Result too large
+ code = error::OUT_OF_RANGE;
+ break;
+ case ENOSYS: // Function not implemented
+ case ENOTSUP: // Operation not supported
+ case EAFNOSUPPORT: // Address family not supported
+ case EPFNOSUPPORT: // Protocol family not supported
+ case EPROTONOSUPPORT: // Protocol not supported
+ case ESOCKTNOSUPPORT: // Socket type not supported
+ case EXDEV: // Improper link
+ code = error::UNIMPLEMENTED;
+ break;
+ case EAGAIN: // Resource temporarily unavailable
+ case ECONNREFUSED: // Connection refused
+ case ECONNABORTED: // Connection aborted
+ case ECONNRESET: // Connection reset
+ case EINTR: // Interrupted function call
+ case EHOSTDOWN: // Host is down
+ case EHOSTUNREACH: // Host is unreachable
+ case ENETDOWN: // Network is down
+ case ENETRESET: // Connection aborted by network
+ case ENETUNREACH: // Network unreachable
+ case ENOLCK: // No locks available
+ case ENOLINK: // Link has been severed
+#if !defined(__APPLE__)
+ case ENONET: // Machine is not on the network
+#endif
+ code = error::UNAVAILABLE;
+ break;
+ case EDEADLK: // Resource deadlock avoided
+ case ESTALE: // Stale file handle
+ code = error::ABORTED;
+ break;
+ case ECANCELED: // Operation cancelled
+ code = error::CANCELLED;
+ break;
+ // NOTE: If you get any of the following (especially in a
+ // reproducible way) and can propose a better mapping,
+ // please email the owners about updating this mapping.
+ case EBADMSG: // Bad message
+ case EIDRM: // Identifier removed
+ case EINPROGRESS: // Operation in progress
+ case EIO: // I/O error
+ case ELOOP: // Too many levels of symbolic links
+ case ENOEXEC: // Exec format error
+ case ENOMSG: // No message of the desired type
+ case EPROTO: // Protocol error
+ case EREMOTE: // Object is remote
+ code = error::UNKNOWN;
+ break;
+ default: {
+ code = error::UNKNOWN;
+ break;
+ }
+ }
+ return code;
+}
+
+static Status IOError(const string& context, int err_number) {
+ auto code = ErrnoToCode(err_number);
+ if (code == error::UNKNOWN) {
+ return Status(ErrnoToCode(err_number),
+ context + "; " + strerror(err_number));
+ } else {
+ return Status(ErrnoToCode(err_number), context);
+ }
+}
+
+// pread() based random-access
+class PosixRandomAccessFile : public RandomAccessFile {
+ private:
+ string filename_;
+ int fd_;
+
+ public:
+ PosixRandomAccessFile(const string& fname, int fd)
+ : filename_(fname), fd_(fd) {}
+ ~PosixRandomAccessFile() override { close(fd_); }
+
+ Status Read(uint64 offset, size_t n, StringPiece* result,
+ char* scratch) const override {
+ Status s;
+ char* dst = scratch;
+ while (n > 0 && s.ok()) {
+ ssize_t r = pread(fd_, dst, n, static_cast<off_t>(offset));
+ if (r > 0) {
+ dst += r;
+ n -= r;
+ offset += r;
+ } else if (r == 0) {
+ s = Status(error::OUT_OF_RANGE, "Read less bytes than requested");
+ } else if (errno == EINTR || errno == EAGAIN) {
+ // Retry
+ } else {
+ s = IOError(filename_, errno);
+ }
+ }
+ *result = StringPiece(scratch, dst - scratch);
+ return s;
+ }
+};
+
+class PosixWritableFile : public WritableFile {
+ private:
+ string filename_;
+ FILE* file_;
+
+ public:
+ PosixWritableFile(const string& fname, FILE* f)
+ : filename_(fname), file_(f) {}
+
+ ~PosixWritableFile() override {
+ if (file_ != NULL) {
+ // Ignoring any potential errors
+ fclose(file_);
+ }
+ }
+
+ Status Append(const StringPiece& data) override {
+ size_t r = fwrite(data.data(), 1, data.size(), file_);
+ if (r != data.size()) {
+ return IOError(filename_, errno);
+ }
+ return Status::OK();
+ }
+
+ Status Close() override {
+ Status result;
+ if (fclose(file_) != 0) {
+ result = IOError(filename_, errno);
+ }
+ file_ = NULL;
+ return result;
+ }
+
+ Status Flush() override {
+ if (fflush(file_) != 0) {
+ return IOError(filename_, errno);
+ }
+ return Status::OK();
+ }
+
+ Status Sync() override {
+ Status s;
+ if (fflush(file_) != 0) {
+ s = IOError(filename_, errno);
+ }
+ return s;
+ }
+};
+
+class StdThread : public Thread {
+ public:
+ // name and thread_options are both ignored.
+ StdThread(const ThreadOptions& thread_options, const string& name,
+ std::function<void()> fn)
+ : thread_(fn) {}
+ ~StdThread() { thread_.join(); }
+
+ private:
+ std::thread thread_;
+};
+
+class PosixEnv : public Env {
+ public:
+ PosixEnv() {}
+
+ ~PosixEnv() override { LOG(FATAL) << "Env::Default() must not be destroyed"; }
+
+ Status NewRandomAccessFile(const string& fname,
+ RandomAccessFile** result) override {
+ *result = NULL;
+ Status s;
+ int fd = open(fname.c_str(), O_RDONLY);
+ if (fd < 0) {
+ s = IOError(fname, errno);
+ } else {
+ *result = new PosixRandomAccessFile(fname, fd);
+ }
+ return s;
+ }
+
+ Status NewWritableFile(const string& fname, WritableFile** result) override {
+ Status s;
+ FILE* f = fopen(fname.c_str(), "w");
+ if (f == NULL) {
+ *result = NULL;
+ s = IOError(fname, errno);
+ } else {
+ *result = new PosixWritableFile(fname, f);
+ }
+ return s;
+ }
+
+ Status NewAppendableFile(const string& fname,
+ WritableFile** result) override {
+ Status s;
+ FILE* f = fopen(fname.c_str(), "a");
+ if (f == NULL) {
+ *result = NULL;
+ s = IOError(fname, errno);
+ } else {
+ *result = new PosixWritableFile(fname, f);
+ }
+ return s;
+ }
+
+ bool FileExists(const string& fname) override {
+ return access(fname.c_str(), F_OK) == 0;
+ }
+
+ Status GetChildren(const string& dir, std::vector<string>* result) override {
+ result->clear();
+ DIR* d = opendir(dir.c_str());
+ if (d == NULL) {
+ return IOError(dir, errno);
+ }
+ struct dirent* entry;
+ while ((entry = readdir(d)) != NULL) {
+ StringPiece basename = entry->d_name;
+ if ((basename != ".") && (basename != "..")) {
+ result->push_back(entry->d_name);
+ }
+ }
+ closedir(d);
+ return Status::OK();
+ }
+
+ Status DeleteFile(const string& fname) override {
+ Status result;
+ if (unlink(fname.c_str()) != 0) {
+ result = IOError(fname, errno);
+ }
+ return result;
+ }
+
+ Status CreateDir(const string& name) override {
+ Status result;
+ if (mkdir(name.c_str(), 0755) != 0) {
+ result = IOError(name, errno);
+ }
+ return result;
+ }
+
+ Status DeleteDir(const string& name) override {
+ Status result;
+ if (rmdir(name.c_str()) != 0) {
+ result = IOError(name, errno);
+ }
+ return result;
+ }
+
+ Status GetFileSize(const string& fname, uint64* size) override {
+ Status s;
+ struct stat sbuf;
+ if (stat(fname.c_str(), &sbuf) != 0) {
+ *size = 0;
+ s = IOError(fname, errno);
+ } else {
+ *size = sbuf.st_size;
+ }
+ return s;
+ }
+
+ Status RenameFile(const string& src, const string& target) override {
+ Status result;
+ if (rename(src.c_str(), target.c_str()) != 0) {
+ result = IOError(src, errno);
+ }
+ return result;
+ }
+
+ uint64 NowMicros() override {
+ struct timeval tv;
+ gettimeofday(&tv, NULL);
+ return static_cast<uint64>(tv.tv_sec) * 1000000 + tv.tv_usec;
+ }
+
+ void SleepForMicroseconds(int micros) override { usleep(micros); }
+
+ Thread* StartThread(const ThreadOptions& thread_options, const string& name,
+ std::function<void()> fn) override {
+ return new StdThread(thread_options, name, fn);
+ }
+};
+
+} // namespace
+#if defined(PLATFORM_POSIX) || defined(__ANDROID__)
+Env* Env::Default() {
+ static Env* default_env = new PosixEnv;
+ return default_env;
+}
+#endif
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc
new file mode 100644
index 0000000000..b4a1570ef9
--- /dev/null
+++ b/tensorflow/core/platform/posix/port.cc
@@ -0,0 +1,92 @@
+#include "tensorflow/core/platform/port.h"
+#if defined(__linux) && !defined(__ANDROID__)
+#include <sched.h>
+#endif
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <unistd.h>
+#ifdef SNAPPY
+#include <snappy.h>
+#endif
+
+namespace tensorflow {
+namespace port {
+
+void InitMain(const char* usage, int* argc, char*** argv) {}
+
+string Hostname() {
+ char hostname[1024];
+ gethostname(hostname, sizeof hostname);
+ hostname[sizeof hostname - 1] = 0;
+ return string(hostname);
+}
+
+int NumSchedulableCPUs() {
+#if defined(__linux) && !defined(__ANDROID__)
+ cpu_set_t cpuset;
+ if (sched_getaffinity(0, sizeof(cpu_set_t), &cpuset) == 0) {
+ return CPU_COUNT(&cpuset);
+ }
+ perror("sched_getaffinity");
+#endif
+ const int kDefaultCores = 4; // Semi-conservative guess
+ fprintf(stderr, "can't determine number of CPU cores: assuming %d\n",
+ kDefaultCores);
+ return kDefaultCores;
+}
+
+void* aligned_malloc(size_t size, int minimum_alignment) {
+#if defined(__ANDROID__)
+ return memalign(minimum_alignment, size);
+#else // !__ANDROID__
+ void* ptr = NULL;
+ // posix_memalign requires that the requested alignment be at least
+ // sizeof(void*). In this case, fall back on malloc which should return
+ // memory aligned to at least the size of a pointer.
+ const int required_alignment = sizeof(void*);
+ if (minimum_alignment < required_alignment) return malloc(size);
+ if (posix_memalign(&ptr, minimum_alignment, size) != 0)
+ return NULL;
+ else
+ return ptr;
+#endif
+}
+
+void aligned_free(void* aligned_memory) { free(aligned_memory); }
+
+void AdjustFilenameForLogging(string* filename) {
+ // Nothing to do
+}
+
+bool Snappy_Compress(const char* input, size_t length, string* output) {
+#ifdef SNAPPY
+ output->resize(snappy::MaxCompressedLength(length));
+ size_t outlen;
+ snappy::RawCompress(input, length, &(*output)[0], &outlen);
+ output->resize(outlen);
+ return true;
+#else
+ return false;
+#endif
+}
+
+bool Snappy_GetUncompressedLength(const char* input, size_t length,
+ size_t* result) {
+#ifdef SNAPPY
+ return snappy::GetUncompressedLength(input, length, result);
+#else
+ return false;
+#endif
+}
+
+bool Snappy_Uncompress(const char* input, size_t length, char* output) {
+#ifdef SNAPPY
+ return snappy::RawUncompress(input, length, output);
+#else
+ return false;
+#endif
+}
+
+} // namespace port
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/protobuf.h b/tensorflow/core/platform/protobuf.h
new file mode 100644
index 0000000000..3a166b3973
--- /dev/null
+++ b/tensorflow/core/platform/protobuf.h
@@ -0,0 +1,29 @@
+#ifndef TENSORFLOW_PLATFORM_PROTOBUF_H_
+#define TENSORFLOW_PLATFORM_PROTOBUF_H_
+
+// Import whatever namespace protobuf comes from into the
+// ::tensorflow::protobuf namespace.
+//
+// TensorFlow code should the ::tensorflow::protobuf namespace to refer
+// to all protobuf APIs.
+
+#include "tensorflow/core/platform/port.h"
+#if defined(PLATFORM_GOOGLE)
+#include "tensorflow/core/platform/google/protobuf.h"
+#elif defined(PLATFORM_GOOGLE_ANDROID)
+#include "tensorflow/core/platform/google/protobuf_android.h"
+#else
+#include "tensorflow/core/platform/default/protobuf.h"
+#endif
+
+namespace tensorflow {
+// Parses a protocol buffer contained in a string in the binary wire format.
+// Returns true on success. Note: Unlike protobuf's builtin ParseFromString,
+// this function has no size restrictions on the total size of the encoded
+// protocol buffer.
+bool ParseProtoUnlimited(protobuf::Message* proto, const string& serialized);
+bool ParseProtoUnlimited(protobuf::Message* proto, const void* serialized,
+ size_t size);
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PLATFORM_PROTOBUF_H_
diff --git a/tensorflow/core/platform/protobuf_util.cc b/tensorflow/core/platform/protobuf_util.cc
new file mode 100644
index 0000000000..b698d3f0c2
--- /dev/null
+++ b/tensorflow/core/platform/protobuf_util.cc
@@ -0,0 +1,17 @@
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+bool ParseProtoUnlimited(protobuf::Message* proto, const string& serialized) {
+ return ParseProtoUnlimited(proto, serialized.data(), serialized.size());
+}
+
+bool ParseProtoUnlimited(protobuf::Message* proto, const void* serialized,
+ size_t size) {
+ protobuf::io::CodedInputStream coded_stream(
+ reinterpret_cast<const uint8*>(serialized), size);
+ coded_stream.SetTotalBytesLimit(INT_MAX, INT_MAX);
+ return proto->ParseFromCodedStream(&coded_stream);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/regexp.h b/tensorflow/core/platform/regexp.h
new file mode 100644
index 0000000000..ef46a7aca5
--- /dev/null
+++ b/tensorflow/core/platform/regexp.h
@@ -0,0 +1,33 @@
+#ifndef TENSORFLOW_PLATFORM_REGEXP_H_
+#define TENSORFLOW_PLATFORM_REGEXP_H_
+
+#include "tensorflow/core/platform/port.h"
+
+#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
+#include "third_party/re2/re2.h"
+namespace tensorflow {
+typedef ::StringPiece RegexpStringPiece;
+} // namespace tensorflow
+
+#else
+
+#include "external/re2/re2/re2.h"
+namespace tensorflow {
+typedef re2::StringPiece RegexpStringPiece;
+} // namespace tensorflow
+
+#endif
+
+namespace tensorflow {
+
+// Conversion to/from the appropriate StringPiece type for using in RE2
+inline RegexpStringPiece ToRegexpStringPiece(tensorflow::StringPiece sp) {
+ return RegexpStringPiece(sp.data(), sp.size());
+}
+inline tensorflow::StringPiece FromRegexpStringPiece(RegexpStringPiece sp) {
+ return tensorflow::StringPiece(sp.data(), sp.size());
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PLATFORM_REGEXP_H_
diff --git a/tensorflow/core/platform/stream_executor_util.h b/tensorflow/core/platform/stream_executor_util.h
new file mode 100644
index 0000000000..a6640fb26d
--- /dev/null
+++ b/tensorflow/core/platform/stream_executor_util.h
@@ -0,0 +1,12 @@
+#ifndef TENSORFLOW_PLATFORM_STREAM_EXECUTOR_UTIL_H_
+#define TENSORFLOW_PLATFORM_STREAM_EXECUTOR_UTIL_H_
+
+#include "tensorflow/core/platform/port.h"
+
+#if defined(PLATFORM_GOOGLE)
+#include "tensorflow/core/platform/google/stream_executor_util.h"
+#else
+#include "tensorflow/core/platform/default/stream_executor_util.h"
+#endif
+
+#endif // TENSORFLOW_PLATFORM_STREAM_EXECUTOR_UTIL_H_
diff --git a/tensorflow/core/platform/tensor_coding.cc b/tensorflow/core/platform/tensor_coding.cc
new file mode 100644
index 0000000000..a5cbd0ab44
--- /dev/null
+++ b/tensorflow/core/platform/tensor_coding.cc
@@ -0,0 +1,53 @@
+#include "tensorflow/core/platform/tensor_coding.h"
+
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+namespace port {
+
+void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out) {
+ out->assign(src.data(), src.size());
+}
+
+void EncodeStringList(const string* strings, int64 n, string* out) {
+ out->clear();
+ for (int i = 0; i < n; ++i) {
+ core::PutVarint32(out, strings[i].size());
+ }
+ for (int i = 0; i < n; ++i) {
+ out->append(strings[i]);
+ }
+}
+
+bool DecodeStringList(const string& src, string* strings, int64 n) {
+ std::vector<uint32> sizes(n);
+ StringPiece reader(src);
+ int64 tot = 0;
+ for (auto& v : sizes) {
+ if (!core::GetVarint32(&reader, &v)) return false;
+ tot += v;
+ }
+ if (tot != static_cast<int64>(reader.size())) {
+ return false;
+ }
+
+ string* data = strings;
+ for (int64 i = 0; i < n; ++i, ++data) {
+ auto size = sizes[i];
+ if (size > reader.size()) {
+ return false;
+ }
+ data->assign(reader.data(), size);
+ reader.remove_prefix(size);
+ }
+
+ return true;
+}
+
+void CopyFromArray(string* s, const char* base, size_t bytes) {
+ s->assign(base, bytes);
+}
+
+} // namespace port
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/tensor_coding.h b/tensorflow/core/platform/tensor_coding.h
new file mode 100644
index 0000000000..6bb9991895
--- /dev/null
+++ b/tensorflow/core/platform/tensor_coding.h
@@ -0,0 +1,40 @@
+// Helper routines for encoding/decoding tensor contents.
+#ifndef TENSORFLOW_PLATFORM_TENSOR_CODING_H_
+#define TENSORFLOW_PLATFORM_TENSOR_CODING_H_
+
+#include <string>
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/port.h"
+
+#ifdef PLATFORM_GOOGLE
+#include "tensorflow/core/platform/google/cord_coding.h"
+#endif
+
+namespace tensorflow {
+namespace port {
+
+// Store src contents in *out. If backing memory for src is shared with *out,
+// will ref obj during the call and will arrange to unref obj when no
+// longer needed.
+void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out);
+
+// Copy contents of src to dst[0,src.size()-1].
+inline void CopyToArray(const string& src, char* dst) {
+ memcpy(dst, src.data(), src.size());
+}
+
+// Store encoding of strings[0..n-1] in *out.
+void EncodeStringList(const string* strings, int64 n, string* out);
+
+// Decode n strings from src and store in strings[0..n-1].
+// Returns true if successful, false on parse error.
+bool DecodeStringList(const string& src, string* strings, int64 n);
+
+// Assigns base[0..bytes-1] to *s
+void CopyFromArray(string* s, const char* base, size_t bytes);
+
+} // namespace port
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PLATFORM_TENSOR_CODING_H_
diff --git a/tensorflow/core/platform/test.cc b/tensorflow/core/platform/test.cc
new file mode 100644
index 0000000000..21c6905683
--- /dev/null
+++ b/tensorflow/core/platform/test.cc
@@ -0,0 +1,39 @@
+#include "tensorflow/core/platform/port.h"
+
+#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_POSIX_ANDROID) || \
+ defined(PLATFORM_GOOGLE_ANDROID)
+#include "testing/base/public/googletest.h"
+#endif
+
+namespace tensorflow {
+namespace testing {
+
+#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_POSIX_ANDROID) || \
+ defined(PLATFORM_GOOGLE_ANDROID)
+string TmpDir() { return FLAGS_test_tmpdir; }
+int RandomSeed() { return FLAGS_test_random_seed; }
+#else
+string TmpDir() {
+ // 'bazel test' sets TEST_TMPDIR
+ const char* env = getenv("TEST_TMPDIR");
+ if (env && env[0] != '\0') {
+ return env;
+ }
+ env = getenv("TMPDIR");
+ if (env && env[0] != '\0') {
+ return env;
+ }
+ return "/tmp";
+}
+int RandomSeed() {
+ const char* env = getenv("TEST_RANDOM_SEED");
+ int result;
+ if (env && sscanf(env, "%d", &result) == 1) {
+ return result;
+ }
+ return 301;
+}
+#endif
+
+} // namespace testing
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/test.h b/tensorflow/core/platform/test.h
new file mode 100644
index 0000000000..ea16fe1442
--- /dev/null
+++ b/tensorflow/core/platform/test.h
@@ -0,0 +1,17 @@
+#ifndef TENSORFLOW_PLATFORM_TEST_H_
+#define TENSORFLOW_PLATFORM_TEST_H_
+
+namespace tensorflow {
+namespace testing {
+
+// Return a temporary directory suitable for temporary testing files.
+string TmpDir();
+
+// Return a random number generator seed to use in randomized tests.
+// Returns the same value for the lifetime of the process.
+int RandomSeed();
+
+} // namespace testing
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PLATFORM_TEST_H_
diff --git a/tensorflow/core/platform/test_benchmark.h b/tensorflow/core/platform/test_benchmark.h
new file mode 100644
index 0000000000..8c8a92a519
--- /dev/null
+++ b/tensorflow/core/platform/test_benchmark.h
@@ -0,0 +1,58 @@
+// Simple benchmarking facility.
+#ifndef TENSORFLOW_PLATFORM_TEST_BENCHMARK_H_
+#define TENSORFLOW_PLATFORM_TEST_BENCHMARK_H_
+
+#include "tensorflow/core/platform/port.h"
+
+#if defined(PLATFORM_GOOGLE)
+#include "testing/base/public/benchmark.h"
+
+#else
+#define BENCHMARK(n) \
+ static ::tensorflow::testing::Benchmark* TF_BENCHMARK_CONCAT( \
+ __benchmark_, n, __LINE__) TF_ATTRIBUTE_UNUSED = \
+ (new ::tensorflow::testing::Benchmark(#n, (n)))
+#define TF_BENCHMARK_CONCAT(a, b, c) TF_BENCHMARK_CONCAT2(a, b, c)
+#define TF_BENCHMARK_CONCAT2(a, b, c) a##b##c
+
+#endif // PLATFORM_GOOGLE
+
+namespace tensorflow {
+namespace testing {
+
+#if defined(PLATFORM_GOOGLE)
+using ::testing::Benchmark;
+#else
+class Benchmark {
+ public:
+ Benchmark(const char* name, void (*fn)(int));
+ Benchmark(const char* name, void (*fn)(int, int));
+
+ Benchmark* Arg(int x);
+ Benchmark* Range(int lo, int hi);
+ static void Run(const char* pattern);
+
+ private:
+ string name_;
+ int num_args_;
+ std::vector<int> args_;
+ void (*fn0_)(int) = nullptr;
+ void (*fn1_)(int, int) = nullptr;
+
+ void Register();
+ void Run(int arg, int* run_count, double* run_seconds);
+};
+#endif
+
+void RunBenchmarks();
+void SetLabel(const std::string& label);
+void BytesProcessed(int64);
+void ItemsProcessed(int64);
+void StartTiming();
+void StopTiming();
+void UseRealTime();
+
+} // namespace testing
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PLATFORM_TEST_BENCHMARK_H_
diff --git a/tensorflow/core/platform/test_main.cc b/tensorflow/core/platform/test_main.cc
new file mode 100644
index 0000000000..11230c3f7b
--- /dev/null
+++ b/tensorflow/core/platform/test_main.cc
@@ -0,0 +1,31 @@
+// A program with a main that is suitable for unittests, including those
+// that also define microbenchmarks. Based on whether the user specified
+// the --benchmark_filter flag which specifies which benchmarks to run,
+// we will either run benchmarks or run the gtest tests in the program.
+
+#include <iostream>
+
+#include "tensorflow/core/platform/port.h"
+
+#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_POSIX_ANDROID) || \
+ defined(PLATFORM_GOOGLE_ANDROID)
+// main() is supplied by gunit_main
+#else
+#include "gtest/gtest.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+GTEST_API_ int main(int argc, char** argv) {
+ std::cout << "Running main() from test_main.cc\n";
+
+ testing::InitGoogleTest(&argc, argv);
+ for (int i = 1; i < argc; i++) {
+ if (tensorflow::StringPiece(argv[i]).starts_with("--benchmarks=")) {
+ const char* pattern = argv[i] + strlen("--benchmarks=");
+ tensorflow::testing::Benchmark::Run(pattern);
+ return 0;
+ }
+ }
+ return RUN_ALL_TESTS();
+}
+#endif
diff --git a/tensorflow/core/platform/thread_annotations.h b/tensorflow/core/platform/thread_annotations.h
new file mode 100644
index 0000000000..cb8040eed6
--- /dev/null
+++ b/tensorflow/core/platform/thread_annotations.h
@@ -0,0 +1,14 @@
+#ifndef TENSORFLOW_PLATFORM_THREAD_ANNOTATIONS_H_
+#define TENSORFLOW_PLATFORM_THREAD_ANNOTATIONS_H_
+
+#include "tensorflow/core/platform/port.h"
+
+#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
+#include "base/thread_annotations.h"
+#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID)
+#include "tensorflow/core/platform/default/thread_annotations.h"
+#else
+#error Define the appropriate PLATFORM_<foo> macro for this platform
+#endif
+
+#endif // TENSORFLOW_PLATFORM_THREAD_ANNOTATIONS_H_
diff --git a/tensorflow/core/platform/tracing.cc b/tensorflow/core/platform/tracing.cc
new file mode 100644
index 0000000000..a4cb92dee4
--- /dev/null
+++ b/tensorflow/core/platform/tracing.cc
@@ -0,0 +1,135 @@
+#include "tensorflow/core/platform/tracing.h"
+
+#include <atomic>
+#include <map>
+#include <string>
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+StepStatsCollector::StepStatsCollector(StepStats* ss) : step_stats_(ss) {}
+
+void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
+ VLOG(1) << "Save dev " << device << " nt " << nt;
+ {
+ mutex_lock l(mu_);
+ DeviceStepStats* dss = nullptr;
+ // Slow linear scan, but it should only be called
+ // by a Worker in a context with < ~10 devices.
+ // TODO(tucker): consider adding a std::unordered_map.
+ for (auto& ds : *step_stats_->mutable_dev_stats()) {
+ if (ds.device() == device) {
+ dss = &ds;
+ break;
+ }
+ }
+ if (dss == nullptr) {
+ dss = step_stats_->add_dev_stats();
+ dss->set_device(device);
+ }
+ nt->Swap(dss->add_node_stats());
+ }
+ delete nt;
+}
+
+void StepStatsCollector::Swap(StepStats* ss) {
+ mutex_lock l(mu_);
+ CHECK(step_stats_);
+ ss->Swap(step_stats_);
+}
+
+namespace port {
+
+int32 Tracing::category_id_[kEventCategoryMax];
+uint64 Tracing::event_mask_ = 0;
+std::map<string, int32>* Tracing::name_map_ = new std::map<string, int32>;
+
+// This needs to be kept in sync with the EventCategory enumeration.
+const char* Tracing::EventCategoryString(EventCategory category) {
+ switch (category) {
+ case EventCategory::kScheduleClosure:
+ return "ScheduleClosure";
+ case EventCategory::kRunClosure:
+ return "RunClosure";
+ case EventCategory::kCompute:
+ return "Compute";
+ case EventCategory::kEventCategoryMax:
+ return "EventCategoryMax";
+ }
+ return "Unknown";
+}
+
+// This function allows the user to specify arbitrary subsets of the
+// supported Threadscape events and activities.
+bool Tracing::ParseEventMask(const char* flagname, const string& value) {
+ VLOG(1) << flagname << " set to " << value;
+ int64 new_mask = 0;
+ std::vector<string> events =
+ str_util::Split(value, ',', str_util::SkipEmpty());
+ for (string name : events) {
+ bool clear = false;
+ int64 mask = 0;
+ if (name[0] == '!') {
+ // invert the sense of the flag
+ clear = true;
+ name = name.substr(1);
+ }
+ if (name == "ALL") {
+ mask = ~0;
+ } else {
+ auto it = name_map_->find(name);
+ int32 id;
+ if (it == name_map_->end()) {
+ id = -1;
+ } else {
+ id = it->second;
+ }
+ if (id < 0) {
+ LOG(ERROR) << "Can't parse event mask name " << name;
+ return false;
+ }
+ mask = 1 << id;
+ }
+ if (clear) {
+ new_mask &= ~mask;
+ } else {
+ new_mask |= mask;
+ }
+ }
+ // parsing was successful; set the permanent event mask
+ event_mask_ = new_mask;
+ return true;
+}
+
+static std::atomic<Tracing::Engine*> tracing_engine;
+
+void Tracing::RegisterEngine(Engine* e) {
+ tracing_engine.store(e, std::memory_order_release);
+}
+
+static Tracing::Engine* engine() {
+ return tracing_engine.load(std::memory_order_acquire);
+}
+
+Tracing::Engine::~Engine() {}
+Tracing::Engine::Annotation::~Annotation() {}
+Tracing::Engine::Tracer::~Tracer() {}
+
+Tracing::ScopedAnnotation::ScopedAnnotation(StringPiece name) {
+ auto e = engine();
+ if (e) {
+ annotation_.reset(e->PushAnnotation(name));
+ }
+}
+
+Tracing::TraceMe::TraceMe(StringPiece name) {
+ auto e = engine();
+ if (e) {
+ tracer_.reset(e->StartTracing(name));
+ }
+}
+
+} // namespace port
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h
new file mode 100644
index 0000000000..2b53a64cf1
--- /dev/null
+++ b/tensorflow/core/platform/tracing.h
@@ -0,0 +1,205 @@
+#ifndef TENSORFLOW_PLATFORM_TRACING_H_
+#define TENSORFLOW_PLATFORM_TRACING_H_
+
+// Tracing interface
+
+#include <map>
+#include <memory>
+
+#include "tensorflow/core/platform/port.h" // Must be first
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace tensorflow {
+
+class NodeExecStats;
+class StepStats;
+
+class StepStatsCollector {
+ public:
+ explicit StepStatsCollector(StepStats* ss);
+
+ void Save(const string& device, NodeExecStats* nt);
+
+ void Swap(StepStats* ss);
+
+ private:
+ friend class StepStatsMgr;
+ mutex mu_;
+ StepStats* step_stats_ GUARDED_BY(mu_);
+};
+
+namespace port {
+
+class Tracing {
+ public:
+ // This enumeration contains the identifiers of all TensorFlow
+ // threadscape events and code regions. Threadscape assigns its
+ // own identiers at runtime when we register our events and we
+ // cannot know in advance what IDs it will choose. The "RecordEvent"
+ // method and "ScopedActivity" use these event IDs for consistency
+ // and remap them to threadscape IDs at runtime. This enum is limited
+ // to 64 values since we use a bitmask to configure which events are
+ // enabled. It must also be kept in step with the code in
+ // "Tracing::EventCategoryString".
+ enum EventCategory {
+ kScheduleClosure = 0,
+ kRunClosure = 1,
+ kCompute = 2,
+ kEventCategoryMax = 3 // sentinel - keep last
+ };
+ // Note: We currently only support up to 64 categories.
+ static_assert(kEventCategoryMax <= 64, "only support up to 64 events");
+
+ // Called by main programs to initialize tracing facilities
+ static void Initialize();
+
+ // Return the pathname of the directory where we are writing log files.
+ static const char* LogDir();
+
+ // Returns a non-zero identifier which can be used to correlate
+ // related events.
+ static inline uint64 UniqueId();
+
+ // Returns true if a trace is in progress. Can be used to reduce tracing
+ // overheads in fast-path code.
+ static inline bool IsActive();
+
+ // Associate name with the current thread.
+ static void RegisterCurrentThread(const char* name);
+
+ // Posts an event with the supplied category and arg.
+ static void RecordEvent(EventCategory category, uint64 arg);
+
+ // Traces a region of code. Posts a tracing "EnterCodeRegion" event
+ // when created and an "ExitCodeRegion" event when destroyed.
+ class ScopedActivity {
+ public:
+ explicit ScopedActivity(EventCategory category, uint64 arg);
+ ~ScopedActivity();
+
+ private:
+ const bool enabled_;
+ const int32 region_id_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ScopedActivity);
+ };
+
+ // Trace collection engine can be registered with this module.
+ // If no engine is registered, ScopedAnnotation and TraceMe are no-ops.
+ class Engine;
+ static void RegisterEngine(Engine*);
+
+ // Forward declaration of the GPU utility classes.
+ class ScopedAnnotation;
+ class TraceMe;
+
+ private:
+ friend class TracingTest;
+
+ static void RegisterEvent(EventCategory id, const char* name);
+ static const char* EventCategoryString(EventCategory category);
+
+ //
+ // Parses event mask expressions in 'value' of the form:
+ // expr ::= <term> (,<term>)*
+ // term ::= <event> | "!" <event>
+ // event ::= "ALL" | <wait_event> | <other_event>
+ // wait_event ::= "ENewSession" | "ECloseSession" | ...
+ // other_event ::= "Send" | "Wait" | ...
+ // ALL denotes all events, <event> turns on tracing for this event, and
+ // !<event> turns off tracing for this event.
+ // If the expression can be parsed correctly it returns true and sets
+ // the event_mask_. Otherwise it returns false and the event_mask_ is left
+ // unchanged.
+ static bool ParseEventMask(const char* flagname, const string& value);
+
+ // Bit mask of enabled trace categories.
+ static uint64 event_mask_;
+
+ // Records the mappings between Threadscape IDs and the "EventCategory" enum.
+ static int32 category_id_[kEventCategoryMax];
+ static std::map<string, int32>* name_map_;
+};
+
+// Trace collection engine that actually implements collection.
+class Tracing::Engine {
+ public:
+ Engine() {}
+ virtual ~Engine();
+
+ // Represents an active annotation.
+ class Annotation {
+ public:
+ Annotation() {}
+ virtual ~Annotation();
+ };
+
+ // Represents an active trace.
+ class Tracer {
+ public:
+ Tracer() {}
+ virtual ~Tracer();
+ };
+
+ private:
+ friend class ScopedAnnotation;
+ friend class TraceMe;
+
+ // Register the specified name as an annotation on the current thread.
+ // Caller should delete the result to remove the annotation.
+ // Annotations from the same thread are destroyed in a LIFO manner.
+ // May return nullptr if annotations are not supported.
+ virtual Annotation* PushAnnotation(StringPiece name) = 0;
+
+ // Start tracing under the specified label. Caller should delete the
+ // result to stop tracing.
+ // May return nullptr if tracing is not supported.
+ virtual Tracer* StartTracing(StringPiece label) = 0;
+};
+
+// This class permits a user to apply annotation on kernels and memcpys
+// when launching them. While an annotation is in scope, all activities
+// within that scope get their names replaced by the annotation. The kernel
+// name replacement is done when constructing the protobuf for sending out to
+// a client (e.g., the stubby requestor) for both API and Activity records.
+//
+// Ownership: The creator of ScopedAnnotation assumes ownership of the object.
+//
+// Usage: {
+// ScopedAnnotation annotation("first set of kernels");
+// Kernel1<<<x,y>>>;
+// LaunchKernel2(); // Which eventually launches a cuda kernel.
+// }
+// In the above scenario, the GPUProf UI would show 2 kernels with the name
+// "first set of kernels" executing -- they will appear as the same kernel.
+class Tracing::ScopedAnnotation {
+ public:
+ explicit ScopedAnnotation(StringPiece name);
+
+ private:
+ std::unique_ptr<Engine::Annotation> annotation_;
+};
+
+// TODO(opensource): clean up the scoped classes for GPU tracing.
+// This class permits user-specified (CPU) tracing activities. A trace
+// activity is started when an object of this class is created and stopped
+// when the object is destroyed.
+class Tracing::TraceMe {
+ public:
+ explicit TraceMe(StringPiece name);
+
+ private:
+ std::unique_ptr<Engine::Tracer> tracer_;
+};
+
+} // namespace port
+} // namespace tensorflow
+
+#if defined(PLATFORM_GOOGLE) && !defined(ANDROID) && !defined(__ANDROID__)
+#include "tensorflow/core/platform/google/tracing_impl.h"
+#else
+#include "tensorflow/core/platform/default/tracing_impl.h"
+#endif
+
+#endif // TENSORFLOW_PLATFORM_TRACING_H_
diff --git a/tensorflow/core/public/README.md b/tensorflow/core/public/README.md
new file mode 100644
index 0000000000..b1afff87de
--- /dev/null
+++ b/tensorflow/core/public/README.md
@@ -0,0 +1,90 @@
+# TensorFlow
+
+TensorFlow is a computational dataflow graph library.
+
+## Getting started
+
+
+### Python API example
+The following is an example python code to do a simple matrix multiply
+of two constants and get the result from a locally-running TensorFlow
+process.
+
+First, bring in the following dependency:
+
+//third_party/tensorflow/core/public:tensorflow_py
+
+to get the python TensorFlow API. If you intend to run TensorFlow within
+the same process, link in the following to the same binary:
+
+//third_party/tensorflow/core/public:tensorflow_std_ops
+
+to get the standard set of op implementations. Then:
+
+```python
+import tensorflow as tf
+
+with tf.Session("local"):
+ input1 = tf.Constant(1.0, shape=[1, 1], name="input1")
+ input2 = tf.Constant(2.0, shape=[1, 1], name="input2")
+ output = tf.MatMul(input1, input2)
+
+ # Run graph and fetch the output
+ result = output.eval()
+ print result
+```
+
+### C++ API Example
+
+If you are running TensorFlow locally, link your binary with
+
+//third_party/tensorflow/core/public:tensorflow_local
+
+and link in the operation implementations you want to supported, e.g.,
+
+//third_party/tensorflow/core/public:tensorflow_std_ops
+
+An example program to take a GraphDef and run it using TensorFlow
+using the C++ Session API:
+
+```c++
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/tensor.h"
+
+int main(int argc, char** argv) {
+ // Construct your graph.
+ tensorflow::GraphDef graph = ...;
+
+ // Create a Session running TensorFlow locally in process.
+ std::unique_ptr<tensorflow::Session> session(tensorflow::NewSession({}));
+
+ // Initialize the session with the graph.
+ tensorflow::Status s = session->Create(graph);
+ if (!s.ok()) { ... }
+
+ // Specify the 'feeds' of your network if needed.
+ std::vector<std::pair<string, tensorflow::Tensor>> inputs;
+
+ // Run the session, asking for the first output of "my_output".
+ std::vector<tensorflow::Tensor> outputs;
+ s = session->Run(inputs, {"my_output:0"}, {}, &outputs);
+ if (!s.ok()) { ... }
+
+ // Do something with your outputs
+ auto output_vector = outputs[0].vec<float>();
+ if (output_vector(0) > 0.5) { ... }
+
+ // Close the session.
+ session->Close();
+
+ return 0;
+}
+```
+
+For a more fully-featured C++ example, see
+`tensorflow/cc/tutorials/example_trainer.cc`
diff --git a/tensorflow/core/public/env.h b/tensorflow/core/public/env.h
new file mode 100644
index 0000000000..4024525859
--- /dev/null
+++ b/tensorflow/core/public/env.h
@@ -0,0 +1,273 @@
+#ifndef TENSORFLOW_PUBLIC_ENV_H_
+#define TENSORFLOW_PUBLIC_ENV_H_
+
+#include <string>
+#include <vector>
+#include <stdint.h>
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+class RandomAccessFile;
+class Thread;
+class ThreadOptions;
+class WritableFile;
+
+/// \brief An interface used by the tensorflow implementation to
+/// access operating system functionality like the filesystem etc.
+///
+/// Callers may wish to provide a custom Env object to get fine grain
+/// control.
+///
+/// All Env implementations are safe for concurrent access from
+/// multiple threads without any external synchronization.
+class Env {
+ public:
+ Env() {}
+ virtual ~Env();
+
+ /// \brief Returns a default environment suitable for the current operating
+ /// system.
+ ///
+ /// Sophisticated users may wish to provide their own Env
+ /// implementation instead of relying on this default environment.
+ ///
+ /// The result of Default() belongs to this library and must never be deleted.
+ static Env* Default();
+
+ /// \brief Creates a brand new random access read-only file with the
+ /// specified name.
+
+ /// On success, stores a pointer to the new file in
+ /// *result and returns OK. On failure stores NULL in *result and
+ /// returns non-OK. If the file does not exist, returns a non-OK
+ /// status.
+ ///
+ /// The returned file may be concurrently accessed by multiple threads.
+ virtual Status NewRandomAccessFile(const string& fname,
+ RandomAccessFile** result) = 0;
+
+ /// \brief Creates an object that writes to a new file with the specified
+ /// name.
+ ///
+ /// Deletes any existing file with the same name and creates a
+ /// new file. On success, stores a pointer to the new file in
+ /// *result and returns OK. On failure stores NULL in *result and
+ /// returns non-OK.
+ ///
+ /// The returned file will only be accessed by one thread at a time.
+ virtual Status NewWritableFile(const string& fname,
+ WritableFile** result) = 0;
+
+ /// \brief Creates an object that either appends to an existing file, or
+ /// writes to a new file (if the file does not exist to begin with).
+ ///
+ /// On success, stores a pointer to the new file in *result and
+ /// returns OK. On failure stores NULL in *result and returns
+ /// non-OK.
+ ///
+ /// The returned file will only be accessed by one thread at a time.
+ virtual Status NewAppendableFile(const string& fname,
+ WritableFile** result) = 0;
+
+ /// Returns true iff the named file exists.
+ virtual bool FileExists(const string& fname) = 0;
+
+ /// \brief Stores in *result the names of the children of the specified
+ /// directory. The names are relative to "dir".
+ ///
+ /// Original contents of *results are dropped.
+ virtual Status GetChildren(const string& dir,
+ std::vector<string>* result) = 0;
+
+ /// Deletes the named file.
+ virtual Status DeleteFile(const string& fname) = 0;
+
+ /// Creates the specified directory.
+ virtual Status CreateDir(const string& dirname) = 0;
+
+ /// Deletes the specified directory.
+ virtual Status DeleteDir(const string& dirname) = 0;
+
+ /// Stores the size of fname in *file_size.
+ virtual Status GetFileSize(const string& fname, uint64* file_size) = 0;
+
+ /// \brief Renames file src to target. If target already exists, it will be
+ /// replaced.
+ virtual Status RenameFile(const string& src, const string& target) = 0;
+
+ // TODO(jeff,sanjay): Add back thread/thread-pool support if needed.
+ // TODO(jeff,sanjay): if needed, tighten spec so relative to epoch, or
+ // provide a routine to get the absolute time.
+
+ /// \brief Returns the number of micro-seconds since some fixed point in
+ /// time. Only useful for computing deltas of time.
+ virtual uint64 NowMicros() = 0;
+
+ /// Sleeps/delays the thread for the prescribed number of micro-seconds.
+ virtual void SleepForMicroseconds(int micros) = 0;
+
+ /// \brief Returns a new thread that is running fn() and is identified
+ /// (for debugging/performance-analysis) by "name".
+ ///
+ /// Caller takes ownership of the result and must delete it eventually
+ /// (the deletion will block until fn() stops running).
+ virtual Thread* StartThread(const ThreadOptions& thread_options,
+ const string& name,
+ std::function<void()> fn) TF_MUST_USE_RESULT = 0;
+
+ private:
+ /// No copying allowed
+ Env(const Env&);
+ void operator=(const Env&);
+};
+
+/// A file abstraction for randomly reading the contents of a file.
+class RandomAccessFile {
+ public:
+ RandomAccessFile() {}
+ virtual ~RandomAccessFile();
+
+ /// \brief Reads up to "n" bytes from the file starting at "offset".
+ ///
+ /// "scratch[0..n-1]" may be written by this routine. Sets "*result"
+ /// to the data that was read (including if fewer than "n" bytes were
+ /// successfully read). May set "*result" to point at data in
+ /// "scratch[0..n-1]", so "scratch[0..n-1]" must be live when
+ /// "*result" is used.
+ ///
+ /// On OK returned status: "n" bytes have been stored in "*result".
+ /// On non-OK returned status: [0..n] bytes have been stored in "*result".
+ ///
+ /// Returns OUT_OF_RANGE if fewer than n bytes were stored in "*result"
+ /// because of EOF.
+ ///
+ /// Safe for concurrent use by multiple threads.
+ virtual Status Read(uint64 offset, size_t n, StringPiece* result,
+ char* scratch) const = 0;
+
+ private:
+ /// No copying allowed
+ RandomAccessFile(const RandomAccessFile&);
+ void operator=(const RandomAccessFile&);
+};
+
+/// \brief A file abstraction for sequential writing.
+///
+/// The implementation must provide buffering since callers may append
+/// small fragments at a time to the file.
+class WritableFile {
+ public:
+ WritableFile() {}
+ virtual ~WritableFile();
+
+ virtual Status Append(const StringPiece& data) = 0;
+ virtual Status Close() = 0;
+ virtual Status Flush() = 0;
+ virtual Status Sync() = 0;
+
+ private:
+ /// No copying allowed
+ WritableFile(const WritableFile&);
+ void operator=(const WritableFile&);
+};
+
+/// \brief An implementation of Env that forwards all calls to another Env.
+///
+/// May be useful to clients who wish to override just part of the
+/// functionality of another Env.
+class EnvWrapper : public Env {
+ public:
+ /// Initializes an EnvWrapper that delegates all calls to *t
+ explicit EnvWrapper(Env* t) : target_(t) {}
+ virtual ~EnvWrapper();
+
+ /// Returns the target to which this Env forwards all calls
+ Env* target() const { return target_; }
+
+ // The following text is boilerplate that forwards all methods to target()
+ Status NewRandomAccessFile(const string& f,
+ RandomAccessFile** r) override {
+ return target_->NewRandomAccessFile(f, r);
+ }
+ Status NewWritableFile(const string& f, WritableFile** r) override {
+ return target_->NewWritableFile(f, r);
+ }
+ Status NewAppendableFile(const string& f, WritableFile** r) override {
+ return target_->NewAppendableFile(f, r);
+ }
+ bool FileExists(const string& f) override { return target_->FileExists(f); }
+ Status GetChildren(const string& dir, std::vector<string>* r) override {
+ return target_->GetChildren(dir, r);
+ }
+ Status DeleteFile(const string& f) override {
+ return target_->DeleteFile(f);
+ }
+ Status CreateDir(const string& d) override {
+ return target_->CreateDir(d);
+ }
+ Status DeleteDir(const string& d) override {
+ return target_->DeleteDir(d);
+ }
+ Status GetFileSize(const string& f, uint64* s) override {
+ return target_->GetFileSize(f, s);
+ }
+ Status RenameFile(const string& s, const string& t) override {
+ return target_->RenameFile(s, t);
+ }
+ uint64 NowMicros() override { return target_->NowMicros(); }
+ void SleepForMicroseconds(int micros) override {
+ target_->SleepForMicroseconds(micros);
+ }
+ Thread* StartThread(const ThreadOptions& thread_options, const string& name,
+ std::function<void()> fn) override {
+ return target_->StartThread(thread_options, name, fn);
+ }
+
+ private:
+ Env* target_;
+};
+
+class Thread {
+ public:
+ Thread() {}
+
+ /// Blocks until the thread of control stops running.
+ virtual ~Thread();
+
+ private:
+ /// No copying allowed
+ Thread(const Thread&);
+ void operator=(const Thread&);
+};
+
+/// \brief Options to configure a Thread.
+///
+/// Note that the options are all hints, and the
+/// underlying implementation may choose to ignore it.
+struct ThreadOptions {
+ /// Thread stack size to use (in bytes).
+ size_t stack_size = 0; // 0: use system default value
+ /// Guard area size to use near thread stacks to use (in bytes)
+ size_t guard_size = 0; // 0: use system default value
+};
+
+/// A utility routine: reads contents of named file into *data
+Status ReadFileToString(Env* env, const string& fname, string* data);
+
+/// A utility routine: write contents of "data" to file named "fname"
+/// (overwriting existing contents, if any).
+Status WriteStringToFile(Env* env, const string& fname,
+ const StringPiece& data);
+
+/// Reads contents of named file and parse as binary encoded proto data
+/// and store into *proto.
+Status ReadBinaryProto(Env* env, const string& fname,
+ ::tensorflow::protobuf::MessageLite* proto);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PUBLIC_ENV_H_
diff --git a/tensorflow/core/public/session.h b/tensorflow/core/public/session.h
new file mode 100644
index 0000000000..a33d5ee6ae
--- /dev/null
+++ b/tensorflow/core/public/session.h
@@ -0,0 +1,125 @@
+#ifndef TENSORFLOW_PUBLIC_SESSION_H_
+#define TENSORFLOW_PUBLIC_SESSION_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+
+namespace tensorflow {
+
+/// \brief A Session instance lets a caller drive a TensorFlow graph
+/// computation.
+///
+/// When a Session is created with a given target, a new Session object
+/// is bound to the universe of resources specified by that target.
+/// Those resources are available to this session to perform
+/// computation described in the GraphDef. After extending the session
+/// with a graph, the caller uses the Run() API to perform the
+/// computation and potentially fetch outputs as Tensors.
+///
+/// Example:
+///
+/// tensorflow::GraphDef graph;
+/// // ... Create or load graph into 'graph'.
+///
+/// // This example uses the default options which connects
+/// // to a local runtime.
+/// tensorflow::SessionOptions options;
+/// std::unique_ptr<tensorflow::Session>
+/// session(tensorflow::NewSession(options));
+///
+/// // Create the session with this graph.
+/// tensorflow::Status s = session->Create(graph);
+/// if (!s.ok()) { ... }
+///
+/// // Run the graph and fetch the first output of the "output"
+/// // operation, and also run to but do not return anything
+/// // for the "update_state" operation.
+/// std::vector<tensorflow::Tensor> outputs;
+/// s = session->Run({}, {"output:0"}, {"update_state"}, &outputs);
+/// if (!s.ok()) { ... }
+///
+/// // Map the output as a flattened float tensor, and do something
+/// // with it.
+/// auto output_tensor = outputs[0].flat<float>();
+/// if (output_tensor(0) > 0.5) { ... }
+///
+/// // Close the session to release the resources associated with
+/// // this session.
+/// session->Close()
+///
+/// A Session allows concurrent calls to Run(), though a Session must
+/// be created / extended by a single thread.
+///
+/// Only one thread must call Close(), and Close() must only be called
+/// after all other calls to Run() have returned.
+class Session {
+ public:
+ /// \brief Create the graph to be used for the session.
+ ///
+ /// Returns an error if this session has already been created with a
+ /// graph. To re-use the session with a different graph, the caller
+ /// must Close() the session first.
+ virtual Status Create(const GraphDef& graph) = 0;
+
+ /// \brief Adds operations to the graph that is already registered with the
+ /// Session.
+ ///
+ /// The names of new operations in "graph" must not exist in the
+ /// graph that is already registered.
+ virtual Status Extend(const GraphDef& graph) = 0;
+
+ /// \brief Runs the graph with the provided input tensors and fills
+ /// 'outputs' for the endpoints specified in 'output_tensor_names'.
+ /// Runs to but does not return Tensors for the nodes in
+ /// 'target_node_names'.
+ ///
+ /// The order of tensors in 'outputs' will match the order provided
+ /// by 'output_tensor_names'.
+ ///
+ /// If Run returns OK(), then outputs->size() will be equal to
+ /// output_tensor_names.size(). If Run does not return OK(), the
+ /// state of outputs is undefined.
+ ///
+ /// REQUIRES: The name of each Tensor of the input or output must
+ /// match a "Tensor endpoint" in the GraphDef passed to Create().
+ ///
+ /// REQUIRES: outputs is not nullptr if output_tensor_names is non-empty.
+ virtual Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) = 0;
+
+ /// \brief Closes this session.
+ ///
+ /// Closing a session releases the resources used by this session
+ /// on the TensorFlow runtime (specified during session creation by
+ /// the 'SessionOptions::target' field).
+ virtual Status Close() = 0;
+
+ virtual ~Session() {}
+};
+
+/// \brief Create a new session with the given options.
+///
+/// If a new session object could not be created, this function will
+/// return nullptr.
+Session* NewSession(const SessionOptions& options);
+
+/// \brief Create a new session with the given options.
+///
+/// If session creation succeeds, the new Session will be stored in
+/// *out_session, the caller will take ownership of the returned
+/// *out_session, and this function will return OK(). Otherwise, this
+/// function will return an error status.
+Status NewSession(const SessionOptions& options, Session** out_session);
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_PUBLIC_SESSION_H_
diff --git a/tensorflow/core/public/session_options.h b/tensorflow/core/public/session_options.h
new file mode 100644
index 0000000000..11d52426ac
--- /dev/null
+++ b/tensorflow/core/public/session_options.h
@@ -0,0 +1,50 @@
+#ifndef TENSORFLOW_PUBLIC_SESSION_OPTIONS_H_
+#define TENSORFLOW_PUBLIC_SESSION_OPTIONS_H_
+
+#include <string>
+#include "tensorflow/core/framework/config.pb.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+class Env;
+
+/// Configuration information for a Session.
+struct SessionOptions {
+ /// The environment to use.
+ Env* env;
+
+ /// \brief The TensorFlow runtime to connect to.
+ ///
+ /// If 'target' is empty or unspecified, the local TensorFlow runtime
+ /// implementation will be used. Otherwise, the TensorFlow engine
+ /// defined by 'target' will be used to perform all computations.
+ ///
+ /// "target" can be either a single entry or a comma separated list
+ /// of entries. Each entry is a resolvable address of the
+ /// following format:
+ /// local
+ /// ip:port
+ /// host:port
+ /// ... other system-specific formats to identify tasks and jobs ...
+ ///
+ /// NOTE: at the moment 'local' maps to an in-process service-based
+ /// runtime.
+ ///
+ /// Upon creation, a single session affines itself to one of the
+ /// remote processes, with possible load balancing choices when the
+ /// "target" resolves to a list of possible processes.
+ ///
+ /// If the session disconnects from the remote process during its
+ /// lifetime, session calls may fail immediately.
+ string target;
+
+ /// Configuration options.
+ ConfigProto config;
+
+ SessionOptions();
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_PUBLIC_SESSION_OPTIONS_H_
diff --git a/tensorflow/core/public/status.h b/tensorflow/core/public/status.h
new file mode 100644
index 0000000000..d0405b8876
--- /dev/null
+++ b/tensorflow/core/public/status.h
@@ -0,0 +1,96 @@
+#ifndef TENSORFLOW_PUBLIC_STATUS_H_
+#define TENSORFLOW_PUBLIC_STATUS_H_
+
+#include <iosfwd>
+#include <string>
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+class Status {
+ public:
+ /// Create a success status.
+ Status() : state_(NULL) {}
+ ~Status();
+
+ /// \brief Create a status with the specified error code and msg as a
+ /// human-readable string containing more detailed information.
+ Status(tensorflow::error::Code code, tensorflow::StringPiece msg);
+
+ /// Copy the specified status.
+ Status(const Status& s);
+ void operator=(const Status& s);
+
+ static Status OK() { return Status(); }
+
+ /// Returns true iff the status indicates success.
+ bool ok() const { return (state_ == NULL); }
+
+ tensorflow::error::Code code() const {
+ return ok() ? tensorflow::error::OK : state_->code;
+ }
+
+ const string& error_message() const {
+ return ok() ? empty_string() : state_->msg;
+ }
+
+ bool operator==(const Status& x) const;
+ bool operator!=(const Status& x) const;
+
+ /// \brief If "ok()", stores "new_status" into *this. If "!ok()", preserves
+ /// the current status, but may augment with additional information
+ /// about "new_status".
+ ///
+ /// Convenient way of keeping track of the first error encountered.
+ /// Instead of:
+ /// if (overall_status.ok()) overall_status = new_status
+ /// Use:
+ /// overall_status.Update(new_status);
+ void Update(const Status& new_status);
+
+ /// \brief Return a string representation of this status suitable for
+ /// printing. Returns the string "OK" for success.
+ string ToString() const;
+
+ private:
+ static const string& empty_string();
+ struct State {
+ tensorflow::error::Code code;
+ string msg;
+ };
+ /// OK status has a NULL state_. Otherwise, state_ points to
+ /// a State structure containing the error code and message(s)
+ State* state_;
+
+ void SlowCopyFrom(const State* src);
+};
+
+inline Status::Status(const Status& s)
+ : state_((s.state_ == NULL) ? NULL : new State(*s.state_)) {}
+
+inline void Status::operator=(const Status& s) {
+ /// The following condition catches both aliasing (when this == &s),
+ /// and the common case where both s and *this are ok.
+ if (state_ != s.state_) {
+ SlowCopyFrom(s.state_);
+ }
+}
+
+inline bool Status::operator==(const Status& x) const {
+ return (this->state_ == x.state_) || (ToString() == x.ToString());
+}
+
+inline bool Status::operator!=(const Status& x) const { return !(*this == x); }
+
+std::ostream& operator<<(std::ostream& os, const Status& x);
+
+typedef std::function<void(const Status&)> StatusCallback;
+
+#define TF_CHECK_OK(val) CHECK_EQ(::tensorflow::Status::OK(), (val))
+#define TF_QCHECK_OK(val) QCHECK_EQ(::tensorflow::Status::OK(), (val))
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PUBLIC_STATUS_H_
diff --git a/tensorflow/core/public/tensor.h b/tensorflow/core/public/tensor.h
new file mode 100644
index 0000000000..6c6ff0f58a
--- /dev/null
+++ b/tensorflow/core/public/tensor.h
@@ -0,0 +1,472 @@
+#ifndef TENSORFLOW_PUBLIC_TENSOR_H_
+#define TENSORFLOW_PUBLIC_TENSOR_H_
+
+#include "tensorflow/core/framework/allocation_description.pb.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_description.pb.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+class TensorBuffer; // Forward declaration.
+class TensorCApi;
+
+/// Represents an n-dimensional array of values.
+class Tensor {
+ public:
+ /// Default Tensor constructor. Creates a 1-dimension, 0-element float tensor.
+ Tensor();
+
+ /// \brief Creates a Tensor of the given datatype and shape.
+ ///
+ /// The underlying buffer is allocated using a CPUAllocator.
+ Tensor(DataType type, const TensorShape& shape);
+
+ /// \brief Creates a tensor with the input datatype and shape, using the
+ /// allocator 'a' to allocate the underlying buffer.
+ ///
+ /// 'a' must outlive the lifetime of this Tensor.
+ Tensor(Allocator* a, DataType type, const TensorShape& shape);
+
+ /// Creates an uninitialized Tensor of the given data type.
+ explicit Tensor(DataType type);
+
+ Tensor(const Tensor& other); /// Copy constructor.
+
+ ~Tensor();
+
+ /// Returns the data type.
+ DataType dtype() const { return type_; }
+
+ /// Returns the shape of the tensor.
+ const TensorShape& shape() const { return shape_; }
+
+ /// \brief Convenience accessor for the tensor shape.
+ ///
+ /// For all shape accessors, see comments for relevant methods of
+ /// TensorShape in tensor_shape.h.
+ int dims() const { return shape().dims(); }
+
+ /// Convenience accessor for the tensor shape.
+ int64 dim_size(int d) const { return shape().dim_size(d); }
+
+ /// Convenience accessor for the tensor shape.
+ int64 NumElements() const { return shape().num_elements(); }
+
+ bool IsSameSize(const Tensor& b) const {
+ return shape().IsSameSize(b.shape());
+ }
+
+ /// Has this Tensor been initialized?
+ bool IsInitialized() const;
+
+ /// Returns the estimated memory usage of this tensor.
+ size_t TotalBytes() const;
+
+ /// Assign operator. This tensor shares other's underlying storage.
+ Tensor& operator=(const Tensor& other) {
+ CopyFromInternal(other, other.shape());
+ return *this;
+ }
+
+ /// \brief Copy the other tensor into this tensor and reshape it.
+ ///
+ /// This tensor shares other's underlying storage. Returns
+ /// true iff other.shape() has the same number of elements of the
+ /// given "shape".
+ bool CopyFrom(const Tensor& other,
+ const TensorShape& shape) TF_MUST_USE_RESULT {
+ if (other.NumElements() != shape.num_elements()) return false;
+ CopyFromInternal(other, shape);
+ return true;
+ }
+
+ /// \brief Slice this tensor along the 1st dimension.
+
+ /// I.e., the returned
+ /// tensor satisifies returned[i, ...] == this[dim0_start + i, ...].
+ /// The returned tensor shares the underlying tensor buffer with this
+ /// tensor.
+ ///
+ /// NOTE: The returned tensor may not satisfies the same alignment
+ /// requirement as this tensor depending on the shape. The caller
+ /// must check the returned tensor's alignment before calling certain
+ /// methods that have alignment requirement (e.g., flat(), tensor()).
+ ///
+ /// REQUIRES: dims() >= 1
+ /// REQUIRES: 0 <= dim0_start <= dim0_limit <= dim_size(0)
+ Tensor Slice(int64 dim0_start, int64 dim0_limit) const;
+
+ /// \brief Parse "other' and construct the tensor.
+
+ /// Returns true iff the
+ /// parsing succeeds. If the parsing fails, the state of "*this" is
+ /// unchanged.
+ bool FromProto(const TensorProto& other) TF_MUST_USE_RESULT;
+ bool FromProto(Allocator* a, const TensorProto& other) TF_MUST_USE_RESULT;
+
+ /// \brief Fills in "proto" with "*this" tensor's content.
+ ///
+ /// AsProtoField() fills in the repeated field for proto.dtype(), while
+ /// AsProtoTensorContent() encodes the content in proto.tensor_content() in a
+ /// compact form.
+ void AsProtoField(TensorProto* proto) const;
+ void AsProtoTensorContent(TensorProto* proto) const;
+
+ /// \brief Return the Tensor data as an Eigen::Tensor with the type and
+ /// sizes of this Tensor.
+ ///
+ /// Use these methods when you know the data type and the number of
+ /// dimensions of the Tensor and you want an Eigen::Tensor
+ /// automatically sized to the Tensor sizes. The implementation check
+ /// fails if either type or sizes mismatch.
+ ///
+ /// Example:
+ /// typedef float T;
+ /// Tensor my_mat(...built with Shape{rows: 3, cols: 5}...);
+ /// auto mat = my_mat.matrix<T>(); // 2D Eigen::Tensor, 3 x 5.
+ /// auto mat = my_mat.tensor<T, 2>(); // 2D Eigen::Tensor, 3 x 5.
+ /// auto vec = my_mat.vec<T>(); // CHECK fails as my_mat is 2D.
+ /// auto vec = my_mat.tensor<T, 3>(); // CHECK fails as my_mat is 2D.
+ /// auto mat = my_mat.matrix<int32>();// CHECK fails as type mismatch.
+ template <typename T>
+ typename TTypes<T>::Vec vec() {
+ return tensor<T, 1>();
+ }
+
+ template <typename T>
+ typename TTypes<T>::Matrix matrix() {
+ return tensor<T, 2>();
+ }
+
+ template <typename T, size_t NDIMS>
+ typename TTypes<T, NDIMS>::Tensor tensor();
+
+ /// \brief Return the Tensor data as an Eigen::Tensor of the data type and a
+ /// specified shape.
+ ///
+ /// These methods allow you to access the data with the dimensions
+ /// and sizes of your choice. You do not need to know the number of
+ /// dimensions of the Tensor to call them. However, they CHECK that
+ /// the type matches and the dimensions requested creates an
+ /// Eigen::Tensor with the same number of elements as the Tensor.
+ ///
+ /// Example:
+ /// typedef float T;
+ /// Tensor my_ten(...built with Shape{planes: 4, rows: 3, cols: 5}...);
+ /// // 1D Eigen::Tensor, size 60:
+ /// auto flat = my_ten.flat<T>();
+ /// // 2D Eigen::Tensor 12 x 5:
+ /// auto inner = my_ten.flat_inner_dims<T>();
+ /// // 2D Eigen::Tensor 4 x 15:
+ /// auto outer = my_ten.shaped<T, 2>({4, 15});
+ /// // CHECK fails, bad num elements:
+ /// auto outer = my_ten.shaped<T, 2>({4, 8});
+ /// // 3D Eigen::Tensor 6 x 5 x 2:
+ /// auto weird = my_ten.shaped<T, 3>({6, 5, 2});
+ /// // CHECK fails, type mismatch:
+ /// auto bad = my_ten.flat<int32>();
+ template <typename T>
+ typename TTypes<T>::Flat flat() {
+ return shaped<T, 1>({NumElements()});
+ }
+
+ template <typename T>
+ typename TTypes<T>::UnalignedFlat unaligned_flat() {
+ return unaligned_shaped<T, 1>({NumElements()});
+ }
+
+ /// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all
+ /// Tensor dimensions but the last one into the first dimension of the result.
+ template <typename T>
+ typename TTypes<T>::Matrix flat_inner_dims() {
+ int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1;
+ if (last_size == 0) {
+ DCHECK_EQ(NumElements(), 0);
+ // Return something empty, avoiding divide by 0
+ return shaped<T, 2>({0, 0});
+ } else {
+ return shaped<T, 2>({NumElements() / last_size, last_size});
+ }
+ }
+
+ /// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all
+ /// Tensor dimensions but the first one into the last dimension of the result.
+ template <typename T>
+ typename TTypes<T>::Matrix flat_outer_dims() {
+ int64 first_size = dims() > 0 ? dim_size(0) : 1;
+ if (first_size == 0) {
+ DCHECK_EQ(NumElements(), 0);
+ // Return something empty, avoiding divide by 0
+ return shaped<T, 2>({0, 0});
+ } else {
+ return shaped<T, 2>({first_size, NumElements() / first_size});
+ }
+ }
+
+ template <typename T, size_t NDIMS>
+ typename TTypes<T, NDIMS>::Tensor shaped(gtl::ArraySlice<int64> new_sizes);
+
+ template <typename T, size_t NDIMS>
+ typename TTypes<T, NDIMS>::UnalignedTensor unaligned_shaped(
+ gtl::ArraySlice<int64> new_sizes);
+
+ /// \brief Return the Tensor data as a Tensor Map of fixed size 1:
+ /// TensorMap<TensorFixedSize<T, 1>>.
+
+ /// Using scalar() allows the compiler to
+ /// perform optimizations as the size of the tensor is known at compile time.
+ template <typename T>
+ typename TTypes<T>::Scalar scalar();
+
+ /// Const versions of all the methods above.
+ template <typename T>
+ typename TTypes<T>::ConstVec vec() const {
+ return tensor<T, 1>();
+ }
+
+ template <typename T>
+ typename TTypes<T>::ConstMatrix matrix() const {
+ return tensor<T, 2>();
+ }
+
+ template <typename T, size_t NDIMS>
+ typename TTypes<T, NDIMS>::ConstTensor tensor() const;
+
+ template <typename T>
+ typename TTypes<T>::ConstFlat flat() const {
+ return shaped<T, 1>({NumElements()});
+ }
+
+ template <typename T>
+ typename TTypes<T>::UnalignedConstFlat unaligned_flat() const {
+ return unaligned_shaped<T, 1>({NumElements()});
+ }
+
+ template <typename T>
+ typename TTypes<T>::ConstMatrix flat_inner_dims() const {
+ int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1;
+ if (last_size == 0) {
+ DCHECK_EQ(NumElements(), 0);
+ // Return something empty, avoiding divide by 0
+ return shaped<T, 2>({0, 0});
+ } else {
+ return shaped<T, 2>({NumElements() / last_size, last_size});
+ }
+ }
+
+ template <typename T>
+ typename TTypes<T>::ConstMatrix flat_outer_dims() const {
+ int64 first_size = dims() > 0 ? dim_size(0) : 1;
+ if (first_size == 0) {
+ DCHECK_EQ(NumElements(), 0);
+ // Return something empty, avoiding divide by 0
+ return shaped<T, 2>({0, 0});
+ } else {
+ return shaped<T, 2>({first_size, NumElements() / first_size});
+ }
+ }
+
+ template <typename T, size_t NDIMS>
+ typename TTypes<T, NDIMS>::ConstTensor shaped(
+ gtl::ArraySlice<int64> new_sizes) const;
+ template <typename T, size_t NDIMS>
+ typename TTypes<T, NDIMS>::UnalignedConstTensor unaligned_shaped(
+ gtl::ArraySlice<int64> new_sizes) const;
+
+ template <typename T>
+ typename TTypes<T>::ConstScalar scalar() const;
+
+ /// Render the first max_entries values in *this into a string.
+ string SummarizeValue(int64 max_entries) const;
+
+ /// A human-readable summary of the Tensor suitable for debugging.
+ string DebugString() const;
+
+ /// Fill in the TensorDescription proto with metadata about the
+ /// Tensor that is useful for monitoring and debugging.
+ void FillDescription(TensorDescription* description) const;
+
+ /// \brief Returns a StringPiece mapping the current tensor's buffer.
+ ///
+ /// The returned StringPiece may point to memory location on devices
+ /// that the CPU cannot address directly.
+ ///
+ /// NOTE: The underlying Tensor buffer is refcounted, so the lifetime
+ /// of the contents mapped by the StringPiece matches the lifetime of
+ /// the buffer; callers should arrange to make sure the buffer does
+ /// not get destroyed while the StringPiece is still used.
+ ///
+ /// REQUIRES: DataTypeCanUseMemcpy(dtype()).
+ StringPiece tensor_data() const;
+
+ private:
+ DataType type_;
+ TensorShape shape_;
+ TensorBuffer* buf_;
+
+ friend class DMAHelper;
+ friend class TensorCApi;
+ friend class VariableOp; // For access to set_shape
+ friend class AutoReloadVariableOp; // For access to set_shape
+
+ // Creates a tensor with the input datatype, shape and buf.
+ //
+ // Acquires a ref on buf that belongs to this Tensor.
+ Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf);
+
+ bool CanUseDMA() const;
+
+ // Only needed by variable op to set the shape of an uninitialized
+ // Tensor.
+ // TODO: Remove this when we have a better story for detecting
+ // uninitialized tensors.
+ void set_shape(const TensorShape& shape) { shape_ = shape; }
+
+ void CopyFromInternal(const Tensor& other, const TensorShape& shape);
+
+ template <typename T>
+ T* base() const;
+};
+
+// Implementation details
+
+// Interface to access the raw ref-counted data buffer.
+class TensorBuffer : public core::RefCounted {
+ public:
+ ~TensorBuffer() override {}
+
+ // data() points to a memory region of size() bytes.
+ virtual void* data() const = 0;
+ virtual size_t size() const = 0;
+
+ // If this TensorBuffer is sub-buffer of another TensorBuffer,
+ // returns that TensorBuffer. Otherwise, returns this.
+ virtual TensorBuffer* root_buffer() = 0;
+
+ // Fill metadata about the allocation into the proto.
+ virtual void FillAllocationDescription(
+ AllocationDescription* proto) const = 0;
+
+ template <typename T>
+ T* base() const {
+ return reinterpret_cast<T*>(data());
+ }
+};
+
+inline void CheckEigenAlignment(const void* ptr) {
+#if EIGEN_ALIGN == 1
+ CHECK_EQ(reinterpret_cast<intptr_t>(ptr) % EIGEN_ALIGN_BYTES, 0);
+#endif
+}
+
+template <typename T>
+T* Tensor::base() const {
+ return buf_ == nullptr ? nullptr : buf_->base<T>();
+}
+
+template <typename T, size_t NDIMS>
+typename TTypes<T, NDIMS>::Tensor Tensor::tensor() {
+ CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
+ CheckEigenAlignment(base<T>());
+ return typename TTypes<T, NDIMS>::Tensor(base<T>(),
+ shape().AsEigenDSizes<NDIMS>());
+}
+
+template <typename T, size_t NDIMS>
+typename TTypes<T, NDIMS>::ConstTensor Tensor::tensor() const {
+ CheckEigenAlignment(base<T>());
+ CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
+ return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(),
+ shape().AsEigenDSizes<NDIMS>());
+}
+
+template <typename T, size_t NDIMS>
+typename TTypes<T, NDIMS>::Tensor Tensor::shaped(
+ gtl::ArraySlice<int64> new_sizes) {
+ CheckEigenAlignment(base<T>());
+ CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
+ CHECK_EQ(NDIMS, new_sizes.size());
+ int64 new_num_elements = 1;
+ Eigen::array<Eigen::DenseIndex, NDIMS> dims;
+ for (int d = 0; d < NDIMS; d++) {
+ new_num_elements *= new_sizes[d];
+ dims[d] = new_sizes[d];
+ }
+ CHECK_EQ(new_num_elements, NumElements());
+ return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims);
+}
+
+template <typename T, size_t NDIMS>
+typename TTypes<T, NDIMS>::UnalignedTensor Tensor::unaligned_shaped(
+ gtl::ArraySlice<int64> new_sizes) {
+ CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
+ CHECK_EQ(NDIMS, new_sizes.size());
+ int64 new_num_elements = 1;
+ Eigen::array<Eigen::DenseIndex, NDIMS> dims;
+ for (int d = 0; d < NDIMS; d++) {
+ new_num_elements *= new_sizes[d];
+ dims[d] = new_sizes[d];
+ }
+ CHECK_EQ(new_num_elements, NumElements());
+ return typename TTypes<T, NDIMS>::UnalignedTensor(base<T>(), dims);
+}
+
+template <typename T, size_t NDIMS>
+typename TTypes<T, NDIMS>::ConstTensor Tensor::shaped(
+ gtl::ArraySlice<int64> new_sizes) const {
+ CheckEigenAlignment(base<T>());
+ CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
+ CHECK_EQ(NDIMS, new_sizes.size());
+ int64 new_num_elements = 1;
+ Eigen::array<Eigen::DenseIndex, NDIMS> dims;
+ for (int d = 0; d < NDIMS; d++) {
+ new_num_elements *= new_sizes[d];
+ dims[d] = new_sizes[d];
+ }
+ CHECK_EQ(new_num_elements, NumElements());
+ return typename TTypes<T, NDIMS>::ConstTensor(base<T>(), dims);
+}
+
+template <typename T, size_t NDIMS>
+typename TTypes<T, NDIMS>::UnalignedConstTensor Tensor::unaligned_shaped(
+ gtl::ArraySlice<int64> new_sizes) const {
+ CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
+ CHECK_EQ(NDIMS, new_sizes.size());
+ int64 new_num_elements = 1;
+ Eigen::array<Eigen::DenseIndex, NDIMS> dims;
+ for (int d = 0; d < NDIMS; d++) {
+ new_num_elements *= new_sizes[d];
+ dims[d] = new_sizes[d];
+ }
+ CHECK_EQ(new_num_elements, NumElements());
+ return typename TTypes<T, NDIMS>::UnalignedConstTensor(base<T>(), dims);
+}
+
+template <typename T>
+typename TTypes<T>::Scalar Tensor::scalar() {
+ CheckEigenAlignment(base<T>());
+ CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
+ return typename TTypes<T>::Scalar(base<T>());
+}
+
+template <typename T>
+typename TTypes<T>::ConstScalar Tensor::scalar() const {
+ CheckEigenAlignment(base<T>());
+ CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
+ return typename TTypes<T>::ConstScalar(base<T>());
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PUBLIC_TENSOR_H_
diff --git a/tensorflow/core/public/tensor_c_api.h b/tensorflow/core/public/tensor_c_api.h
new file mode 100644
index 0000000000..fe1846319e
--- /dev/null
+++ b/tensorflow/core/public/tensor_c_api.h
@@ -0,0 +1,243 @@
+// TODO(jeff,sanjay): Rename to tensorflow/public/c_api.h
+#ifndef TENSORFLOW_PUBLIC_TENSOR_C_API_H_
+#define TENSORFLOW_PUBLIC_TENSOR_C_API_H_
+
+#include <stddef.h>
+
+// --------------------------------------------------------------------------
+// C API for TensorFlow.
+//
+// The API leans towards simplicity and uniformity instead of convenience
+// since most usage will be by language specific wrappers.
+//
+// Conventions:
+// * We use the prefix TF_ for everything in the API.
+// * Objects are always passed around as pointers to opaque structs
+// and these structs are allocated/deallocated via the API.
+// * TF_Status holds error information. It is an object type
+// and threfore is passed around as a pointer to an opaque
+// struct as mentioned above.
+// * Every call that has a TF_Status* argument clears it on success
+// and fills it with error info on failure.
+//
+// Questions left to address:
+// * Might need to add stride info to TF_Tensor?
+// * Might at some point need a way for callers to provide their own Env.
+// * Should we remove the TF_Status arg from TF_AddProto calls and only
+// report errors later (e.g., on Run call).
+// * Should dimensions be unsigned instead of signed?
+// * Maybe add TF_TensorShape that encapsulates dimension info.
+//
+// Design decisions made:
+// * Backing store for tensor memory has an associated deallocation
+// function. This deallocation function will point to client code
+// for tensors populated by the client. So the client can do things
+// like shadowing a numpy array.
+// * We do not provide TF_OK since it is not strictly necessary and we
+// are not optimizing for convenience.
+// * We make assumption that one session has one graph. This should be
+// fine since we have the ability to run sub-graphs.
+// * We are not providing TF_AddNode/TF_AddNodes to better support
+// languages/platforms where proto is not available. This is because
+// we can just point authors of bindings at the .proto file and the
+// proto serialization spec and they can do the right thing for
+// their language.
+// * We could allow NULL for some arguments (e.g., NULL options arg).
+// However since convenience is not a primary goal, we don't do this.
+// * Devices are not in this API. Instead, they are created/used internally
+// and the API just provides high level controls over the number of
+// devices of each type.
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// --------------------------------------------------------------------------
+// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor.
+// The enum values here are identical to corresponding values in types.proto.
+typedef enum {
+ TF_FLOAT = 1,
+ TF_DOUBLE = 2,
+ TF_INT32 = 3, // Int32 tensors are always in 'host' memory.
+ TF_UINT8 = 4,
+ TF_INT16 = 5,
+ TF_INT8 = 6,
+ TF_STRING = 7,
+ TF_COMPLEX = 8, // Single-precision complex
+ TF_INT64 = 9,
+ TF_BOOL = 10,
+ TF_QINT8 = 11, // Quantized int8
+ TF_QUINT8 = 12, // Quantized uint8
+ TF_QINT32 = 13, // Quantized int32
+ TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops.
+} TF_DataType;
+
+// --------------------------------------------------------------------------
+// TF_Code holds an error code. The enum values here are identical to
+// corresponding values in error_codes.proto.
+typedef enum {
+ TF_OK = 0,
+ TF_CANCELLED = 1,
+ TF_UNKNOWN = 2,
+ TF_INVALID_ARGUMENT = 3,
+ TF_DEADLINE_EXCEEDED = 4,
+ TF_NOT_FOUND = 5,
+ TF_ALREADY_EXISTS = 6,
+ TF_PERMISSION_DENIED = 7,
+ TF_UNAUTHENTICATED = 16,
+ TF_RESOURCE_EXHAUSTED = 8,
+ TF_FAILED_PRECONDITION = 9,
+ TF_ABORTED = 10,
+ TF_OUT_OF_RANGE = 11,
+ TF_UNIMPLEMENTED = 12,
+ TF_INTERNAL = 13,
+ TF_UNAVAILABLE = 14,
+ TF_DATA_LOSS = 15,
+} TF_Code;
+
+// --------------------------------------------------------------------------
+// TF_Status holds error information. It either has an OK code, or
+// else an error code with an associated error message.
+typedef struct TF_Status TF_Status;
+
+// Return a new status object.
+extern TF_Status* TF_NewStatus();
+
+// Delete a previously created status object.
+extern void TF_DeleteStatus(TF_Status*);
+
+// Record <code, msg> in *s. Any previous information is lost.
+// A common use is to clear a status: TF_SetStatus(s, TF_OK, "");
+extern void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg);
+
+// Return the code record in *s.
+extern TF_Code TF_GetCode(const TF_Status* s);
+
+// Return a pointer to the error message in *s. The return value
+// points to memory that is only usable until the next mutation to *s.
+// Always returns an empty string if TF_GetCode(s) is TF_OK.
+extern const char* TF_Message(const TF_Status* s);
+
+// --------------------------------------------------------------------------
+// TF_Tensor holds a multi-dimensional array of elements of a single data type.
+// For all types other than TF_STRING, the data buffer stores elements
+// in row major order. E.g. if data is treated as a vector of TF_DataType:
+//
+// element 0: index (0, ..., 0)
+// element 1: index (0, ..., 1)
+// ...
+//
+// TODO(jeff,sanjay): Define format for TF_STRING tensors. Perhaps:
+// start_offset: array[uint64]
+// data: byte[...]
+//
+// String length is encoded (varint?) starting at data[start_offset[i]]
+// String contents follow immediately after string length.
+
+typedef struct TF_Tensor TF_Tensor;
+
+// Return a new tensor that holds the bytes data[0,len-1].
+//
+// The data will be deallocated by a subsequent call to TF_DeleteTensor via:
+// (*deallocator_fn)(data, len, deallocator_arg)
+// Clients can provide a custom deallocator function so they can pass in
+// memory managed by something like numpy.
+extern TF_Tensor* TF_NewTensor(TF_DataType, long long* dims, int num_dims,
+ void* data, size_t len,
+ void (*deallocator)(void* data, size_t len,
+ void* arg),
+ void* deallocator_arg);
+
+// Destroy a tensor.
+extern void TF_DeleteTensor(TF_Tensor*);
+
+// Return the type of a tensor element.
+extern TF_DataType TF_TensorType(const TF_Tensor*);
+
+// Return the number of dimensions that the tensor has.
+extern int TF_NumDims(const TF_Tensor*);
+
+// Return the length of the tensor in the "dim_index" dimension.
+// REQUIRES: 0 <= dim_index < TF_NumDims(tensor)
+extern long long TF_Dim(const TF_Tensor* tensor, int dim_index);
+
+// Return the size of the underlying data in bytes.
+extern size_t TF_TensorByteSize(const TF_Tensor*);
+
+// Return a pointer to the underlying data buffer.
+extern void* TF_TensorData(const TF_Tensor*);
+
+// --------------------------------------------------------------------------
+// TF_SessionOptions holds options that can be passed during session creation.
+typedef struct TF_SessionOptions TF_SessionOptions;
+
+// Return a new options object.
+extern TF_SessionOptions* TF_NewSessionOptions();
+
+// Set the target in TF_SessionOptions.options.
+// target can be empty, a single entry, or a comma separated list of entries.
+// Each entry is in one of the following formats :
+// "local"
+// ip:port
+// host:port
+extern void TF_SetTarget(TF_SessionOptions* options, const char* target);
+
+// Set the config in TF_SessionOptions.options.
+// config should be a serialized brain.ConfigProto proto.
+// If config was not parsed successfully as a ConfigProto, record the
+// error information in *status.
+extern void TF_SetConfig(TF_SessionOptions* options, const char* config,
+ size_t config_len, TF_Status* status);
+
+// Destroy an options object.
+extern void TF_DeleteSessionOptions(TF_SessionOptions*);
+
+// TODO(jeff,sanjay):
+// - export functions to set Config fields
+
+// --------------------------------------------------------------------------
+// TF_Session manages a single graph and execution.
+typedef struct TF_Session TF_Session;
+
+// Return a new execution session, or NULL on error.
+extern TF_Session* TF_NewSession(const TF_SessionOptions*, TF_Status* status);
+
+// Close a session.
+extern void TF_CloseSession(TF_Session*, TF_Status* status);
+
+// Destroy a session. Even if error information is recorded in *status,
+// this call discards all resources associated with the session.
+extern void TF_DeleteSession(TF_Session*, TF_Status* status);
+
+// Treat the bytes proto[0,proto_len-1] as a serialized GraphDef and
+// add the nodes in that GraphDef to the graph for the session.
+extern void TF_ExtendGraph(TF_Session*, const void* proto, size_t proto_len,
+ TF_Status*);
+
+// Run the graph associated with the session starting with the
+// supplied inputs (inputs[0,ninputs-1]). Regardless of success or
+// failure, inputs[] become the property of the implementation (the
+// implementation will eventually call TF_DeleteTensor on each input).
+//
+// On success, the tensors corresponding to output_names[0,noutputs-1]
+// are placed in outputs[]. and these outputs[] become the property
+// of the caller (the caller must eventually call TF_DeleteTensor on
+// them).
+//
+// On failure, outputs[] contains nulls.
+extern void TF_Run(TF_Session*,
+ // Input tensors
+ const char** input_names, TF_Tensor** inputs, int ninputs,
+ // Output tensors
+ const char** output_tensor_names, TF_Tensor** outputs,
+ int noutputs,
+ // Target nodes
+ const char** target_node_names, int ntargets,
+ // Output status
+ TF_Status*);
+
+#ifdef __cplusplus
+} /* end extern "C" */
+#endif
+
+#endif // TENSORFLOW_PUBLIC_TENSOR_C_API_H_
diff --git a/tensorflow/core/public/tensor_shape.h b/tensorflow/core/public/tensor_shape.h
new file mode 100644
index 0000000000..a889b8b17d
--- /dev/null
+++ b/tensorflow/core/public/tensor_shape.h
@@ -0,0 +1,239 @@
+#ifndef TENSORFLOW_PUBLIC_TENSOR_SHAPE_H_
+#define TENSORFLOW_PUBLIC_TENSOR_SHAPE_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+class TensorShapeIter; // Declared below
+
+/// Manages the dimensions of a Tensor and their sizes.
+class TensorShape {
+ public:
+ /// \brief Construct a TensorShape from the provided sizes..
+ /// REQUIRES: dim_sizes[i] >= 0
+ explicit TensorShape(gtl::ArraySlice<int64> dim_sizes);
+ TensorShape(std::initializer_list<int64> dim_sizes)
+ : TensorShape(gtl::ArraySlice<int64>(dim_sizes)) {}
+
+ /// REQUIRES: IsValid(proto)
+ explicit TensorShape(const TensorShapeProto& proto);
+
+ /// Create a tensor shape with no dimensions and one element, which you can
+ /// then call AddDim() on.
+ TensorShape();
+
+ /// Returns true iff "proto" is a valid tensor shape.
+ static bool IsValid(const TensorShapeProto& proto);
+
+ /// Clear a tensor shape
+ void Clear();
+
+ /// \brief Add a dimension to the end ("inner-most").
+ /// REQUIRES: size >= 0
+ void AddDim(int64 size);
+
+ /// Appends all the dimensions from shape.
+ void AppendShape(const TensorShape& shape);
+
+ /// \brief Insert a dimension somewhere in the TensorShape.
+ /// REQUIRES: "0 <= d <= dims()"
+ /// REQUIRES: size >= 0
+ void InsertDim(int d, int64 size);
+
+ /// \brief Modifies the size of the dimension 'd' to be 'size'
+ /// REQUIRES: "0 <= d < dims()"
+ /// REQUIRES: size >= 0
+ void set_dim(int d, int64 size);
+
+ /// \brief Removes dimension 'd' from the TensorShape.
+ /// REQUIRES: "0 <= d < dims()"
+ void RemoveDim(int d);
+
+ /// Return the number of dimensions in the tensor.
+ int dims() const { return dim_sizes_.size(); }
+
+ /// \brief Returns the number of elements in dimension "d".
+ /// REQUIRES: "0 <= d < dims()"
+ // TODO(mdevin): Rename to dimension() to match Eigen::Tensor::dimension()?
+ int64 dim_size(int d) const {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ return dim_sizes_[d];
+ }
+
+ /// Returns sizes of all dimensions.
+ gtl::ArraySlice<int64> dim_sizes() const { return dim_sizes_; }
+
+ /// \brief Returns the number of elements in the tensor.
+ ///
+ /// We use int64 and
+ /// not size_t to be compatible with Eigen::Tensor which uses ptr_fi
+ int64 num_elements() const { return num_elements_; }
+
+ /// Returns true if *this and b have the same sizes. Ignores dimension names.
+ bool IsSameSize(const TensorShape& b) const;
+ bool operator==(const TensorShape& b) const { return IsSameSize(b); }
+
+ /// Fill *proto from *this.
+ void AsProto(TensorShapeProto* proto) const;
+
+ /// Fill *dsizes from *this.
+ template <int NDIMS>
+ Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizes() const;
+
+ /// Same as AsEigenDSizes() but allows for NDIMS > dims() -- in which case we
+ /// pad the rest of the sizes with 1.
+ template <int NDIMS>
+ Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizesWithPadding() const;
+
+ /// For iterating through the dimensions.
+ TensorShapeIter begin() const;
+ TensorShapeIter end() const;
+
+ /// For error messages.
+ string DebugString() const;
+ // TODO(vrv): Remove this, this is the same as DebugString().
+ string ShortDebugString() const;
+
+ private:
+ /// Recalculates the dimensions of this tensor after they are modified.
+ void recompute_dims();
+
+ // TODO(josh11b): Maybe use something from the Eigen Tensor library
+ /// for the sizes.
+ gtl::InlinedVector<int64, 4> dim_sizes_;
+
+ /// total number of elements (avoids recomputing it each time).
+ int64 num_elements_;
+};
+
+struct TensorShapeDim {
+ explicit TensorShapeDim(int64 s) : size(s) {}
+ int size;
+};
+
+class TensorShapeIter {
+ public:
+ TensorShapeIter(const TensorShape* shape, int d) : shape_(shape), d_(d) {}
+ bool operator==(const TensorShapeIter& rhs) {
+ DCHECK(shape_ == rhs.shape_);
+ return d_ == rhs.d_;
+ }
+ bool operator!=(const TensorShapeIter& rhs) {
+ DCHECK(shape_ == rhs.shape_);
+ return d_ != rhs.d_;
+ }
+ void operator++() { ++d_; }
+ TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); }
+
+ private:
+ const TensorShape* shape_;
+ int d_;
+};
+
+// In some places, allow shape (1,) to be treated as a scalar and shape () to be
+// treated as a vector. This flag is for temporary backwards compatibility
+// only, and will be changed to strict within Google around November 15, 2015.
+#if defined(PLATFORM_GOOGLE)
+// TODO(irving): Become strict on November 15, 2015.
+static const bool kAllowLegacyScalars = true;
+#else
+// For open source (outside Google), we are strict.
+static const bool kAllowLegacyScalars = false;
+#endif
+
+/// \brief Static helper routines for TensorShape. Includes a few common
+/// predicates on a tensor shape.
+class TensorShapeUtils {
+ public:
+ static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; }
+
+ static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; }
+
+ // Allow either scalars or (if allowing legacy scalars) shape (1,).
+ static bool IsLegacyScalar(const TensorShape& shape) {
+ return shape.dims() == 0 ||
+ (kAllowLegacyScalars && shape.dims() == 1 && shape.dim_size(0) == 1);
+ }
+
+ // Allow rank 1 or (if allowing legacy scalars) rank 0.
+ static bool IsLegacyVector(const TensorShape& shape) {
+ return shape.dims() == 1 || (kAllowLegacyScalars && shape.dims() == 0);
+ }
+
+ static bool IsVectorOrHigher(const TensorShape& shape) {
+ return shape.dims() >= 1;
+ }
+
+ static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; }
+
+ static bool IsMatrixOrHigher(const TensorShape& shape) {
+ return shape.dims() >= 2;
+ }
+
+ /// \brief Returns a TensorShape whose dimensions are dims[0], dims[1], ...,
+ /// dims[n-1].
+ template <typename T>
+ static TensorShape MakeShape(const T* dims, int n) {
+ TensorShape shape;
+ for (int i = 0; i < n; ++i) shape.AddDim(dims[i]);
+ return shape;
+ }
+
+ static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes) {
+ string result = "[";
+ bool first = true;
+ for (const TensorShape& shape : shapes) {
+ strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
+ first = false;
+ }
+ strings::StrAppend(&result, "]");
+ return result;
+ }
+
+ static bool StartsWith(const TensorShape& shape0, const TensorShape& shape1);
+};
+
+// TODO(josh11b): Add TensorStrides once we support strides
+// struct TensorStrides {
+// gtl::InlinedVector<int, 4> strides_;
+// };
+
+// ----------------------------------------------------------------------------
+// Template method implementation details below
+// ----------------------------------------------------------------------------
+
+template <int NDIMS>
+Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizes() const {
+ CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS
+ << " for a tensor of " << dims() << " dimensions";
+ return AsEigenDSizesWithPadding<NDIMS>();
+}
+
+template <int NDIMS>
+Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizesWithPadding()
+ const {
+ CHECK_GE(NDIMS, dims()) << "Asking for tensor of " << NDIMS
+ << " for a tensor of " << dims() << " dimensions";
+ Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
+ for (int d = 0; d < dims(); d++) {
+ dsizes[d] = dim_size(d);
+ }
+ for (int d = dims(); d < NDIMS; d++) {
+ dsizes[d] = 1;
+ }
+ return dsizes;
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PUBLIC_TENSOR_SHAPE_H_
diff --git a/tensorflow/core/public/tensorflow_server.h b/tensorflow/core/public/tensorflow_server.h
new file mode 100644
index 0000000000..0dac414555
--- /dev/null
+++ b/tensorflow/core/public/tensorflow_server.h
@@ -0,0 +1,19 @@
+#ifndef TENSORFLOW_PUBLIC_TENSORFLOW_SERVER_H_
+#define TENSORFLOW_PUBLIC_TENSORFLOW_SERVER_H_
+
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// Initialize the TensorFlow service for this address space.
+// This is a blocking call that never returns.
+// See BUILD file for details on linkage guidelines.
+::tensorflow::Status InitTensorFlow();
+
+// Like InitTensorFlow() but returns after the Tensorflow
+// services have been launched.
+::tensorflow::Status LaunchTensorFlow();
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PUBLIC_TENSORFLOW_SERVER_H_
diff --git a/tensorflow/core/user_ops/fact.cc b/tensorflow/core/user_ops/fact.cc
new file mode 100644
index 0000000000..7b6932244d
--- /dev/null
+++ b/tensorflow/core/user_ops/fact.cc
@@ -0,0 +1,29 @@
+// An example Op.
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+using namespace tensorflow;
+
+REGISTER_OP("Fact")
+ .Output("fact: string")
+ .Doc(R"doc(
+Output a fact about factorials.
+)doc");
+
+class FactOp : public OpKernel {
+ public:
+ explicit FactOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // Output a scalar string.
+ Tensor* output_tensor = NULL;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape(), &output_tensor));
+ auto output = output_tensor->template scalar<string>();
+
+ output() = "0! == 1";
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("Fact").Device(DEVICE_CPU), FactOp);
diff --git a/tensorflow/core/util/bcast.cc b/tensorflow/core/util/bcast.cc
new file mode 100644
index 0000000000..4e70b78751
--- /dev/null
+++ b/tensorflow/core/util/bcast.cc
@@ -0,0 +1,120 @@
+#include "tensorflow/core/util/bcast.h"
+
+#include "tensorflow/core/platform/logging.h"
+namespace tensorflow {
+
+/* static */
+void BCast::Reverse(Vec* shape) { std::reverse(shape->begin(), shape->end()); }
+
+BCast::BCast(const Vec& sx, const Vec& sy) {
+ // Reverse the shape of x and y for convenience.
+ // After the reverse, 0-th is the inner-most dimension.
+ Vec x = sx;
+ Reverse(&x);
+ Vec y = sy;
+ Reverse(&y);
+
+ // 1-extend and align x and y so that they are the same size.
+ if (x.size() > y.size()) {
+ y.resize(x.size(), 1);
+ } else {
+ x.resize(y.size(), 1);
+ }
+
+ // Going through each dimension starting from the inner-most
+ // dimension, compares dimension of x and y. They are compatible if
+ // they are equal or either is 1.
+ enum State {
+ UNKNOWN,
+ SAME,
+ X_ONE,
+ Y_ONE,
+ };
+ State prev = UNKNOWN;
+ const int64 n = x.size();
+ for (int i = 0; i < n; ++i) {
+ // Output shape.
+ State curr = UNKNOWN;
+ const int64 x_i = x[i]; // i-th dimension of x.
+ CHECK_GE(x_i, 0);
+ const int64 y_i = y[i]; // i-th dimension of y.
+ CHECK_GE(y_i, 0);
+ int64 o_i; // i-th dimension of the output.
+ int64 bx_i; // i-th broadcast for x.
+ int64 by_i; // i-th broadcast for y.
+ // Invariant:
+ // o_i = x_i * bx_i = y_i * by_i
+ if (x_i == y_i) {
+ // No broadcast.
+ o_i = x_i;
+ bx_i = 1;
+ by_i = 1;
+ curr = SAME;
+ } else if (x_i == 1) {
+ // x broadcast to y on this dimension.
+ o_i = y_i;
+ bx_i = y_i;
+ by_i = 1;
+ grad_x_reduce_idx_.push_back(n - 1 - i);
+ curr = X_ONE;
+ } else if (y_i == 1) {
+ // y broadcast to x on this dimension.
+ o_i = x_i;
+ bx_i = 1;
+ by_i = x_i;
+ grad_y_reduce_idx_.push_back(n - 1 - i);
+ curr = Y_ONE;
+ } else {
+ valid_ = false;
+ return;
+ }
+ output_.push_back(o_i);
+ // Reshape/broadcast.
+ // Invariant:
+ // result[i] == x_reshape[i] * x_bcast[i] == y_reshape_[i] * y_bcast_[i]
+ if (curr == SAME && x_i == 1) {
+ // Both side are 1s.
+ grad_x_reduce_idx_.push_back(n - 1 - i);
+ grad_y_reduce_idx_.push_back(n - 1 - i);
+ continue;
+ } else if (prev == curr) {
+ // It is a run of the same cases (no broadcast, x broadcast to
+ // y, y broadcast to x). We can reshape the input so that fewer
+ // dimensions are involved in the intermediate computation.
+ result_.back() *= o_i;
+ x_reshape_.back() *= x_i;
+ x_bcast_.back() *= bx_i;
+ y_reshape_.back() *= y_i;
+ y_bcast_.back() *= by_i;
+ } else {
+ result_.push_back(o_i);
+ x_reshape_.push_back(x_i);
+ x_bcast_.push_back(bx_i);
+ y_reshape_.push_back(y_i);
+ y_bcast_.push_back(by_i);
+ }
+ prev = curr;
+ }
+
+ if (result_.empty()) {
+ // Can happen when both x and y are effectively scalar.
+ result_.push_back(1);
+ x_reshape_.push_back(1);
+ x_bcast_.push_back(1);
+ y_reshape_.push_back(1);
+ y_bcast_.push_back(1);
+ }
+
+ // Reverse all vectors since x and y were reversed at very
+ // beginning.
+ Reverse(&x_reshape_);
+ Reverse(&x_bcast_);
+ Reverse(&y_reshape_);
+ Reverse(&y_bcast_);
+ Reverse(&result_);
+ Reverse(&output_);
+ Reverse(&grad_x_reduce_idx_);
+ Reverse(&grad_y_reduce_idx_);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/util/bcast.h b/tensorflow/core/util/bcast.h
new file mode 100644
index 0000000000..9f0233e415
--- /dev/null
+++ b/tensorflow/core/util/bcast.h
@@ -0,0 +1,99 @@
+#ifndef TENSORFLOW_UTIL_BCAST_H_
+#define TENSORFLOW_UTIL_BCAST_H_
+
+#include <algorithm>
+#include <vector>
+
+#include "tensorflow/core/platform/port.h"
+
+#include "tensorflow/core/platform/logging.h"
+namespace tensorflow {
+
+// BCast is a helper for broadcasting binary tensor operation.
+// TensorFlow's broadcasting rule follows that of numpy (See
+// http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
+//
+// The rule has the following properties:
+//
+// 1. suffix matching: the rule starts with the right-most
+// dimension, and works towards the left-most dimension. Since
+// TensorFlow is row-major, the right-most dimension (the last
+// element in the shape of a tensor) is the inner-most, a.k.a.
+// the fastest changing, dimension.
+//
+// 2. Two dimensions are compatible for broadcasting if both are the
+// same or either is 1.
+//
+// BCast takes the shape of two tensors and computes a few vectors of
+// int32 that are useful for the caller to reshape the tensors, apply
+// the right broadcasts to them, compute the broadcasted operation,
+// and possibly the gradients. In a nutshell, the caller is expected
+// to compute the broadcasted operation as following:
+//
+// BCast b(x.shape(), y.shape());
+// output = x.reshape(b.x_reshape()).broadcast(b.x_bcast())
+// _op_
+// y.reshape(b.y_reshape()).broadcast(b.y_bcast())
+//
+// For the gradient computation,
+// grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx)
+// .reshape(x.shape())
+// grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx)
+// .reshape(y.shape())
+// backprop_x and backprop_y are functionals of the binary function "op",
+// e.g.,
+// for +, backprop_x(x, y) = backprop_y(x, y) = 1;
+// for *, backprop_x(x, y) = y, backprop_y(x, y) = x;
+// for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2;
+//
+// The multiplication in the grad * backprop_x itself is also
+// broadcasting following the same rule.
+//
+// TODO(zhifengc): Adds support for n-ary (n >= 2).
+class BCast {
+ public:
+ // A vector of int32 representing the shape of tensor. The 0-th
+ // element is the outer-most dimension and the last element is the
+ // inner-most dimension. Note that we do not use TensorShape since
+ // it's more convenient to manipulate Vec directly for this module.
+ typedef std::vector<int64> Vec;
+
+ BCast(const Vec& x, const Vec& y);
+ ~BCast() {}
+
+ // Returns true iff two operands are compatible according to the
+ // broadcasting rule.
+ bool IsValid() const { return valid_; }
+
+ // If and only if IsValid(), the following fields can be used in
+ // implementing a broadcasted binary tensor operation according to
+ // the broadcasting rule.
+ const Vec& x_reshape() const { return x_reshape_; }
+ const Vec& x_bcast() const { return x_bcast_; }
+ const Vec& y_reshape() const { return y_reshape_; }
+ const Vec& y_bcast() const { return y_bcast_; }
+ const Vec& result_shape() const { return result_; }
+ const Vec& output_shape() const { return output_; }
+ const Vec& grad_x_reduce_idx() const { return grad_x_reduce_idx_; }
+ const Vec& grad_y_reduce_idx() const { return grad_y_reduce_idx_; }
+
+ private:
+ bool valid_ = true;
+ Vec x_reshape_;
+ Vec x_bcast_;
+ Vec y_reshape_;
+ Vec y_bcast_;
+ Vec result_;
+ Vec output_;
+ Vec grad_x_reduce_idx_;
+ Vec grad_y_reduce_idx_;
+
+ static void Reverse(Vec* shape);
+ static bool HasZero(const Vec& shape);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BCast);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_BCAST_H_
diff --git a/tensorflow/core/util/bcast_test.cc b/tensorflow/core/util/bcast_test.cc
new file mode 100644
index 0000000000..02d18586d6
--- /dev/null
+++ b/tensorflow/core/util/bcast_test.cc
@@ -0,0 +1,226 @@
+#include "tensorflow/core/util/bcast.h"
+
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+string BCast(const tensorflow::BCast::Vec& x, const tensorflow::BCast::Vec& y) {
+ tensorflow::BCast b(x, y);
+ if (!b.IsValid()) {
+ return "invalid";
+ }
+ string ret;
+ strings::StrAppend(&ret, "[", str_util::Join(b.x_reshape(), ","), "]");
+ strings::StrAppend(&ret, "[", str_util::Join(b.x_bcast(), ","), "]");
+ strings::StrAppend(&ret, "[", str_util::Join(b.y_reshape(), ","), "]");
+ strings::StrAppend(&ret, "[", str_util::Join(b.y_bcast(), ","), "]");
+ strings::StrAppend(&ret, "[", str_util::Join(b.result_shape(), ","), "]");
+ strings::StrAppend(&ret, "[", str_util::Join(b.output_shape(), ","), "]");
+ strings::StrAppend(&ret, "[", str_util::Join(b.grad_x_reduce_idx(), ","),
+ "]");
+ strings::StrAppend(&ret, "[", str_util::Join(b.grad_y_reduce_idx(), ","),
+ "]");
+ return ret;
+}
+
+TEST(BCastTest, Invalid) {
+ EXPECT_EQ("invalid", BCast({5, 3, 2}, {3}));
+ EXPECT_EQ("invalid", BCast({5, 3, 2}, {2, 2}));
+ EXPECT_EQ("invalid", BCast({5, 3, 2}, {10, 1, 1}));
+ EXPECT_EQ("invalid", BCast({1, 2, 1, 2, 1, 2}, {2, 4, 2, 1, 2, 1}));
+}
+
+TEST(BCastTest, Basic_SameShape) {
+ // Effectively no broadcast needed.
+ EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}),
+ "[2310][1][2310][1]"
+ "[2310]"
+ "[11,7,5,3,2]"
+ "[][]");
+}
+
+TEST(BCastTest, Basic_Scalar_Scalar) {
+ // Effectively it's a scalar and a scalar.
+ // [1, 1] [1]
+ EXPECT_EQ(BCast({1, 1}, {1}),
+ "[1][1][1][1]"
+ "[1]"
+ "[1,1]"
+ "[0,1][0,1]");
+
+ // [1] [1, 1]
+ EXPECT_EQ(BCast({1}, {1, 1}),
+ "[1][1][1][1]"
+ "[1]"
+ "[1,1]"
+ "[0,1][0,1]");
+}
+
+TEST(BCastTest, Basic_Tensor_Scalar) {
+ // Effectively it's a tensor and a scalar.
+ // [11, 7, 5, 3, 2] [1]
+ EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {1}),
+ "[2310][1][1][2310]"
+ "[2310]"
+ "[11,7,5,3,2]"
+ "[][0,1,2,3,4]");
+
+ // [1] [11, 7, 5, 3, 2]
+ EXPECT_EQ(BCast({1}, {11, 7, 5, 3, 2}),
+ "[1][2310][2310][1]"
+ "[2310]"
+ "[11,7,5,3,2]"
+ "[0,1,2,3,4][]");
+}
+
+TEST(BCastTest, Basic_Tensor_With_DimSize_1_Scalar) {
+ // Effectively it's a tensor and a scalar.
+ // [11, 7, 5, 3, 2, 1] [1]
+ EXPECT_EQ(BCast({11, 7, 5, 3, 2, 1}, {1}),
+ "[2310][1][1][2310]"
+ "[2310]"
+ "[11,7,5,3,2,1]"
+ "[5][0,1,2,3,4,5]");
+
+ // [1] [11, 7, 5, 3, 2, 1]
+ EXPECT_EQ(BCast({1}, {11, 7, 5, 3, 2, 1}),
+ "[1][2310][2310][1]"
+ "[2310]"
+ "[11,7,5,3,2,1]"
+ "[0,1,2,3,4,5][5]");
+
+ // Effectively it's a tensor and a scalar.
+ // [11, 7, 5, 1, 1, 3, 2, 1] [1]
+ EXPECT_EQ(BCast({11, 7, 5, 1, 1, 3, 2, 1, 1}, {1}),
+ "[2310][1][1][2310]"
+ "[2310]"
+ "[11,7,5,1,1,3,2,1,1]"
+ "[3,4,7,8][0,1,2,3,4,5,6,7,8]");
+
+ // [1] [11, 7, 5, 1, 1, 3, 2, 1]
+ EXPECT_EQ(BCast({1}, {11, 7, 5, 1, 1, 3, 2, 1, 1}),
+ "[1][2310][2310][1]"
+ "[2310]"
+ "[11,7,5,1,1,3,2,1,1]"
+ "[0,1,2,3,4,5,6,7,8][3,4,7,8]");
+}
+
+TEST(BCastTest, Basic_Tensor_Vector) {
+ // [11, 7, 5, 3, 2] [2]
+ EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {2}),
+ "[1155,2][1,1][1,2][1155,1]"
+ "[1155,2]"
+ "[11,7,5,3,2]"
+ "[][0,1,2,3]");
+
+ // [2] [11, 7, 5, 3, 2]
+ EXPECT_EQ(BCast({2}, {11, 7, 5, 3, 2}),
+ "[1,2][1155,1][1155,2][1,1]"
+ "[1155,2]"
+ "[11,7,5,3,2]"
+ "[0,1,2,3][]");
+}
+
+TEST(BCastTest, Basic_Tensor_Matrix) {
+ // [11, 7, 5, 3, 2] [3, 2]
+ EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {3, 2}),
+ "[385,6][1,1][1,6][385,1]"
+ "[385,6]"
+ "[11,7,5,3,2]"
+ "[][0,1,2]");
+ // [3, 2] [11, 7, 5, 3, 2]
+ EXPECT_EQ(BCast({3, 2}, {11, 7, 5, 3, 2}),
+ "[1,6][385,1][385,6][1,1]"
+ "[385,6]"
+ "[11,7,5,3,2]"
+ "[0,1,2][]");
+}
+
+TEST(BCastTest, Basic_Tensor_Matrix_Column) {
+ // [11, 7, 5, 3, 2] [3, 1]
+ EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {3, 1}),
+ "[385,3,2][1,1,1][1,3,1][385,1,2]"
+ "[385,3,2]"
+ "[11,7,5,3,2]"
+ "[][0,1,2,4]");
+
+ // [3, 1] [11, 7, 5, 3, 2]
+ EXPECT_EQ(BCast({3, 1}, {11, 7, 5, 3, 2}),
+ "[1,3,1][385,1,2][385,3,2][1,1,1]"
+ "[385,3,2]"
+ "[11,7,5,3,2]"
+ "[0,1,2,4][]");
+}
+
+TEST(BCastTest, Basic_Tensor_Matrix_As_Tensor) {
+ // [11, 7, 5, 3, 2] [7, 5, 1, 1]
+ EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {7, 5, 1, 1}),
+ "[11,35,6][1,1,1][1,35,1][11,1,6]"
+ "[11,35,6]"
+ "[11,7,5,3,2]"
+ "[][0,3,4]");
+
+ // [7, 5, 1, 1] [11, 7, 5, 3, 2]
+ EXPECT_EQ(BCast({7, 5, 1, 1}, {11, 7, 5, 3, 2}),
+ "[1,35,1][11,1,6][11,35,6][1,1,1]"
+ "[11,35,6]"
+ "[11,7,5,3,2]"
+ "[0,3,4][]");
+}
+
+TEST(BCastTest, Complex_BCast_To_Each_Other) {
+ // Rare cases. x and y broadcast to each other. x and y are of
+ // different ranks.
+ // Can be verified in numpy as:
+ // import numpy as np
+ // x = np.arange(0,110).reshape([11,1,5,1,2])
+ // y = np.arange(0,21).reshape([7,1,3,1])
+ // np.shape(x + y)
+ // Out[.]: (11, 7, 5, 3, 2)
+ EXPECT_EQ(BCast({11, 1, 5, 1, 2}, {7, 1, 3, 1}),
+ "[11,1,5,1,2][1,7,1,3,1][1,7,1,3,1][11,1,5,1,2]"
+ "[11,7,5,3,2]"
+ "[11,7,5,3,2]"
+ "[1,3][0,2,4]");
+}
+
+TEST(BCastTest, TestZeroDimensionShape) {
+ EXPECT_EQ(BCast({2, 0, 5}, {5}),
+ "[0,5][1,1][1,5][0,1]"
+ "[0,5]"
+ "[2,0,5]"
+ "[][0,1]");
+ EXPECT_EQ(BCast({5}, {2, 0, 5}),
+ "[1,5][0,1][0,5][1,1]"
+ "[0,5]"
+ "[2,0,5]"
+ "[0,1][]");
+
+ EXPECT_EQ(BCast({2, 0, 3, 0, 5}, {5}),
+ "[0,5][1,1][1,5][0,1]"
+ "[0,5]"
+ "[2,0,3,0,5]"
+ "[][0,1,2,3]");
+ EXPECT_EQ(BCast({5}, {2, 0, 3, 0, 5}),
+ "[1,5][0,1][0,5][1,1]"
+ "[0,5]"
+ "[2,0,3,0,5]"
+ "[0,1,2,3][]");
+
+ EXPECT_EQ(BCast({2, 0, 3, 0, 5}, {3, 1, 5}),
+ "[0,3,0,5][1,1,1,1][1,3,1,5][0,1,0,1]"
+ "[0,3,0,5]"
+ "[2,0,3,0,5]"
+ "[][0,1,3]");
+ EXPECT_EQ(BCast({3, 1, 5}, {2, 0, 3, 0, 5}),
+ "[1,3,1,5][0,1,0,1][0,3,0,5][1,1,1,1]"
+ "[0,3,0,5]"
+ "[2,0,3,0,5]"
+ "[0,1,3][]");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc
new file mode 100644
index 0000000000..b8c6a77dd0
--- /dev/null
+++ b/tensorflow/core/util/device_name_utils.cc
@@ -0,0 +1,338 @@
+#include "tensorflow/core/util/device_name_utils.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+static bool IsAlpha(char c) {
+ return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');
+}
+
+static bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
+
+// Returns true iff "in" is a valid job name.
+static bool IsJobName(StringPiece in) {
+ if (in.empty()) return false;
+ if (!IsAlpha(in[0])) return false;
+ for (size_t i = 1; i < in.size(); ++i) {
+ if (!(IsAlphaNum(in[i]) || in[i] == '_')) return false;
+ }
+ return true;
+}
+
+// Returns true and fills in "*job" iff "*in" starts with a job name.
+static bool ConsumeJobName(StringPiece* in, string* job) {
+ if (in->empty()) return false;
+ if (!IsAlpha((*in)[0])) return false;
+ size_t i = 1;
+ for (; i < in->size(); ++i) {
+ const char c = (*in)[i];
+ if (c == '/') break;
+ if (!(IsAlphaNum(c) || c == '_')) {
+ return false;
+ }
+ }
+ job->assign(in->data(), i);
+ in->remove_prefix(i);
+ return true;
+}
+
+// Returns true and fills in "*device_type" iff "*in" starts with a device type
+// name.
+static bool ConsumeDeviceType(StringPiece* in, string* device_type) {
+ if (in->empty()) return false;
+ if (!IsAlpha((*in)[0])) return false;
+ size_t i = 1;
+ for (; i < in->size(); ++i) {
+ const char c = (*in)[i];
+ if (c == '/' || c == ':') break;
+ if (!(IsAlphaNum(c) || c == '_')) {
+ return false;
+ }
+ }
+ device_type->assign(in->data(), i);
+ in->remove_prefix(i);
+ return true;
+}
+
+// Returns true and fills in "*val" iff "*in" starts with a decimal
+// number.
+static bool ConsumeNumber(StringPiece* in, int* val) {
+ uint64 tmp;
+ if (str_util::ConsumeLeadingDigits(in, &tmp)) {
+ *val = tmp;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+/* static */
+string DeviceNameUtils::FullName(const string& job, int replica, int task,
+ const string& type, int id) {
+ CHECK(IsJobName(job)) << job;
+ CHECK_LE(0, replica);
+ CHECK_LE(0, task);
+ CHECK(!type.empty());
+ CHECK_LE(0, id);
+ return strings::StrCat("/job:", job, "/replica:", replica, "/task:", task,
+ "/device:", type, ":", id);
+}
+
+bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
+ p->Clear();
+ if (fullname == "/") {
+ return true;
+ }
+ StringPiece tmp;
+ while (!fullname.empty()) {
+ if (str_util::ConsumePrefix(&fullname, "/job:")) {
+ p->has_job = !str_util::ConsumePrefix(&fullname, "*");
+ if (p->has_job && !ConsumeJobName(&fullname, &p->job)) {
+ return false;
+ }
+ } else if (str_util::ConsumePrefix(&fullname, "/replica:")) {
+ p->has_replica = !str_util::ConsumePrefix(&fullname, "*");
+ if (p->has_replica && !ConsumeNumber(&fullname, &p->replica)) {
+ return false;
+ }
+ } else if (str_util::ConsumePrefix(&fullname, "/task:")) {
+ p->has_task = !str_util::ConsumePrefix(&fullname, "*");
+ if (p->has_task && !ConsumeNumber(&fullname, &p->task)) {
+ return false;
+ }
+ } else if (str_util::ConsumePrefix(&fullname, "/device:")) {
+ p->has_type = !str_util::ConsumePrefix(&fullname, "*");
+ if (p->has_type && !ConsumeDeviceType(&fullname, &p->type)) {
+ return false;
+ }
+ if (!str_util::ConsumePrefix(&fullname, ":")) {
+ p->has_id = false;
+ } else {
+ p->has_id = !str_util::ConsumePrefix(&fullname, "*");
+ if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
+ return false;
+ }
+ }
+
+ } else if (str_util::ConsumePrefix(&fullname, "/cpu:") ||
+ str_util::ConsumePrefix(&fullname, "/CPU:")) {
+ p->has_type = true;
+ p->type = "CPU"; // Treat '/cpu:..' as uppercase '/device:CPU:...'
+ p->has_id = !str_util::ConsumePrefix(&fullname, "*");
+ if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
+ return false;
+ }
+ } else if (str_util::ConsumePrefix(&fullname, "/gpu:") ||
+ str_util::ConsumePrefix(&fullname, "/GPU:")) {
+ p->has_type = true;
+ p->type = "GPU"; // Treat '/gpu:..' as uppercase '/device:GPU:...'
+ p->has_id = !str_util::ConsumePrefix(&fullname, "*");
+ if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ }
+ return true;
+}
+
+/* static */
+string DeviceNameUtils::ParsedNameToString(const ParsedName& pn) {
+ string buf;
+ if (pn.has_job) strings::StrAppend(&buf, "/job:", pn.job);
+ if (pn.has_replica) strings::StrAppend(&buf, "/replica:", pn.replica);
+ if (pn.has_task) strings::StrAppend(&buf, "/task:", pn.task);
+ if (pn.has_type) {
+ strings::StrAppend(&buf, "/", pn.type, ":");
+ if (pn.has_id) {
+ strings::StrAppend(&buf, pn.id);
+ } else {
+ strings::StrAppend(&buf, "*");
+ }
+ }
+ return buf;
+}
+
+/* static */
+bool DeviceNameUtils::IsSpecification(const ParsedName& less_specific,
+ const ParsedName& more_specific) {
+ if (less_specific.has_job &&
+ (!more_specific.has_job || (less_specific.job != more_specific.job))) {
+ return false;
+ }
+ if (less_specific.has_replica &&
+ (!more_specific.has_replica ||
+ (less_specific.replica != more_specific.replica))) {
+ return false;
+ }
+ if (less_specific.has_task &&
+ (!more_specific.has_task || (less_specific.task != more_specific.task))) {
+ return false;
+ }
+ if (less_specific.has_type &&
+ (!more_specific.has_type || (less_specific.type != more_specific.type))) {
+ return false;
+ }
+ if (less_specific.has_id &&
+ (!more_specific.has_id || (less_specific.id != more_specific.id))) {
+ return false;
+ }
+ return true;
+}
+
+/* static */
+bool DeviceNameUtils::IsCompleteSpecification(const ParsedName& pattern,
+ const ParsedName& name) {
+ CHECK(name.has_job && name.has_replica && name.has_task && name.has_type &&
+ name.has_id);
+
+ if (pattern.has_job && (pattern.job != name.job)) return false;
+ if (pattern.has_replica && (pattern.replica != name.replica)) return false;
+ if (pattern.has_task && (pattern.task != name.task)) return false;
+ if (pattern.has_type && (pattern.type != name.type)) return false;
+ if (pattern.has_id && (pattern.id != name.id)) return false;
+ return true;
+}
+
+/* static */
+Status DeviceNameUtils::MergeDevNames(ParsedName* target,
+ const ParsedName& other,
+ bool allow_soft_placement) {
+ if (other.has_job) {
+ if (target->has_job && target->job != other.job) {
+ return errors::InvalidArgument(
+ "Cannot merge devices with incompatible jobs: '",
+ ParsedNameToString(*target), "' and '", ParsedNameToString(other),
+ "'");
+ } else {
+ target->has_job = other.has_job;
+ target->job = other.job;
+ }
+ }
+
+ if (other.has_replica) {
+ if (target->has_replica && target->replica != other.replica) {
+ return errors::InvalidArgument(
+ "Cannot merge devices with incompatible replicas: '",
+ ParsedNameToString(*target), "' and '", ParsedNameToString(other),
+ "'");
+ } else {
+ target->has_replica = other.has_replica;
+ target->replica = other.replica;
+ }
+ }
+
+ if (other.has_task) {
+ if (target->has_task && target->task != other.task) {
+ return errors::InvalidArgument(
+ "Cannot merge devices with incompatible tasks: '",
+ ParsedNameToString(*target), "' and '", ParsedNameToString(other),
+ "'");
+ } else {
+ target->has_task = other.has_task;
+ target->task = other.task;
+ }
+ }
+
+ if (other.has_type) {
+ if (target->has_type && target->type != other.type) {
+ if (!allow_soft_placement) {
+ return errors::InvalidArgument(
+ "Cannot merge devices with incompatible types: '",
+ ParsedNameToString(*target), "' and '", ParsedNameToString(other),
+ "'");
+ } else {
+ target->has_id = false;
+ target->has_type = false;
+ return Status::OK();
+ }
+ } else {
+ target->has_type = other.has_type;
+ target->type = other.type;
+ }
+ }
+
+ if (other.has_id) {
+ if (target->has_id && target->id != other.id) {
+ if (!allow_soft_placement) {
+ return errors::InvalidArgument(
+ "Cannot merge devices with incompatible ids: '",
+ ParsedNameToString(*target), "' and '", ParsedNameToString(other),
+ "'");
+ } else {
+ target->has_id = false;
+ return Status::OK();
+ }
+ } else {
+ target->has_id = other.has_id;
+ target->id = other.id;
+ }
+ }
+
+ return Status::OK();
+}
+
+/* static */
+bool DeviceNameUtils::IsSameAddressSpace(const ParsedName& a,
+ const ParsedName& b) {
+ return (a.has_job && b.has_job && (a.job == b.job)) &&
+ (a.has_replica && b.has_replica && (a.replica == b.replica)) &&
+ (a.has_task && b.has_task && (a.task == b.task));
+}
+
+/* static */
+bool DeviceNameUtils::IsSameAddressSpace(StringPiece src, StringPiece dst) {
+ ParsedName x;
+ ParsedName y;
+ return ParseFullName(src, &x) && ParseFullName(dst, &y) &&
+ IsSameAddressSpace(x, y);
+}
+
+/* static */
+string DeviceNameUtils::LocalName(StringPiece type, int id) {
+ return strings::StrCat(type, ":", id);
+}
+
+/* static */
+string DeviceNameUtils::LocalName(StringPiece fullname) {
+ ParsedName x;
+ CHECK(ParseFullName(fullname, &x)) << fullname;
+ return LocalName(x.type, x.id);
+}
+
+/* static */
+bool DeviceNameUtils::ParseLocalName(StringPiece name, ParsedName* p) {
+ ParsedName x;
+ if (!ConsumeDeviceType(&name, &p->type)) {
+ return false;
+ }
+ if (!str_util::ConsumePrefix(&name, ":")) {
+ return false;
+ }
+ if (!ConsumeNumber(&name, &p->id)) {
+ return false;
+ }
+ return name.empty();
+}
+
+/* static */
+bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task,
+ string* device) {
+ ParsedName pn;
+ if (ParseFullName(name, &pn) && pn.has_type && pn.has_id) {
+ *task = strings::StrCat(
+ (pn.has_job ? strings::StrCat("/job:", pn.job) : ""),
+ (pn.has_replica ? strings::StrCat("/replica:", pn.replica) : ""),
+ (pn.has_task ? strings::StrCat("/task:", pn.task) : ""));
+ *device = strings::StrCat(pn.type, ":", pn.id);
+ return true;
+ }
+ return false;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/device_name_utils.h b/tensorflow/core/util/device_name_utils.h
new file mode 100644
index 0000000000..8b0a24ed0d
--- /dev/null
+++ b/tensorflow/core/util/device_name_utils.h
@@ -0,0 +1,141 @@
+#ifndef TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_
+#define TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_
+
+#include <string>
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// In TensorFlow a device name is a string of the following form:
+// /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
+//
+// <name> is a short identifier conforming to the regexp
+// [a-zA-Z][_a-zA-Z]*
+// <type> is a supported device type (e.g. 'cpu' or 'gpu')
+// <replica>, <task>, <device_num> are small non-negative integers and are
+// densely allocated (except in tests).
+//
+// For some purposes, we also allow device patterns, which can specify
+// some or none of the specific fields above, with missing components,
+// or "<component>:*" indicating "any value allowed for that component.
+//
+// For example:
+// "/job:param_server" - Consider any devices in the "param_server" job
+// "/device:cpu:*" - Consider any cpu devices in any job/task/replica
+// "/job:*/replica:*/task:*/device:cpu:*" - Consider any cpu devices in any
+// job/task/replica
+// "/job:w/replica:0/task:0/device:gpu:*" - Consider any gpu devices in
+// replica 0, task 0, of job "w"
+class DeviceNameUtils {
+ public:
+ // Returns a fully qualified device name given the parameters.
+ static string FullName(const string& job, int replica, int task,
+ const string& type, int id);
+
+ struct ParsedName {
+ void Clear() {
+ has_job = false;
+ has_replica = false;
+ has_task = false;
+ has_type = false;
+ has_id = false;
+ job.clear();
+ replica = 0;
+ task = 0;
+ type.clear();
+ id = 0;
+ }
+
+ bool operator==(const ParsedName& other) const {
+ return (has_job ? (other.has_job && job == other.job) : !other.has_job) &&
+ (has_replica ? (other.has_replica && replica == other.replica)
+ : !other.has_replica) &&
+ (has_task ? (other.has_task && task == other.task)
+ : !other.has_task) &&
+ (has_type ? (other.has_type && type == other.type)
+ : !other.has_type) &&
+ (has_id ? (other.has_id && id == other.id) : !other.has_id);
+ }
+
+ bool has_job = false;
+ string job;
+ bool has_replica = false;
+ int replica = 0;
+ bool has_task = false;
+ int task = 0;
+ bool has_type = false;
+ string type;
+ bool has_id = false;
+ int id = 0;
+ };
+ // Parses "fullname" into "*parsed". Returns true iff succeeds.
+ static bool ParseFullName(StringPiece fullname, ParsedName* parsed);
+
+ // Returns true if "name" specifies any non-trivial constraint on the device.
+ static bool HasSomeDetails(const ParsedName& name) {
+ return name.has_job || name.has_replica || name.has_task || name.has_type ||
+ name.has_id;
+ }
+
+ // Returns true if more_specific is a specification of
+ // less_specific, i.e. everywhere that less-specific has a
+ // non-wildcard component value, more_specific has the same value
+ // for that component.
+ static bool IsSpecification(const ParsedName& less_specific,
+ const ParsedName& more_specific);
+
+ // Like IsSpecification, but the second argument "name" must have a
+ // non-wildcard value for all of its components.
+ static bool IsCompleteSpecification(const ParsedName& pattern,
+ const ParsedName& name);
+
+ // True iff there exists any possible complete device name that is
+ // a specification of both "a" and "b".
+ static inline bool AreCompatibleDevNames(const ParsedName& a,
+ const ParsedName& b) {
+ return IsSpecification(a, b) || IsSpecification(b, a);
+ }
+
+ // Merges the device specifications in "*target" and "other", and
+ // stores the result in "*target". Returns OK if "*target" and
+ // "other" are compatible, otherwise returns an error.
+ static Status MergeDevNames(ParsedName* target, const ParsedName& other) {
+ return MergeDevNames(target, other, false);
+ }
+ static Status MergeDevNames(ParsedName* target, const ParsedName& other,
+ bool allow_soft_placement);
+
+ // Returns true iff devices identified by 'src' and 'dst' are in the
+ // same address space.
+ static bool IsSameAddressSpace(StringPiece src, StringPiece dst);
+ static bool IsSameAddressSpace(const ParsedName& src, const ParsedName& dst);
+
+ // Returns the local device given its "type" and "id".
+ static string LocalName(StringPiece type, int id);
+
+ // Returns a short local device name (cpu:0, gpu:1, etc) based on
+ // the given fullname.
+ static string LocalName(StringPiece fullname);
+
+ // If "name" is a valid local device name (cpu:0, gpu:1, etc.),
+ // fills in parsed.type and parsed.id accordingly. Returns true iff
+ // succeeds.
+ static bool ParseLocalName(StringPiece name, ParsedName* parsed);
+
+ // Splits a fully-qualified device name into a task identifier and a
+ // relative device identifier. It first parses "name" using
+ // ParseFullName(), then assigns *task with everything except for
+ // the local device component, and assigns the relative device
+ // component into *device. This function will still return true if
+ // the task component is empty, but it requires the relative device
+ // component to be fully specified.
+ static bool SplitDeviceName(StringPiece name, string* task, string* device);
+
+ static string ParsedNameToString(const ParsedName& pn);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_
diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc
new file mode 100644
index 0000000000..14f30d6de5
--- /dev/null
+++ b/tensorflow/core/util/device_name_utils_test.cc
@@ -0,0 +1,369 @@
+#include "tensorflow/core/util/device_name_utils.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+TEST(DeviceNameUtilsTest, Basic) {
+ EXPECT_EQ(DeviceNameUtils::FullName("hello", 1, 2, "CPU", 3),
+ "/job:hello/replica:1/task:2/device:CPU:3");
+
+ {
+ DeviceNameUtils::ParsedName p;
+ EXPECT_FALSE(DeviceNameUtils::ParseFullName("foobar", &p));
+ EXPECT_FALSE(
+ DeviceNameUtils::ParseFullName("/job:123/replica:1/task:2/gpu:3", &p));
+ EXPECT_FALSE(
+ DeviceNameUtils::ParseFullName("/job:123/replica:1/task:2/gpu:", &p));
+ EXPECT_FALSE(DeviceNameUtils::ParseFullName(
+ "/job:123/replica:1/task:2/device:gpu:", &p));
+ EXPECT_FALSE(
+ DeviceNameUtils::ParseFullName("/job:foo/replica:-1/task:2/gpu:3", &p));
+ EXPECT_FALSE(
+ DeviceNameUtils::ParseFullName("/job:foo/replica:1/task:-2/gpu:3", &p));
+ EXPECT_FALSE(
+ DeviceNameUtils::ParseFullName("/job:foo/replica:1/task:2/bar:3", &p));
+ EXPECT_FALSE(DeviceNameUtils::ParseFullName(
+ "/job:foo/replica:1/task:2/gpu:3/extra", &p));
+ EXPECT_TRUE(
+ DeviceNameUtils::ParseFullName("/job:foo/replica:1/task:2/gpu:3", &p));
+ EXPECT_TRUE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_TRUE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_TRUE(p.has_id);
+ EXPECT_EQ(p.job, "foo");
+ EXPECT_EQ(p.replica, 1);
+ EXPECT_EQ(p.task, 2);
+ EXPECT_EQ(p.type, "GPU");
+ EXPECT_EQ(p.id, 3);
+ }
+ {
+ // Allow _ in job names.
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName(
+ "/job:foo_bar/replica:1/task:2/gpu:3", &p));
+ EXPECT_TRUE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_TRUE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_TRUE(p.has_id);
+ EXPECT_EQ(p.job, "foo_bar");
+ EXPECT_EQ(p.replica, 1);
+ EXPECT_EQ(p.task, 2);
+ EXPECT_EQ(p.type, "GPU");
+ EXPECT_EQ(p.id, 3);
+ }
+ {
+ // Allow _ in job names.
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName(
+ "/job:foo_bar/replica:1/task:2/device:GPU:3", &p));
+ EXPECT_TRUE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_TRUE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_TRUE(p.has_id);
+ EXPECT_EQ(p.job, "foo_bar");
+ EXPECT_EQ(p.replica, 1);
+ EXPECT_EQ(p.task, 2);
+ EXPECT_EQ(p.type, "GPU");
+ EXPECT_EQ(p.id, 3);
+ }
+ {
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName("/job:*/replica:4/gpu:*", &p));
+ EXPECT_FALSE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_FALSE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_FALSE(p.has_id);
+ EXPECT_EQ(p.replica, 4);
+ EXPECT_EQ(p.type, "GPU");
+ }
+ {
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(
+ DeviceNameUtils::ParseFullName("/job:*/replica:4/device:GPU:*", &p));
+ EXPECT_FALSE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_FALSE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_FALSE(p.has_id);
+ EXPECT_EQ(p.replica, 4);
+ EXPECT_EQ(p.type, "GPU");
+ }
+ {
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(
+ DeviceNameUtils::ParseFullName("/job:*/device:GPU/replica:4", &p));
+ EXPECT_FALSE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_FALSE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_FALSE(p.has_id);
+ EXPECT_EQ(p.replica, 4);
+ EXPECT_EQ(p.type, "GPU");
+ }
+ {
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName(
+ "/job:*/replica:4/device:myspecialdevice:13", &p));
+ EXPECT_FALSE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_FALSE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_TRUE(p.has_id);
+ EXPECT_EQ(p.replica, 4);
+ EXPECT_EQ(p.type, "myspecialdevice");
+ EXPECT_EQ(p.id, 13);
+ }
+ {
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName("/", &p));
+ EXPECT_FALSE(p.has_job);
+ EXPECT_FALSE(p.has_replica);
+ EXPECT_FALSE(p.has_task);
+ EXPECT_FALSE(p.has_type);
+ EXPECT_FALSE(p.has_id);
+ }
+ {
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName("/job:*/replica:4/gpu:5", &p));
+ EXPECT_FALSE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_FALSE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_TRUE(p.has_id);
+ EXPECT_EQ(p.replica, 4);
+ EXPECT_EQ(p.type, "GPU");
+ EXPECT_EQ(p.id, 5);
+ }
+ { // Same result if we reorder the components
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName("/gpu:*/job:*/replica:4", &p));
+ EXPECT_FALSE(p.has_job);
+ EXPECT_TRUE(p.has_replica);
+ EXPECT_FALSE(p.has_task);
+ EXPECT_TRUE(p.has_type);
+ EXPECT_FALSE(p.has_id);
+ EXPECT_EQ(p.replica, 4);
+ EXPECT_EQ(p.type, "GPU");
+ }
+
+ EXPECT_TRUE(DeviceNameUtils::IsSameAddressSpace(
+ "/job:foo/replica:1/task:2/cpu:3", "/job:foo/replica:1/task:2/gpu:4"));
+ EXPECT_FALSE(DeviceNameUtils::IsSameAddressSpace(
+ "/job:foo/replica:1/task:2/cpu:3", "/job:foo/replica:1/task:3/gpu:4"));
+ EXPECT_FALSE(DeviceNameUtils::IsSameAddressSpace(
+ "/job:foo/replica:1/task:2/cpu:3", "/job:foo/replica:10/task:2/gpu:4"));
+ EXPECT_FALSE(DeviceNameUtils::IsSameAddressSpace(
+ "/job:foo/replica:1/task:2/cpu:3", "/job:bar/replica:1/task:2/gpu:4"));
+
+ EXPECT_EQ(DeviceNameUtils::LocalName("CPU", 1), "CPU:1");
+ EXPECT_EQ(DeviceNameUtils::LocalName("GPU", 2), "GPU:2");
+ EXPECT_EQ(DeviceNameUtils::LocalName("MySpecialDevice", 13),
+ "MySpecialDevice:13");
+
+ EXPECT_EQ(
+ DeviceNameUtils::LocalName("/job:foo/replica:1/task:2/device:CPU:3"),
+ "CPU:3");
+
+ EXPECT_EQ(DeviceNameUtils::LocalName("/job:foo/replica:1/task:2/cpu:3"),
+ "CPU:3");
+
+ EXPECT_EQ(
+ DeviceNameUtils::LocalName("/job:foo/replica:1/task:2/device:abc:73"),
+ "abc:73");
+
+ {
+ DeviceNameUtils::ParsedName p;
+ EXPECT_TRUE(DeviceNameUtils::ParseLocalName("CPU:10", &p));
+ EXPECT_EQ(p.type, "CPU");
+ EXPECT_EQ(p.id, 10);
+ EXPECT_FALSE(DeviceNameUtils::ParseLocalName("cpu:abc", &p));
+ EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc:", &p));
+ EXPECT_FALSE(DeviceNameUtils::ParseLocalName("abc", &p));
+ EXPECT_FALSE(DeviceNameUtils::ParseLocalName("myspecialdevice", &p));
+ }
+}
+
+static bool IsCSHelper(StringPiece pattern, StringPiece actual) {
+ DeviceNameUtils::ParsedName p, a;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName(pattern, &p));
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName(actual, &a));
+ return DeviceNameUtils::IsCompleteSpecification(p, a);
+}
+
+TEST(DeviceNameUtilsTest, IsCompleteSpecification) {
+ EXPECT_TRUE(IsCSHelper("/job:*", "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(
+ IsCSHelper("/job:*/replica:*", "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsCSHelper("/job:*/task:*", "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsCSHelper("/job:*/replica:*/task:*",
+ "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(
+ IsCSHelper("/job:*/replica:*/gpu:*", "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_FALSE(IsCSHelper("/cpu:*", "/job:worker/replica:1/task:2/gpu:3"));
+ EXPECT_FALSE(IsCSHelper("/gpu:2", "/job:worker/replica:1/task:2/gpu:1"));
+ EXPECT_TRUE(IsCSHelper("/gpu:*", "/job:worker/replica:1/task:2/gpu:3"));
+}
+
+static bool IsSpecHelper(StringPiece pattern, StringPiece actual) {
+ DeviceNameUtils::ParsedName p, a;
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName(pattern, &p));
+ EXPECT_TRUE(DeviceNameUtils::ParseFullName(actual, &a));
+ return DeviceNameUtils::IsSpecification(p, a);
+}
+
+TEST(DeviceNameUtilsTest, IsSpecification) {
+ EXPECT_TRUE(IsSpecHelper("/job:*", "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/job:*", "/job:work/replica:1/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/job:*", "/job:work/replica:1"));
+ EXPECT_TRUE(IsSpecHelper("/job:*", "/replica:1"));
+ EXPECT_TRUE(IsSpecHelper("/job:*", "/job:work"));
+ EXPECT_TRUE(
+ IsSpecHelper("/job:*/replica:*", "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/job:work/replica:1/gpu:*",
+ "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/job:work/replica:1/gpu:3",
+ "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/job:work/replica:1/task:2",
+ "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/job:work/replica:*/task:2",
+ "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/task:*", "/job:*/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/task:2", "/job:*/replica:1/task:2/gpu:3"));
+ EXPECT_TRUE(IsSpecHelper("/cpu:*", "/job:*/replica:1/task:2/cpu:1"));
+ EXPECT_TRUE(IsSpecHelper("/cpu:0", "/cpu:0"));
+ EXPECT_TRUE(IsSpecHelper("/gpu:*", "/job:worker/replica:1/task:2/gpu:3"));
+
+ EXPECT_FALSE(IsSpecHelper("/job:worker/replica:1/task:2/gpu:3", "/gpu:*"));
+ EXPECT_FALSE(IsSpecHelper("/cpu:*", "/job:*/replica:1/task:2"));
+ EXPECT_FALSE(IsSpecHelper("/cpu:*", "/job:*/replica:1/task:2/gpu:1"));
+ EXPECT_FALSE(IsSpecHelper("/cpu:*", "/job:worker/replica:1/task:2/gpu:3"));
+ EXPECT_FALSE(IsSpecHelper("/gpu:2", "/job:worker/replica:1/task:2/gpu:1"));
+ EXPECT_FALSE(IsSpecHelper("/job:work/replica:*/task:0",
+ "/job:work/replica:1/task:2/gpu:3"));
+ EXPECT_FALSE(IsSpecHelper("/job:work/replica:0/task:2",
+ "/job:work/replica:*/task:2/gpu:3"));
+}
+
+TEST(DeviceNameUtilsTest, SplitDeviceName) {
+ string task;
+ string device;
+ EXPECT_TRUE(DeviceNameUtils::SplitDeviceName(
+ "/job:foo/replica:1/task:2/cpu:1", &task, &device));
+ EXPECT_EQ("/job:foo/replica:1/task:2", task);
+ EXPECT_EQ("CPU:1", device);
+ EXPECT_TRUE(DeviceNameUtils::SplitDeviceName(
+ "/job:foo/cpu:1/task:2/replica:1", &task, &device));
+ EXPECT_EQ("/job:foo/replica:1/task:2", task);
+ EXPECT_EQ("CPU:1", device);
+ EXPECT_TRUE(DeviceNameUtils::SplitDeviceName("/gpu:3", &task, &device));
+ EXPECT_EQ("", task);
+ EXPECT_EQ("GPU:3", device);
+ EXPECT_FALSE(DeviceNameUtils::SplitDeviceName("gpu:3", &task, &device));
+ EXPECT_FALSE(DeviceNameUtils::SplitDeviceName("/job:foo/task:2/replica:1",
+ &task, &device));
+ EXPECT_TRUE(DeviceNameUtils::SplitDeviceName("/device:myspecialdevice:3",
+ &task, &device));
+ EXPECT_EQ("", task);
+ EXPECT_EQ("myspecialdevice:3", device);
+}
+
+static DeviceNameUtils::ParsedName Name(const string& str) {
+ DeviceNameUtils::ParsedName ret;
+ CHECK(DeviceNameUtils::ParseFullName(str, &ret)) << "Invalid name: " << str;
+ return ret;
+}
+
+static void MergeDevNamesHelperImpl(const string& name_a, const string& name_b,
+ const string& expected_merge_name,
+ bool allow_soft_placement) {
+ DeviceNameUtils::ParsedName target_a = Name(name_a);
+ EXPECT_OK(DeviceNameUtils::MergeDevNames(&target_a, Name(name_b),
+ allow_soft_placement));
+ DeviceNameUtils::ParsedName target_b = Name(name_b);
+ EXPECT_OK(DeviceNameUtils::MergeDevNames(&target_b, Name(name_a),
+ allow_soft_placement));
+ EXPECT_EQ(target_a, target_b);
+ EXPECT_EQ(target_a, Name(expected_merge_name));
+ EXPECT_EQ(target_b, Name(expected_merge_name));
+}
+
+static void MergeDevNamesHelper(const string& name_a, const string& name_b,
+ const string& expected_merge_name) {
+ MergeDevNamesHelperImpl(name_a, name_b, expected_merge_name, false);
+}
+
+static void MergeDevNamesHelperAllowSoftPlacement(
+ const string& name_a, const string& name_b,
+ const string& expected_merge_name) {
+ MergeDevNamesHelperImpl(name_a, name_b, expected_merge_name, true);
+}
+
+static void MergeDevNamesError(const string& name_a, const string& name_b,
+ const string& expected_error_substr) {
+ DeviceNameUtils::ParsedName target_a = Name(name_a);
+ Status s = DeviceNameUtils::MergeDevNames(&target_a, Name(name_b));
+ EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
+ EXPECT_TRUE(StringPiece(s.error_message()).contains(expected_error_substr))
+ << s;
+}
+
+TEST(DeviceNameUtilsTest, MergeDevNames) {
+ DeviceNameUtils::ParsedName target;
+
+ // Idempotence tests.
+ MergeDevNamesHelper("", "", "");
+ MergeDevNamesHelper("/job:foo/replica:1/task:2/cpu:1",
+ "/job:foo/replica:1/task:2/cpu:1",
+ "/job:foo/replica:1/task:2/cpu:1");
+
+ // Merging with empty device has no effect.
+ MergeDevNamesHelper("", "/job:foo", "/job:foo");
+ MergeDevNamesHelper("", "/replica:2", "/replica:2");
+ MergeDevNamesHelper("", "/task:7", "/task:7");
+ // MergeDevNamesHelper("", "/gpu:1", "/gpu:1");
+
+ // Combining disjoint names.
+ MergeDevNamesHelper("/job:foo", "/task:7", "/job:foo/task:7");
+ MergeDevNamesHelper("/job:foo", "/gpu:1", "/job:foo/gpu:1");
+
+ // Combining overlapping names.
+ MergeDevNamesHelper("/job:foo/replica:0", "/replica:0/task:1",
+ "/job:foo/replica:0/task:1");
+
+ // Wildcard tests.
+ MergeDevNamesHelper("", "/gpu:*", "/gpu:*");
+ MergeDevNamesHelper("/gpu:*", "/gpu:*", "/gpu:*");
+ MergeDevNamesHelper("/gpu:1", "/gpu:*", "/gpu:1");
+
+ // Incompatible components.
+ MergeDevNamesError("/job:foo", "/job:bar", "incompatible jobs");
+ MergeDevNamesError("/replica:0", "/replica:1", "incompatible replicas");
+ MergeDevNamesError("/task:0", "/task:1", "incompatible tasks");
+ MergeDevNamesError("/gpu:*", "/cpu:*", "incompatible types");
+ MergeDevNamesError("/gpu:0", "/gpu:1", "incompatible ids");
+}
+
+TEST(DeviceNameUtilsTest, MergeDevNamesAllowSoftPlacement) {
+ // Incompatible components with allow_soft_placement.
+ MergeDevNamesHelperAllowSoftPlacement("/gpu:*", "/cpu:1", "");
+ MergeDevNamesHelperAllowSoftPlacement("/cpu:*", "/gpu:1", "");
+ MergeDevNamesHelperAllowSoftPlacement("/gpu:1", "/gpu:2", "/gpu:*");
+}
+
+static void BM_ParseFullName(int iters) {
+ DeviceNameUtils::ParsedName p;
+ while (iters--) {
+ DeviceNameUtils::ParseFullName("/job:worker/replica:3/task:0/cpu:0", &p);
+ }
+}
+BENCHMARK(BM_ParseFullName);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/event.proto b/tensorflow/core/util/event.proto
new file mode 100644
index 0000000000..5d67823ce7
--- /dev/null
+++ b/tensorflow/core/util/event.proto
@@ -0,0 +1,29 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/graph.proto";
+import "tensorflow/core/framework/summary.proto";
+
+// Protocol buffer representing an event that happened during
+// the execution of a Brain model.
+message Event {
+ // Timestamp of the event.
+ double wall_time = 1;
+
+ // Globale step of the event.
+ int64 step = 2;
+
+ oneof what {
+ // An event file was started, with the specified version.
+ // This is use to identify the contents of the record IO files
+ // easily. Current version is "tensorflow.Event:1". All versions
+ // start with "tensorflow.Event:".
+ string file_version = 3;
+ // A model was constructed.
+ GraphDef graph_def = 4;
+ // A summary was generated.
+ Summary summary = 5;
+ }
+}
diff --git a/tensorflow/core/util/events_writer.cc b/tensorflow/core/util/events_writer.cc
new file mode 100644
index 0000000000..1b34a36577
--- /dev/null
+++ b/tensorflow/core/util/events_writer.cc
@@ -0,0 +1,144 @@
+#include "tensorflow/core/util/events_writer.h"
+
+#include <stddef.h> // for NULL
+
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/util/event.pb.h"
+
+namespace tensorflow {
+
+EventsWriter::EventsWriter(const string& file_prefix)
+ // TODO(jeff,sanjay): Pass in env and use that here instead of Env::Default
+ : env_(Env::Default()),
+ file_prefix_(file_prefix),
+ num_outstanding_events_(0) {}
+
+bool EventsWriter::Init() {
+ if (recordio_writer_.get() != nullptr) {
+ CHECK(!filename_.empty());
+ if (FileHasDisappeared()) {
+ // Warn user of data loss and let .reset() below do basic cleanup.
+ if (num_outstanding_events_ > 0) {
+ LOG(WARNING) << "Re-intialization, attempting to open a new file, "
+ << num_outstanding_events_ << " events will be lost.";
+ }
+ } else {
+ // No-op: File is present and writer is initialized.
+ return true;
+ }
+ }
+
+ int64 time_in_seconds = env_->NowMicros() / 1000000;
+
+ filename_ = strings::Printf(
+ "%s.out.tfevents.%010lld.%s", file_prefix_.c_str(),
+ static_cast<long long>(time_in_seconds), port::Hostname().c_str());
+ port::AdjustFilenameForLogging(&filename_);
+
+ WritableFile* file;
+ Status s = env_->NewWritableFile(filename_, &file);
+ if (!s.ok()) {
+ LOG(ERROR) << "Could not open events file: " << filename_ << ": " << s;
+ return false;
+ }
+ recordio_file_.reset(file);
+ recordio_writer_.reset(new io::RecordWriter(recordio_file_.get()));
+ if (recordio_writer_.get() == NULL) {
+ LOG(ERROR) << "Could not create record writer";
+ return false;
+ }
+ num_outstanding_events_ = 0;
+ VLOG(1) << "Successfully opened events file: " << filename_;
+ {
+ // Write the first event with the current version, and flush
+ // right away so the file contents will be easily determined.
+
+ Event event;
+ event.set_wall_time(time_in_seconds);
+ event.set_file_version(strings::StrCat(kVersionPrefix, kCurrentVersion));
+ WriteEvent(event);
+ Flush();
+ }
+ return true;
+}
+
+string EventsWriter::FileName() {
+ if (filename_.empty()) {
+ Init();
+ }
+ return filename_;
+}
+
+void EventsWriter::WriteSerializedEvent(const string& event_str) {
+ if (recordio_writer_.get() == NULL) {
+ if (!Init()) {
+ LOG(ERROR) << "Write failed because file could not be opened.";
+ return;
+ }
+ }
+ num_outstanding_events_++;
+ recordio_writer_->WriteRecord(event_str);
+}
+
+void EventsWriter::WriteEvent(const Event& event) {
+ string record;
+ event.AppendToString(&record);
+ WriteSerializedEvent(record);
+}
+
+bool EventsWriter::Flush() {
+ if (num_outstanding_events_ == 0) return true;
+ CHECK(recordio_file_.get() != NULL) << "Unexpected NULL file";
+ // The FileHasDisappeared() condition is necessary because
+ // recordio_writer_->Sync() can return true even if the underlying
+ // file has been deleted. EventWriter.FileDeletionBeforeWriting
+ // demonstrates this and will fail if the FileHasDisappeared()
+ // conditon is removed.
+ // Also, we deliberately attempt to Sync() before checking for a
+ // disappearing file, in case for some file system File::Exists() is
+ // false after File::Open() but before File::Sync().
+ if (!recordio_file_->Flush().ok() || !recordio_file_->Sync().ok() ||
+ FileHasDisappeared()) {
+ LOG(ERROR) << "Failed to flush " << num_outstanding_events_ << " events to "
+ << filename_;
+ return false;
+ }
+ VLOG(1) << "Wrote " << num_outstanding_events_ << " events to disk.";
+ num_outstanding_events_ = 0;
+ return true;
+}
+
+bool EventsWriter::Close() {
+ bool return_value = Flush();
+ if (recordio_file_.get() != NULL) {
+ Status s = recordio_file_->Close();
+ if (!s.ok()) {
+ LOG(ERROR) << "Error when closing previous event file: " << filename_
+ << ": " << s;
+ return_value = false;
+ }
+ recordio_writer_.reset(NULL);
+ recordio_file_.reset(NULL);
+ }
+ num_outstanding_events_ = 0;
+ return return_value;
+}
+
+bool EventsWriter::FileHasDisappeared() {
+ if (env_->FileExists(filename_)) {
+ return false;
+ } else {
+ // This can happen even with non-null recordio_writer_ if some other
+ // process has removed the file.
+ LOG(ERROR) << "The events file " << filename_ << " has disappeared.";
+ return true;
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/events_writer.h b/tensorflow/core/util/events_writer.h
new file mode 100644
index 0000000000..e6b94ad265
--- /dev/null
+++ b/tensorflow/core/util/events_writer.h
@@ -0,0 +1,77 @@
+#ifndef TENSORFLOW_UTIL_EVENTS_WRITER_H_
+#define TENSORFLOW_UTIL_EVENTS_WRITER_H_
+
+#include <memory>
+#include <string>
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/util/event.pb.h"
+
+namespace tensorflow {
+
+class EventsWriter {
+ public:
+#ifndef SWIG
+ // Prefix of version string present in the first entry of every event file.
+ static constexpr const char* kVersionPrefix = "brain.Event:";
+ static constexpr const int kCurrentVersion = 1;
+#endif
+
+ // Events files typically have a name of the form
+ // '/some/file/path/my.file.out.events.[timestamp].[hostname]'
+ // To create and EventWriter, the user should provide file_prefix =
+ // '/some/file/path/my.file'
+ // The EventsWriter will append '.out.events.[timestamp].[hostname]'
+ // to the ultimate filename once Init() is called.
+ // Note that it is not recommended to simultaneously have two
+ // EventWriters writing to the same file_prefix.
+ explicit EventsWriter(const string& file_prefix);
+ ~EventsWriter() { Close(); } // Autoclose in destructor.
+
+ // Sets the event file filename and opens file for writing. If not called by
+ // user, will be invoked automatically by a call to FileName() or Write*().
+ // Returns false if the file could not be opened. Idempotent: if file exists
+ // and is open this is a no-op. If on the other hand the file was opened,
+ // but has since disappeared (e.g. deleted by another process), this will open
+ // a new file with a new timestamp in its filename.
+ bool Init();
+
+ // Returns the filename for the current events file:
+ // filename_ = [file_prefix_].out.events.[timestamp].[hostname]
+ string FileName();
+
+ // Append "event" to the file. The "tensorflow::" part is for swig happiness.
+ void WriteEvent(const tensorflow::Event& event);
+
+ // Append "event_str", a serialized Event, to the file.
+ // Note that this function does NOT check that de-serializing event_str
+ // results in a valid Event proto.
+ void WriteSerializedEvent(const string& event_str);
+
+ // EventWriter automatically flushes and closes on destruction, but
+ // these two methods are provided for users who want to write to disk sooner
+ // and/or check for success.
+ // Flush() pushes outstanding events to disk. Returns false if the
+ // events file could not be created, or if the file exists but could not
+ // be written too.
+ // Close() calls Flush() and then closes the current events file.
+ // Returns true only if both the flush and the closure were successful.
+ bool Flush();
+ bool Close();
+
+ private:
+ bool FileHasDisappeared(); // True if event_file_path_ does not exist.
+
+ Env* env_;
+ const string file_prefix_;
+ string filename_;
+ std::unique_ptr<WritableFile> recordio_file_;
+ std::unique_ptr<io::RecordWriter> recordio_writer_;
+ int num_outstanding_events_;
+ TF_DISALLOW_COPY_AND_ASSIGN(EventsWriter);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_EVENTS_WRITER_H_
diff --git a/tensorflow/core/util/events_writer_test.cc b/tensorflow/core/util/events_writer_test.cc
new file mode 100644
index 0000000000..f6523ead92
--- /dev/null
+++ b/tensorflow/core/util/events_writer_test.cc
@@ -0,0 +1,198 @@
+#include "tensorflow/core/util/events_writer.h"
+
+#include <math.h>
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/io/record_reader.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/util/event.pb.h"
+
+namespace tensorflow {
+namespace {
+
+// shorthand
+Env* env() { return Env::Default(); }
+
+void WriteSimpleValue(EventsWriter* writer, double wall_time, int64 step,
+ const string& tag, float simple_value) {
+ Event event;
+ event.set_wall_time(wall_time);
+ event.set_step(step);
+ Summary::Value* summ_val = event.mutable_summary()->add_value();
+ summ_val->set_tag(tag);
+ summ_val->set_simple_value(simple_value);
+ writer->WriteEvent(event);
+}
+
+void WriteFile(EventsWriter* writer) {
+ WriteSimpleValue(writer, 1234, 34, "foo", 3.14159);
+ WriteSimpleValue(writer, 2345, 35, "bar", -42);
+}
+
+static bool ReadEventProto(io::RecordReader* reader, uint64* offset,
+ Event* proto) {
+ string record;
+ Status s = reader->ReadRecord(offset, &record);
+ if (!s.ok()) {
+ return false;
+ }
+ return ParseProtoUnlimited(proto, record);
+}
+
+void VerifyFile(const string& filename) {
+ CHECK(env()->FileExists(filename));
+ RandomAccessFile* event_file;
+ TF_CHECK_OK(env()->NewRandomAccessFile(filename, &event_file));
+ io::RecordReader* reader = new io::RecordReader(event_file);
+
+ uint64 offset = 0;
+
+ Event actual;
+ CHECK(ReadEventProto(reader, &offset, &actual));
+ VLOG(1) << actual.ShortDebugString();
+ // Wall time should be within 5s of now.
+
+ double current_time = env()->NowMicros() / 1000000.0;
+ EXPECT_LT(fabs(actual.wall_time() - current_time), 5);
+ // Should have the current version number.
+ EXPECT_EQ(actual.file_version(),
+ strings::StrCat(EventsWriter::kVersionPrefix,
+ EventsWriter::kCurrentVersion));
+
+ Event expected;
+ CHECK(ReadEventProto(reader, &offset, &actual));
+ VLOG(1) << actual.ShortDebugString();
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
+ "wall_time: 1234 step: 34 "
+ "summary { value { tag: 'foo' simple_value: 3.14159 } }",
+ &expected));
+ // TODO(keveman): Enable this check
+ // EXPECT_THAT(expected, EqualsProto(actual));
+
+ CHECK(ReadEventProto(reader, &offset, &actual));
+ VLOG(1) << actual.ShortDebugString();
+ ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
+ "wall_time: 2345 step: 35 "
+ "summary { value { tag: 'bar' simple_value: -42 } }",
+ &expected));
+ // TODO(keveman): Enable this check
+ // EXPECT_THAT(expected, EqualsProto(actual));
+
+ TF_CHECK_OK(env()->DeleteFile(filename));
+
+ delete reader;
+ delete event_file;
+}
+
+string GetDirName(const string& suffix) {
+ return io::JoinPath(testing::TmpDir(), suffix);
+}
+
+TEST(EventWriter, WriteFlush) {
+ string file_prefix = GetDirName("/writeflush_test");
+ EventsWriter writer(file_prefix);
+ WriteFile(&writer);
+ EXPECT_TRUE(writer.Flush());
+ string filename = writer.FileName();
+ VerifyFile(filename);
+}
+
+TEST(EventWriter, WriteClose) {
+ string file_prefix = GetDirName("/writeclose_test");
+ EventsWriter writer(file_prefix);
+ WriteFile(&writer);
+ EXPECT_TRUE(writer.Close());
+ string filename = writer.FileName();
+ VerifyFile(filename);
+}
+
+TEST(EventWriter, WriteDelete) {
+ string file_prefix = GetDirName("/writedelete_test");
+ EventsWriter* writer = new EventsWriter(file_prefix);
+ WriteFile(writer);
+ string filename = writer->FileName();
+ delete writer;
+ VerifyFile(filename);
+}
+
+TEST(EventWriter, FailFlush) {
+ string file_prefix = GetDirName("/failflush_test");
+ EventsWriter writer(file_prefix);
+ string filename = writer.FileName();
+ WriteFile(&writer);
+ EXPECT_TRUE(env()->FileExists(filename));
+ env()->DeleteFile(filename);
+ EXPECT_FALSE(env()->FileExists(filename));
+ EXPECT_FALSE(writer.Flush());
+ EXPECT_FALSE(env()->FileExists(filename));
+}
+
+TEST(EventWriter, FailClose) {
+ string file_prefix = GetDirName("/failclose_test");
+ EventsWriter writer(file_prefix);
+ string filename = writer.FileName();
+ WriteFile(&writer);
+ EXPECT_TRUE(env()->FileExists(filename));
+ env()->DeleteFile(filename);
+ EXPECT_FALSE(env()->FileExists(filename));
+ EXPECT_FALSE(writer.Close());
+ EXPECT_FALSE(env()->FileExists(filename));
+}
+
+TEST(EventWriter, InitWriteClose) {
+ string file_prefix = GetDirName("/initwriteclose_test");
+ EventsWriter writer(file_prefix);
+ EXPECT_TRUE(writer.Init());
+ string filename0 = writer.FileName();
+ EXPECT_TRUE(env()->FileExists(filename0));
+ WriteFile(&writer);
+ EXPECT_TRUE(writer.Close());
+ string filename1 = writer.FileName();
+ EXPECT_EQ(filename0, filename1);
+ VerifyFile(filename1);
+}
+
+TEST(EventWriter, NameWriteClose) {
+ string file_prefix = GetDirName("/namewriteclose_test");
+ EventsWriter writer(file_prefix);
+ string filename = writer.FileName();
+ EXPECT_TRUE(env()->FileExists(filename));
+ WriteFile(&writer);
+ EXPECT_TRUE(writer.Close());
+ VerifyFile(filename);
+}
+
+TEST(EventWriter, NameClose) {
+ string file_prefix = GetDirName("/nameclose_test");
+ EventsWriter writer(file_prefix);
+ string filename = writer.FileName();
+ EXPECT_TRUE(writer.Close());
+ EXPECT_TRUE(env()->FileExists(filename));
+ env()->DeleteFile(filename);
+}
+
+TEST(EventWriter, FileDeletionBeforeWriting) {
+ string file_prefix = GetDirName("/fdbw_test");
+ EventsWriter writer(file_prefix);
+ string filename0 = writer.FileName();
+ EXPECT_TRUE(env()->FileExists(filename0));
+ env()->SleepForMicroseconds(
+ 2000000); // To make sure timestamp part of filename will differ.
+ env()->DeleteFile(filename0);
+ EXPECT_TRUE(writer.Init()); // Init should reopen file.
+ WriteFile(&writer);
+ EXPECT_TRUE(writer.Flush());
+ string filename1 = writer.FileName();
+ EXPECT_NE(filename0, filename1);
+ VerifyFile(filename1);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/util/guarded_philox_random.cc b/tensorflow/core/util/guarded_philox_random.cc
new file mode 100644
index 0000000000..4cf58b8979
--- /dev/null
+++ b/tensorflow/core/util/guarded_philox_random.cc
@@ -0,0 +1,39 @@
+#include "tensorflow/core/util/guarded_philox_random.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+
+Status GuardedPhiloxRandom::Init(OpKernelConstruction* context) {
+ // Grab seed Attrs.
+ int64 seed, seed2;
+ auto status = context->GetAttr("seed", &seed);
+ if (!status.ok()) return status;
+ status = context->GetAttr("seed2", &seed2);
+ if (!status.ok()) return status;
+
+ // Initialize with the given seeds
+ Init(seed, seed2);
+ return Status::OK();
+}
+
+void GuardedPhiloxRandom::Init(int64 seed, int64 seed2) {
+ CHECK(!initialized_);
+ if (seed == 0 && seed2 == 0) {
+ // If both seeds are unspecified, use completely random seeds.
+ seed = random::New64();
+ seed2 = random::New64();
+ }
+ mutex_lock lock(mu_);
+ generator_ = random::PhiloxRandom(seed, seed2);
+ initialized_ = true;
+}
+
+random::PhiloxRandom GuardedPhiloxRandom::ReserveSamples128(int64 samples) {
+ CHECK(initialized_);
+ mutex_lock lock(mu_);
+ auto local = generator_;
+ generator_.Skip(samples);
+ return local;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/guarded_philox_random.h b/tensorflow/core/util/guarded_philox_random.h
new file mode 100644
index 0000000000..6e9cb9f99c
--- /dev/null
+++ b/tensorflow/core/util/guarded_philox_random.h
@@ -0,0 +1,56 @@
+#ifndef TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_
+#define TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// A thread safe wrapper around a Philox generator. Example usage:
+//
+// GuardedRandomPhilox generator;
+// generator.Init(context);
+//
+// // In thread safe code
+// const int samples = ...;
+// auto local_generator = generator.ReserveSamples128(samples);
+// for (int i = 0; i < samples; i++)
+// Array<uint32, 4> sample = local_generator();
+// // Use sample
+// }
+//
+class GuardedPhiloxRandom {
+ public:
+ // Must call Init to finish initialization
+ GuardedPhiloxRandom() : initialized_(false) {}
+
+ // Initialize the generator from attributes "seed" and "seed2".
+ // If both seeds are unspecified, use random seeds.
+ // Must be called exactly once.
+ Status Init(OpKernelConstruction* context);
+
+ // Initialize with given seeds.
+ void Init(int64 seed, int64 seed2);
+
+ // Reserve a certain number of 128-bit samples.
+ // This function is thread safe. The returned generator is valid for the
+ // given number of samples, and can be used without a lock.
+ random::PhiloxRandom ReserveSamples128(int64 samples);
+
+ // Reserve a certain number of 32-bit samples
+ random::PhiloxRandom ReserveSamples32(int64 samples) {
+ return ReserveSamples128((samples + 3) / 4);
+ }
+
+ private:
+ mutex mu_;
+ random::PhiloxRandom generator_ GUARDED_BY(mu_);
+ bool initialized_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GuardedPhiloxRandom);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_
diff --git a/tensorflow/core/util/padding.cc b/tensorflow/core/util/padding.cc
new file mode 100644
index 0000000000..24273e5ca4
--- /dev/null
+++ b/tensorflow/core/util/padding.cc
@@ -0,0 +1,24 @@
+#include "tensorflow/core/util/padding.h"
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+Status GetNodeAttr(const NodeDef& node_def, const string& attr_name,
+ Padding* value) {
+ string str_value;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attr_name, &str_value));
+ if (str_value == "SAME") {
+ *value = SAME;
+ } else if (str_value == "VALID") {
+ *value = VALID;
+ } else {
+ return errors::NotFound(str_value, " is not an allowed padding type");
+ }
+ return Status::OK();
+}
+
+string GetPaddingAttrString() { return "padding: {'SAME', 'VALID'}"; }
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/util/padding.h b/tensorflow/core/util/padding.h
new file mode 100644
index 0000000000..66cd96abdb
--- /dev/null
+++ b/tensorflow/core/util/padding.h
@@ -0,0 +1,37 @@
+#ifndef TENSORFLOW_UTIL_PADDING_H_
+#define TENSORFLOW_UTIL_PADDING_H_
+
+// This file contains helper routines to deal with padding in various ops and
+// kernels.
+
+#include <string>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// Padding: the padding we apply to the input tensor along the rows and columns
+// dimensions. This is usually used to make sure that the spatial dimensions do
+// not shrink when we progress with convolutions. Two types of padding are
+// supported:
+// VALID: No padding is carried out.
+// SAME: The pad value is computed so that the output will have the same
+// dimensions as the input.
+// The padded area is zero-filled.
+enum Padding {
+ VALID = 1, // No padding.
+ SAME = 2, // Input and output layers have the same size.
+};
+
+// Return the string containing the list of valid padding types, that can be
+// used as an Attr() in REGISTER_OP.
+string GetPaddingAttrString();
+
+// Specialization to parse an attribute directly into a Padding enum.
+Status GetNodeAttr(const NodeDef& node_def, const string& attr_name,
+ Padding* value);
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_PADDING_H_
diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc
new file mode 100644
index 0000000000..12eb076a4d
--- /dev/null
+++ b/tensorflow/core/util/port.cc
@@ -0,0 +1,13 @@
+#include "tensorflow/core/util/port.h"
+
+namespace tensorflow {
+
+bool IsGoogleCudaEnabled() {
+#if GOOGLE_CUDA
+ return true;
+#else
+ return false;
+#endif
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/util/port.h b/tensorflow/core/util/port.h
new file mode 100644
index 0000000000..8b9d033d63
--- /dev/null
+++ b/tensorflow/core/util/port.h
@@ -0,0 +1,11 @@
+#ifndef TENSORFLOW_UTIL_PORT_H_
+#define TENSORFLOW_UTIL_PORT_H_
+
+namespace tensorflow {
+
+// Returns true if GOOGLE_CUDA is defined.
+bool IsGoogleCudaEnabled();
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_PORT_H_
diff --git a/tensorflow/core/util/saved_tensor_slice.proto b/tensorflow/core/util/saved_tensor_slice.proto
new file mode 100644
index 0000000000..f6599d9669
--- /dev/null
+++ b/tensorflow/core/util/saved_tensor_slice.proto
@@ -0,0 +1,76 @@
+// Protocol buffers for saved tensor slices. It's used for the brain tensor
+// ops checkpoints and the V3 checkpoints in dist_belief.
+
+// A checkpoint file is an sstable. The value for each record is a serialized
+// SavedTensorSlices message (defined below).
+//
+// Each checkpoint file has a record with the empty key (""), which corresponds
+// to a SavedTensorSlices message that contains a "meta", that serves as a
+// table of contents on all the tensor slices saved in this file. Since the key
+// is "", it's always the first record in each file.
+//
+// Each of the rest of the records in a checkpoint stores the raw data of a
+// particular tensor slice, in SavedSlice format. The corresponding key is an
+// ordered code that encodes the name of the tensor and the slice
+// information. The name is also stored in the SaveSlice message for ease of
+// debugging and manual examination.
+
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/tensor_slice.proto";
+import "tensorflow/core/framework/tensor.proto";
+import "tensorflow/core/framework/types.proto";
+
+// Metadata describing the set of slices of the same tensor saved in a
+// checkpoint file.
+message SavedSliceMeta {
+ // Name of the tensor.
+ string name = 1;
+
+ // Shape of the tensor
+ TensorShapeProto shape = 2;
+
+ // Type of the tensor
+ DataType type = 3;
+
+ // Explicit list of slices saved in the checkpoint file.
+ repeated TensorSliceProto slice = 4;
+};
+
+// Metadata describing the set of tensor slices saved in a checkpoint file.
+// It is always stored at the beginning of each checkpoint file.
+message SavedTensorSliceMeta {
+ // Each SavedSliceMeta describes the slices for one tensor.
+ repeated SavedSliceMeta tensor = 1;
+};
+
+// Saved tensor slice: it stores the name of the tensors, the slice, and the
+// raw data.
+message SavedSlice {
+ // Name of the tensor that this slice belongs to. This must be identical to
+ // the name used to encode the key for this record.
+ string name = 1;
+
+ // Extent of the slice. Must have one entry for each of the dimension of the
+ // tensor that this slice belongs to.
+ TensorSliceProto slice = 2;
+
+ // The raw data of the slice is stored as a TensorProto. Only raw data are
+ // stored (we don't fill in fields such as dtype or tensor_shape).
+ TensorProto data = 3;
+};
+
+// Each record in a v3 checkpoint file is a serialized SavedTensorSlices
+// message.
+message SavedTensorSlices {
+ // This is only present at the first item of each checkpoint file and serves
+ // as a table of contents, listing all the tensor slices saved in this file.
+ SavedTensorSliceMeta meta = 1;
+
+ // This exists in all but the first item of each checkpoint file.
+ SavedSlice data = 2;
+};
diff --git a/tensorflow/core/util/saved_tensor_slice_util.cc b/tensorflow/core/util/saved_tensor_slice_util.cc
new file mode 100644
index 0000000000..7a5903f07f
--- /dev/null
+++ b/tensorflow/core/util/saved_tensor_slice_util.cc
@@ -0,0 +1,76 @@
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/ordered_code.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+const char kSavedTensorSlicesKey[] = "";
+
+string EncodeTensorNameSlice(const string& name, const TensorSlice& slice) {
+ string buffer;
+ // All the tensor slice keys will start with a 0
+ tensorflow::strings::OrderedCode::WriteNumIncreasing(&buffer, 0);
+ tensorflow::strings::OrderedCode::WriteString(&buffer, name);
+ tensorflow::strings::OrderedCode::WriteNumIncreasing(&buffer, slice.dims());
+ for (int d = 0; d < slice.dims(); ++d) {
+ // A trivial extent (meaning we take EVERYTHING) will default to -1 for both
+ // start and end. These will be properly parsed.
+ tensorflow::strings::OrderedCode::WriteSignedNumIncreasing(&buffer,
+ slice.start(d));
+ tensorflow::strings::OrderedCode::WriteSignedNumIncreasing(&buffer,
+ slice.length(d));
+ }
+ return buffer;
+}
+
+Status DecodeTensorNameSlice(const string& code, string* name,
+ tensorflow::TensorSlice* slice) {
+ StringPiece src(code);
+ uint64 x;
+ if (!tensorflow::strings::OrderedCode::ReadNumIncreasing(&src, &x)) {
+ return errors::Internal("Failed to parse the leading number: src = ", src);
+ }
+ if (x != 0) {
+ return errors::Internal(
+ "The leading number should always be 0 for any valid key: src = ", src);
+ }
+ if (!tensorflow::strings::OrderedCode::ReadString(&src, name)) {
+ return errors::Internal("Failed to parse the tensor name: src = ", src);
+ }
+ if (!tensorflow::strings::OrderedCode::ReadNumIncreasing(&src, &x)) {
+ return errors::Internal("Failed to parse the tensor rank: src = ", src);
+ }
+ if (x == 0) {
+ return errors::Internal("Expecting positive rank of the tensor, got ", x,
+ ", src = ", src);
+ }
+ if (x >= kint32max) {
+ return errors::Internal("Too many elements ", x);
+ }
+ slice->SetFullSlice(x);
+ for (int d = 0; d < static_cast<int32>(x); ++d) {
+ // We expected 2x integers
+ int64 start, length;
+ if (!tensorflow::strings::OrderedCode::ReadSignedNumIncreasing(&src,
+ &start)) {
+ return errors::Internal("Failed to parse start: src = ", src);
+ }
+ if (!tensorflow::strings::OrderedCode::ReadSignedNumIncreasing(&src,
+ &length)) {
+ return errors::Internal("Failed to parse length: src = ", src);
+ }
+ if (length >= 0) {
+ // a non-trivial extent
+ slice->set_start(d, start);
+ slice->set_length(d, length);
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/saved_tensor_slice_util.h b/tensorflow/core/util/saved_tensor_slice_util.h
new file mode 100644
index 0000000000..6206cd8538
--- /dev/null
+++ b/tensorflow/core/util/saved_tensor_slice_util.h
@@ -0,0 +1,110 @@
+// Utilities for saving/restoring tensor slice checkpoints.
+
+#ifndef TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
+#define TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
+
+#include <string> // for string
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/status.h" // for Status
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+// The key for the metadata in the tensor slice checkpoint files. It is "" so
+// that the metadata is always at the beginning of a checkpoint file.
+extern const char kSavedTensorSlicesKey[];
+
+// Encode a tensor name + a tensor slice into an ordered code and outputs it as
+// a string.
+// The format is
+// <0>
+// <tensor_name>
+// <rank>
+// <dim-0-start><dim-0-length>
+// <dim-1-start><dim-1-length>
+// ...
+
+string EncodeTensorNameSlice(const string& name,
+ const tensorflow::TensorSlice& slice);
+
+// Parse out the name and the slice from string encoded as an ordered code.
+Status DecodeTensorNameSlice(const string& code, string* name,
+ tensorflow::TensorSlice* slice);
+
+template <typename T>
+struct SaveTypeTraits;
+
+template <typename T>
+const typename SaveTypeTraits<T>::SavedType* TensorProtoData(
+ const TensorProto& t);
+
+template <typename T>
+protobuf::RepeatedField<typename SaveTypeTraits<T>::SavedType>*
+MutableTensorProtoData(TensorProto* t);
+
+template <typename T>
+void Fill(T* data, size_t n, TensorProto* t);
+
+#define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE) \
+ template <> \
+ struct SaveTypeTraits<TYPE> { \
+ static constexpr bool supported = true; \
+ typedef FTYPE SavedType; \
+ }; \
+ template <> \
+ inline const FTYPE* TensorProtoData<TYPE>(const TensorProto& t) { \
+ static_assert(SaveTypeTraits<TYPE>::supported, \
+ "Specified type " #TYPE " not supported for Restore"); \
+ return reinterpret_cast<const FTYPE*>(t.FIELD##_val().data()); \
+ } \
+ template <> \
+ inline protobuf::RepeatedField<FTYPE>* MutableTensorProtoData<TYPE>( \
+ TensorProto * t) { \
+ static_assert(SaveTypeTraits<TYPE>::supported, \
+ "Specified type " #TYPE " not supported for Save"); \
+ return reinterpret_cast<protobuf::RepeatedField<FTYPE>*>( \
+ t->mutable_##FIELD##_val()); \
+ } \
+ template <> \
+ inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \
+ typename protobuf::RepeatedField<FTYPE> copy(data, data + n); \
+ t->mutable_##FIELD##_val()->Swap(&copy); \
+ }
+
+TENSOR_PROTO_EXTRACT_TYPE(float, float, float);
+TENSOR_PROTO_EXTRACT_TYPE(double, double, double);
+TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(int64, int64, int64);
+TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32);
+
+#undef TENSOR_PROTO_EXTRACT_TYPE
+
+template <>
+struct SaveTypeTraits<qint32> : SaveTypeTraits<int32> {};
+
+template <>
+inline const int32* TensorProtoData<qint32>(const TensorProto& t) {
+ static_assert(SaveTypeTraits<qint32>::supported,
+ "Specified type qint32 not supported for Restore");
+ return reinterpret_cast<const int32*>(t.int_val().data());
+}
+
+inline void Fill(const qint32* data, size_t n, TensorProto* t) {
+ const int32* p = reinterpret_cast<const int32*>(data);
+ typename protobuf::RepeatedField<int32> copy(p, p + n);
+ t->mutable_int_val()->Swap(&copy);
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
diff --git a/tensorflow/core/util/saved_tensor_slice_util_test.cc b/tensorflow/core/util/saved_tensor_slice_util_test.cc
new file mode 100644
index 0000000000..2c34c903db
--- /dev/null
+++ b/tensorflow/core/util/saved_tensor_slice_util_test.cc
@@ -0,0 +1,32 @@
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+namespace {
+
+// Testing serialization of tensor name and tensor slice in the ordered code
+// format.
+TEST(TensorShapeUtilTest, TensorNameSliceToOrderedCode) {
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("-:-:1,3:4,5");
+ string buffer = EncodeTensorNameSlice("foo", s);
+ string name;
+ s.Clear();
+ TF_CHECK_OK(DecodeTensorNameSlice(buffer, &name, &s));
+ EXPECT_EQ("foo", name);
+ EXPECT_EQ("-:-:1,3:4,5", s.DebugString());
+ }
+}
+
+} // namespace
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/sparse/README.md b/tensorflow/core/util/sparse/README.md
new file mode 100644
index 0000000000..7b0799eb0e
--- /dev/null
+++ b/tensorflow/core/util/sparse/README.md
@@ -0,0 +1,222 @@
+SparseTensor
+============
+
+Sparse Tensors are stored as two dense tensors and a shape:
+
+* `indices`: a `brain::Tensor` storing a matrix, typically `int64`
+* `values`: a `brain::Tensor` storing a vector with values of type T.
+* `shape`: a `TensorShape` storing the bounds of the underlying tensor
+* `order`: (optional) a `gtl::InlinedVector<int64,8>` with the dimensions
+ along which the indices are ordered.
+
+Let
+
+ ix = indices.matrix<int64>()
+ vals = values.vec<T>()
+
+The shape of `ix` is `N x NDIMS`, and each row corresponds to the
+index of a single element of the sparse tensor.
+
+The length of `vals` must be `N`, and `vals(i)` corresponds to the
+value with index `ix(i,:)`.
+
+Shape must be a `TensorShape` with `dims() == NDIMS`.
+The shape is the full shape of the dense tensor these indices
+represent.
+
+To be specific, the representation (pseudocode) is:
+
+ tensor[ix[i,:]] == vals[i] for i = 0, ..., N-1
+
+Ordering
+--------
+
+Indices need not be provided in order. For example, the following
+index matrix is ordered according to dimension order `{0, 1, 2}`.
+
+ [0 0 1]
+ [0 1 1]
+ [2 0 2]
+
+However, you can provide an unordered version:
+
+ [2 0 2]
+ [0 0 1]
+ [0 1 1]
+
+If the SparseTensor is constructed without a provided order, then a
+the default order is `{-1, ..., -1}`. Certain operations will fail or crash
+when the order is not provided.
+
+Resorting the SparseTensor in-place (which resorts the underlying index and
+values tensors in-place) will update the order. The cost of reordering the
+matrix is `O(N*log(N))`, and requires `O(N)` additional temporary space to store
+a reordering index. If the default order is not specified and reordering is not
+performed, the following will happen:
+
+* `group()` will **raise an assertion failure**
+* `IndicesValid()` will **raise an assertion failure**
+
+To update the internal index ordering after construction, call
+`Reorder<T>()` via, e.g., `Reorder<T>({0,1,2})`.
+After this step, all the above methods should work correctly.
+
+The method `IndicesValid()` checks to make sure:
+
+* `0 <= ix(i, d) < shape.dim_size(d)`
+* indices do not repeat
+* indices are in order
+
+Iterating
+---------
+
+### group({grouping dims})
+
+* provides an iterator that groups entries according to
+ dimensions you care about
+* may require a sort if your data isn't presorted in a way that's
+ compatible with grouping_dims
+* for each group, returns the group index (values of the group
+ dims for this iteration), the subset of indices in this group,
+ and the subset of values in this group. these are lazy outputs
+ so to read them individually, copy them as per the example
+ below.
+
+#### **NOTE**
+`group({dim0, ..., dimk})` will **raise an assertion failure** if the
+order of the SparseTensor does not match the dimensions you wish to group by.
+You must either have your indices in the correct order and construct the
+SparseTensor with
+
+ order = {dim0, ..., dimk, ...}
+
+or call
+
+ Reorder<T>({dim0, .., dimk, ...})
+
+to sort the SparseTensor before grouping.
+
+Example of grouping:
+
+ Tensor indices(DT_INT64, TensorShape({N, NDIMS});
+ Tensor values(DT_STRING, TensorShape({N});
+ TensorShape shape({dim0,...});
+ SparseTensor sp(indices, vals, shape);
+ sp.Reorder<string>({1, 2, 0, 3, ...}); // Must provide NDIMS dims.
+ // group according to dims 1 and 2
+ for (const auto& g : sp.group({1, 2})) {
+ cout << "vals of ix[:, 1,2] for this group: "
+ << g.group()[0] << ", " << g.group()[1];
+ cout << "full indices of group:\n" << g.indices();
+ cout << "values of group:\n" << g.values();
+
+ TTypes<int64>::UnalignedMatrix g_ix = g.indices();
+ TTypes<string>::UnalignedVec g_v = g.values();
+ ASSERT(g_ix.dimension(0) == g_v.size()); // number of elements match.
+ }
+
+
+ToDense
+--------
+
+Converts sparse tensor to dense. You must provide a pointer to the
+dense tensor (preallocated). `ToDense()` will optionally
+preinitialize the tensor with zeros.
+
+Shape checking is performed, as is boundary checking.
+
+ Tensor indices(DT_INT64, TensorShape({N, NDIMS});
+ Tensor values(DT_STRING, TensorShape({N});
+ TensorShape shape({dim0,...});
+ SparseTensor sp(indices, vals, shape);
+ ASSERT(sp.IndicesValid()); // checks ordering & index bounds.
+
+ Tensor dense(DT_STRING, shape);
+ // initialize other indices to zero. copy.
+ ASSERT(sp.ToDense<string>(&dense, true));
+
+
+Concat
+--------
+
+Concatenates multiple SparseTensors and returns a new SparseTensor.
+This concatenation is with respect to the "dense" versions of these
+SparseTensors. Concatenation is performed along dimension order[0]
+of all tensors. As a result, shape[order[0]] may differ across
+the inputs, but shape[d] for d != order[0] must match across all inputs.
+
+We call order[0] the **primary dimension**.
+
+**Prerequisites**
+
+* The inputs' ranks must all match.
+* The inputs' order[0] must all match.
+* The inputs' shapes must all match except for dimension order[0].
+* The inputs' values must all be of the same type.
+
+If any of these are false, concat will die with an assertion failure.
+
+Example:
+Concatenate two sparse matrices along columns.
+
+Matrix 1:
+
+ [0 0 1]
+ [2 0 0]
+ [3 0 4]
+
+Matrix 2:
+
+ [0 0 0 0 0]
+ [0 1 0 0 0]
+ [2 0 0 1 0]
+
+Concatenated Matrix:
+
+ [0 0 1 0 0 0 0 0]
+ [2 0 0 0 1 0 0 0]
+ [3 0 4 2 0 0 1 0]
+
+Expected input shapes, orders, and `nnz()`:
+
+ shape_1 = TensorShape({3, 3})
+ shape_2 = TensorShape({3, 8})
+ order_1 = {1, 0} // primary order is 1, columns
+ order_2 = {1, 0} // primary order is 1, must match
+ nnz_1 = 4
+ nnz_2 = 3
+
+Output shapes and orders:
+
+ conc_shape = TensorShape({3, 11}) // primary dim increased, others same
+ conc_order = {1, 0} // Orders match along all inputs
+ conc_nnz = 7 // Sum of nonzeros of inputs
+
+Coding Example:
+
+ Tensor ix1(DT_INT64, TensorShape({N1, 3});
+ Tensor vals1(DT_STRING, TensorShape({N1, 3});
+ Tensor ix2(DT_INT64, TensorShape({N2, 3});
+ Tensor vals2(DT_STRING, TensorShape({N2, 3});
+ Tensor ix3(DT_INT64, TensorShape({N3, 3});
+ Tensor vals3(DT_STRING, TensorShape({N3, 3});
+
+ SparseTensor st1(ix1, vals1, TensorShape({10, 20, 5}), {1, 0, 2});
+ SparseTensor st2(ix2, vals2, TensorShape({10, 10, 5}), {1, 0, 2});
+ // For kicks, st3 indices are out of order, but order[0] matches so we
+ // can still concatenate along this dimension.
+ SparseTensor st3(ix3, vals3, TensorShape({10, 30, 5}), {1, 2, 0});
+
+ SparseTensor conc = SparseTensor::Concat<string>({st1, st2, st3});
+ Tensor ix_conc = conc.indices();
+ Tensor vals_conc = conc.values();
+ EXPECT_EQ(conc.nnz(), st1.nnz() + st2.nnz() + st3.nnz());
+ EXPECT_EQ(conc.Shape(), TensorShape({10, 60, 5}));
+ EXPECT_EQ(conc.Order(), {-1, -1, -1});
+
+ // Reorder st3 so all input tensors have the exact same orders.
+ st3.Reorder<string>({1, 0, 2});
+ SparseTensor conc2 = SparseTensor::Concat<string>({st1, st2, st3});
+ EXPECT_EQ(conc2.Order(), {1, 0, 2});
+ // All indices' orders matched, so output is in order.
+ EXPECT_TRUE(conc2.IndicesValid());
diff --git a/tensorflow/core/util/sparse/dim_comparator.h b/tensorflow/core/util/sparse/dim_comparator.h
new file mode 100644
index 0000000000..57473867cf
--- /dev/null
+++ b/tensorflow/core/util/sparse/dim_comparator.h
@@ -0,0 +1,60 @@
+#ifndef TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
+#define TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace sparse {
+
+/////////////////
+// DimComparator
+/////////////////
+//
+// Helper class, mainly used by the IndexSortOrder. This comparator
+// can be passed to e.g. std::sort, or any other sorter, to sort two
+// rows of an index matrix according to the dimension(s) of interest.
+// The dimensions to sort by are passed to the constructor as "order".
+//
+// Example: if given index matrix IX, two rows ai and bi, and order = {2,1}.
+// operator() compares
+// IX(ai,2) < IX(bi,2).
+// If IX(ai,2) == IX(bi,2), it compares
+// IX(ai,1) < IX(bi,1).
+//
+// This can be used to sort a vector of row indices into IX according to
+// the values in IX in particular columns (dimensions) of interest.
+class DimComparator {
+ public:
+ typedef typename gtl::ArraySlice<int64> VarDimArray;
+
+ inline DimComparator(const TTypes<int64>::Matrix& ix,
+ const VarDimArray& order, int dims)
+ : ix_(ix), order_(order), dims_(dims) {
+ CHECK_GT(order.size(), 0) << "Must order using at least one index";
+ CHECK_LE(order.size(), dims_) << "Can only sort up to dims";
+ for (size_t d = 0; d < order.size(); ++d) {
+ CHECK_GE(order[d], 0);
+ CHECK_LT(order[d], dims);
+ }
+ }
+
+ inline bool operator()(const int64 i, const int64 j) const {
+ for (int di = 0; di < dims_; ++di) {
+ const int64 d = order_[di];
+ if (ix_(i, d) < ix_(j, d)) return true;
+ if (ix_(i, d) > ix_(j, d)) return false;
+ }
+ return false;
+ }
+
+ const TTypes<int64>::Matrix ix_;
+ const VarDimArray order_;
+ const int dims_;
+};
+
+} // namespace sparse
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_SPARSE_DIM_COMPARATOR_H_
diff --git a/tensorflow/core/util/sparse/group_iterator.cc b/tensorflow/core/util/sparse/group_iterator.cc
new file mode 100644
index 0000000000..e153bcdbb4
--- /dev/null
+++ b/tensorflow/core/util/sparse/group_iterator.cc
@@ -0,0 +1,49 @@
+#include "tensorflow/core/util/sparse/group_iterator.h"
+
+namespace tensorflow {
+namespace sparse {
+
+void GroupIterable::IteratorStep::UpdateEndOfGroup() {
+ ++next_loc_;
+ int64 N = iter_->ix_.dim_size(0);
+ auto ix_t = iter_->ix_.template matrix<int64>();
+ while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) {
+ ++next_loc_;
+ }
+}
+
+bool GroupIterable::IteratorStep::operator!=(const IteratorStep& rhs) const {
+ CHECK_EQ(rhs.iter_, iter_) << "Can't compare steps from different iterators";
+ return (rhs.loc_ != loc_);
+}
+
+GroupIterable::IteratorStep& GroupIterable::IteratorStep::
+operator++() { // prefix ++
+ loc_ = next_loc_;
+ UpdateEndOfGroup();
+ return *this;
+}
+
+GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++(
+ int) { // postfix ++
+ IteratorStep lhs(*this);
+ ++(*this);
+ return lhs;
+}
+
+std::vector<int64> Group::group() const {
+ std::vector<int64> g;
+ auto ix_t = iter_->ix_.template matrix<int64>();
+ for (const int d : iter_->group_dims_) {
+ g.push_back(ix_t(loc_, d));
+ }
+ return g;
+}
+
+TTypes<int64>::UnalignedConstMatrix Group::indices() const {
+ return TTypes<int64>::UnalignedConstMatrix(
+ &(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_);
+}
+
+} // namespace sparse
+} // namespace tensorflow
diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h
new file mode 100644
index 0000000000..8423d54f27
--- /dev/null
+++ b/tensorflow/core/util/sparse/group_iterator.h
@@ -0,0 +1,120 @@
+#ifndef TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
+#define TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace sparse {
+
+class GroupIterable; // Predeclare GroupIterable for Group.
+
+// This class is returned when dereferencing a GroupIterable iterator.
+// It provides the methods group(), indices(), and values(), which
+// provide access into the underlying SparseTensor.
+class Group {
+ public:
+ Group(GroupIterable* iter, int64 loc, int64 next_loc)
+ : iter_(iter), loc_(loc), next_loc_(next_loc) {}
+
+ std::vector<int64> group() const;
+ TTypes<int64>::UnalignedConstMatrix indices() const;
+ template <typename T>
+ typename TTypes<T>::UnalignedVec values() const;
+
+ private:
+ GroupIterable* iter_;
+ int64 loc_;
+ int64 next_loc_;
+};
+
+/////////////////
+// GroupIterable
+/////////////////
+//
+// Returned when calling sparse_tensor.group({dim0, dim1, ...}).
+//
+// Please note: the sparse_tensor should already be ordered according
+// to {dim0, dim1, ...}. Otherwise this iteration will return invalid groups.
+//
+// Allows grouping and iteration of the SparseTensor according to the
+// subset of dimensions provided to the group call.
+//
+// The actual grouping dimensions are stored in the
+// internal vector group_dims_. Iterators inside the iterable provide
+// the three methods:
+//
+// * group(): returns a vector with the current group dimension values.
+// * indices(): a map of index, providing the indices in
+// this group.
+// * values(): a map of values, providing the values in
+// this group.
+//
+// To iterate across GroupIterable, see examples in README.md.
+//
+
+// Forward declaration of SparseTensor
+class GroupIterable {
+ public:
+ typedef gtl::ArraySlice<int64> VarDimArray;
+
+ GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
+ : ix_(ix), vals_(vals), dims_(dims), group_dims_(group_dims) {}
+
+ class IteratorStep;
+
+ IteratorStep begin() { return IteratorStep(this, 0); }
+ IteratorStep end() { return IteratorStep(this, ix_.dim_size(0)); }
+
+ template <typename TIX>
+ inline bool GroupMatches(const TIX& ix, int64 loc_a, int64 loc_b) const {
+ bool matches = true;
+ for (int d : group_dims_) {
+ if (ix(loc_a, d) != ix(loc_b, d)) {
+ matches = false;
+ }
+ }
+ return matches;
+ }
+
+ class IteratorStep {
+ public:
+ IteratorStep(GroupIterable* iter, int64 loc)
+ : iter_(iter), loc_(loc), next_loc_(loc_) {
+ UpdateEndOfGroup();
+ }
+
+ void UpdateEndOfGroup();
+ bool operator!=(const IteratorStep& rhs) const;
+ IteratorStep& operator++(); // prefix ++
+ IteratorStep operator++(int); // postfix ++
+ Group operator*() const { return Group(iter_, loc_, next_loc_); }
+
+ private:
+ GroupIterable* iter_;
+ int64 loc_;
+ int64 next_loc_;
+ };
+
+ private:
+ friend class Group;
+ Tensor ix_;
+ Tensor vals_;
+ const int dims_;
+ const VarDimArray group_dims_;
+};
+
+// Implementation of Group::values<T>()
+template <typename T>
+typename TTypes<T>::UnalignedVec Group::values() const {
+ return typename TTypes<T>::UnalignedVec(&(iter_->vals_.vec<T>()(loc_)),
+ next_loc_ - loc_);
+}
+
+} // namespace sparse
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_SPARSE_GROUP_ITERATOR_H_
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
new file mode 100644
index 0000000000..dcb75e7f54
--- /dev/null
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -0,0 +1,353 @@
+#ifndef TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
+#define TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
+
+#include <limits>
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/util/sparse/dim_comparator.h"
+#include "tensorflow/core/util/sparse/group_iterator.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace sparse {
+
+class SparseTensor {
+ public:
+ typedef typename gtl::ArraySlice<int64> VarDimArray;
+
+ SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
+ : SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {}
+
+ SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
+ const VarDimArray& order)
+ : ix_(ix),
+ vals_(vals),
+ shape_(shape),
+ order_(order.begin(), order.end()),
+ dims_(GetDimsFromIx(ix)) {
+ CHECK_EQ(ix.dtype(), DT_INT64) << "indices must be type int64 but got: "
+ << ix.dtype();
+ CHECK(TensorShapeUtils::IsMatrix(ix.shape()))
+ << "indices must be a matrix, but got: " << ix.shape().DebugString();
+ CHECK(TensorShapeUtils::IsVector(vals.shape()))
+ << "vals must be a vec, but got: " << vals.shape().DebugString();
+ CHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0))
+ << "indices and values rows (indexing dimension) must match.";
+ }
+
+ std::size_t num_entries() const { return ix_.dim_size(0); }
+
+ const Tensor& indices() const { return ix_; }
+
+ const Tensor& values() const { return vals_; }
+
+ DataType dtype() const { return vals_.dtype(); }
+
+ bool IndicesValid() const {
+ const auto ix_t = ix_.matrix<int64>();
+ for (int64 ord : order_) {
+ CHECK_GE(ord, 0) << "Order was not provided. Provide an order at "
+ "construction time or run ReorderInPlace";
+ }
+
+ for (std::size_t n = 0; n < num_entries(); ++n) {
+ if (!IndexValid(ix_t, n)) return false;
+ }
+
+ return true;
+ }
+
+ // Returns the tensor shape (the dimensions of the "densified"
+ // tensor this tensor represents).
+ const TensorShape shape() const { return shape_; }
+
+ const VarDimArray order() const { return order_; }
+
+ // Resorts the indices and values according to the dimensions in order.
+ template <typename T>
+ void Reorder(const VarDimArray& order);
+
+ // Returns a group iterable that can be used for clumping indices
+ // and values according to the group indices of interest.
+ //
+ // Precondition: order()[0..group_ix.size()] == group_ix.
+ //
+ // See the README.md in this directory for more usage information.
+ GroupIterable group(const VarDimArray& group_ix) {
+ CHECK_LE(group_ix.size(), dims_);
+ for (std::size_t di = 0; di < group_ix.size(); ++di) {
+ CHECK_GE(group_ix[di], 0) << "Group dimension out of range";
+ CHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
+ CHECK_EQ(group_ix[di], order_[di])
+ << "Group dimension does not match sorted order";
+ }
+ return GroupIterable(ix_, vals_, dims_, group_ix);
+ }
+
+ // Stores the sparse indices into the dense tensor out.
+ // Preconditions:
+ // out->shape().dims() == shape().dims()
+ // out->shape().dim_size(d) >= shape(d) for all d
+ //
+ // Returns true on success. False on failure (mismatched dimensions
+ // or out-of-bounds indices).
+ //
+ // If initialize==True, ToDense first overwrites all coefficients in out to 0.
+ //
+ template <typename T>
+ bool ToDense(Tensor* out, bool initialize = true);
+
+ // Concat() will concatenate all the tensors according to their first order
+ // dimension. All tensors must have identical shape except for
+ // the first order dimension. All tensors orders' first dimension
+ // must match.
+ //
+ // If all of the tensors have identical ordering, then the output
+ // will have this ordering. Otherwise the output is set as not
+ // having any order and a Reorder<T>() should be called on it before
+ // performing any subsequent operations.
+ template <typename T>
+ static SparseTensor Concat(const gtl::ArraySlice<SparseTensor>& tensors);
+
+ private:
+ static int GetDimsFromIx(const Tensor& ix) {
+ CHECK(TensorShapeUtils::IsMatrix(ix.shape()));
+ return ix.dim_size(1);
+ }
+
+ static gtl::InlinedVector<int64, 8> UndefinedOrder(const TensorShape& shape) {
+ return gtl::InlinedVector<int64, 8>(shape.dims(), -1);
+ }
+
+ // Helper for IndicesValid()
+ inline bool IndexValid(const TTypes<int64>::ConstMatrix& ix_t,
+ int64 n) const {
+ bool different = false;
+ bool bad_order = false;
+ bool valid = true;
+ if (n == 0) {
+ for (int di = 0; di < dims_; ++di) {
+ if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_.dim_size(di))
+ valid = false;
+ }
+ different = true;
+ } else {
+ for (int di = 0; di < dims_; ++di) {
+ if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_.dim_size(di))
+ valid = false;
+ int64 diff = ix_t(n, order_[di]) - ix_t(n - 1, order_[di]);
+ if (diff > 0) different = true;
+ if (!different && diff < 0) bad_order = true;
+ }
+ }
+ if (!valid) return false; // Out of bounds
+ if (!different) return false; // The past two indices are identical...
+ if (bad_order) return false; // Decreasing in order.
+ return true;
+ }
+
+ // Helper for ToDense<T>()
+ template <typename T>
+ bool ValidateAndInitializeToDense(Tensor* out, bool initialize);
+
+ Tensor ix_;
+ Tensor vals_;
+ TensorShape shape_;
+ gtl::InlinedVector<int64, 8> order_;
+ const int dims_;
+};
+
+// This operation updates the indices and values Tensor rows, so it is
+// an in-place algorithm. It requires O(N log N) time and O(N)
+// temporary space.
+template <typename T>
+void SparseTensor::Reorder(const VarDimArray& order) {
+ CHECK_EQ(DataTypeToEnum<T>::v(), dtype())
+ << "Reorder requested with the wrong datatype";
+ CHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
+ auto ix_t = ix_.matrix<int64>();
+ auto vals_t = vals_.vec<T>();
+
+ DimComparator sorter(ix_t, order, dims_);
+
+ std::vector<int64> reorder(num_entries());
+ std::iota(reorder.begin(), reorder.end(), 0);
+
+ // Sort to get order of indices
+ std::sort(reorder.begin(), reorder.end(), sorter);
+
+ // We have a forward reordering, but what we'll need is a
+ // permutation (the inverse). This can be calculated with O(1)
+ // additional
+ // and O(n) time (INVPERM) but we just do the simple thing here.
+ std::vector<int64> permutation(reorder.size());
+ for (std::size_t n = 0; n < reorder.size(); ++n) {
+ permutation[reorder[n]] = n;
+ }
+
+ // Update indices & values by converting the permutations to
+ // a product of transpositions. Iterate over the cycles in the
+ // permutation, and convert each of those into a product of
+ // transpositions (swaps):
+ // https://en.wikipedia.org/wiki/Cyclic_permutation
+ // This is N swaps, 2*N comparisons.
+ for (std::size_t n = 0; n + 1 < permutation.size(); ++n) {
+ while (n != permutation[n]) {
+ std::size_t r = permutation[n];
+ std::swap_ranges(&(ix_t(n, 0)), &(ix_t(n + 1, 0)), &(ix_t(r, 0)));
+ std::swap(vals_t(n), vals_t(r));
+ std::swap(permutation[n], permutation[r]);
+ }
+ }
+
+ order_ = gtl::InlinedVector<int64, 8>(order.begin(), order.end());
+}
+
+template <typename T>
+bool SparseTensor::ValidateAndInitializeToDense(Tensor* out, bool initialize) {
+ CHECK_EQ(DataTypeToEnum<T>::v(), dtype())
+ << "ToDense requested with the wrong datatype";
+
+ CHECK_EQ(out->shape().dims(), dims_)
+ << "Incompatible dimensions between SparseTensor and output";
+
+ CHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
+ << "Output must be type: " << DataTypeToEnum<T>::v()
+ << " but got: " << out->dtype();
+
+ // Make sure the dense output is the same rank and has room
+ // to hold the SparseTensor.
+ const auto& out_shape = out->shape();
+ if (shape_.dims() != out_shape.dims()) return false;
+ for (int d = 0; d < shape_.dims(); ++d) {
+ if (shape_.dim_size(d) > out_shape.dim_size(d)) return false;
+ }
+
+ if (initialize) {
+ auto out_t = out->flat<T>();
+ out_t.setConstant(T());
+ }
+
+ return true;
+}
+
+template <typename T>
+bool SparseTensor::ToDense(Tensor* out, bool initialize) {
+ if (!ValidateAndInitializeToDense<T>(out, initialize)) return false;
+
+ auto out_t = out->flat<T>();
+ auto ix_t = ix_.matrix<int64>();
+ auto vals_t = vals_.vec<T>();
+
+ std::vector<int64> strides(dims_);
+ const auto& out_shape = out->shape();
+ strides[dims_ - 1] = 1;
+ for (int d = dims_ - 2; d >= 0; --d) {
+ strides[d] = strides[d + 1] * out_shape.dim_size(d + 1);
+ }
+
+ for (std::size_t n = 0; n < vals_t.dimension(0); ++n) {
+ bool invalid_dims = false;
+ int64 ix = 0;
+ for (int d = 0; d < dims_; ++d) {
+ const int64 ix_n_d = ix_t(n, d);
+ if (ix_n_d < 0 || ix_n_d >= out_shape.dim_size(d)) {
+ invalid_dims = true;
+ }
+ ix += strides[d] * ix_n_d;
+ }
+ if (invalid_dims) return false;
+ out_t(ix) = vals_t(n);
+ }
+ return true;
+}
+
+template <typename T>
+SparseTensor SparseTensor::Concat(
+ const gtl::ArraySlice<SparseTensor>& tensors) {
+ CHECK_GE(tensors.size(), 1) << "Cannot concat 0 SparseTensors";
+ const int dims = tensors[0].dims_;
+ CHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
+ auto order_0 = tensors[0].order();
+ const int primary_dim = order_0[0];
+ gtl::InlinedVector<int64, 8> final_order(order_0.begin(), order_0.end());
+ TensorShape final_shape(tensors[0].shape());
+ final_shape.set_dim(primary_dim, 0); // We'll build this up as we go along.
+ int num_entries = 0;
+
+ bool fully_ordered = true;
+ for (const SparseTensor& st : tensors) {
+ CHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
+ CHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
+ << "Concat requested with the wrong data type";
+ CHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
+ CHECK_EQ(st.order()[0], primary_dim)
+ << "All SparseTensors' order[0] must match. This is the concat dim.";
+ if (st.order() != final_order) fully_ordered = false;
+ const TensorShape st_shape = st.shape();
+ for (int d = 0; d < dims - 1; ++d) {
+ const int cdim = (d < primary_dim) ? d : d + 1;
+ CHECK_EQ(final_shape.dim_size(cdim), st_shape.dim_size(cdim))
+ << "All SparseTensors' shapes must match except on the concat dim. "
+ << "Concat dim: " << primary_dim
+ << ", mismatched shape at dim: " << cdim
+ << ". Expecting shape like: " << final_shape.DebugString()
+ << " but saw shape: " << st_shape.DebugString();
+ }
+
+ // Update dimension of final shape
+ final_shape.set_dim(primary_dim, final_shape.dim_size(primary_dim) +
+ st_shape.dim_size(primary_dim));
+
+ num_entries += st.num_entries(); // Update number of entries
+ }
+
+ // If nonconsistent ordering among inputs, set final order to -1s.
+ if (!fully_ordered) {
+ final_order = UndefinedOrder(final_shape);
+ }
+
+ Tensor output_ix(DT_INT64, TensorShape({num_entries, dims}));
+ Tensor output_vals(DataTypeToEnum<T>::v(), TensorShape({num_entries}));
+
+ auto ix_t = output_ix.matrix<int64>();
+ auto vals_t = output_vals.vec<T>();
+
+ Eigen::DenseIndex offset = 0;
+ int64 shape_offset = 0;
+ for (const SparseTensor& st : tensors) {
+ int st_num_entries = st.num_entries();
+ Eigen::DSizes<Eigen::DenseIndex, 2> ix_start(offset, 0);
+ Eigen::DSizes<Eigen::DenseIndex, 2> ix_size(st_num_entries, dims);
+ Eigen::DSizes<Eigen::DenseIndex, 1> vals_start(offset);
+ Eigen::DSizes<Eigen::DenseIndex, 1> vals_size(st_num_entries);
+
+ // Fill in indices & values.
+ ix_t.slice(ix_start, ix_size) = st.ix_.matrix<int64>();
+ vals_t.slice(vals_start, vals_size) = st.vals_.vec<T>();
+
+ Eigen::DSizes<Eigen::DenseIndex, 2> ix_update_start(offset, primary_dim);
+ Eigen::DSizes<Eigen::DenseIndex, 2> ix_update_size(st_num_entries, 1);
+ // The index associated with the primary dimension gets increased
+ // by the shapes of the previous concatted Tensors.
+ auto update_slice = ix_t.slice(ix_update_start, ix_update_size);
+ update_slice += update_slice.constant(shape_offset);
+
+ offset += st_num_entries;
+ shape_offset += st.shape().dim_size(primary_dim);
+ }
+
+ return SparseTensor(output_ix, output_vals, final_shape, final_order);
+}
+
+} // namespace sparse
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_SPARSE_SPARSE_TENSOR_H_
diff --git a/tensorflow/core/util/sparse/sparse_tensor_test.cc b/tensorflow/core/util/sparse/sparse_tensor_test.cc
new file mode 100644
index 0000000000..47126b7187
--- /dev/null
+++ b/tensorflow/core/util/sparse/sparse_tensor_test.cc
@@ -0,0 +1,467 @@
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/public/tensor.h"
+#include <gtest/gtest.h>
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace sparse {
+namespace {
+
+Eigen::Tensor<int64, 2, Eigen::RowMajor, Eigen::DenseIndex>
+GetSimpleIndexTensor(int N, const int NDIM) {
+ Eigen::Tensor<int64, 2, Eigen::RowMajor, Eigen::DenseIndex> ix(N, NDIM);
+ ix(0, 0) = 0;
+ ix(0, 1) = 0;
+ ix(0, 2) = 0;
+
+ ix(1, 0) = 3;
+ ix(1, 1) = 0;
+ ix(1, 2) = 0;
+
+ ix(2, 0) = 2;
+ ix(2, 1) = 0;
+ ix(2, 2) = 0;
+
+ ix(3, 0) = 0;
+ ix(3, 1) = 1;
+ ix(3, 2) = 0;
+
+ ix(4, 0) = 0;
+ ix(4, 1) = 0;
+ ix(4, 2) = 2;
+ return ix;
+}
+
+TEST(SparseTensorTest, DimComparatorSorts) {
+ std::size_t N = 5;
+ const int NDIM = 3;
+ auto ix = GetSimpleIndexTensor(N, NDIM);
+ TTypes<int64>::Matrix map(ix.data(), N, NDIM);
+
+ std::vector<int64> sorting(N);
+ for (std::size_t n = 0; n < N; ++n) sorting[n] = n;
+
+ // new order should be: {0, 4, 3, 2, 1}
+ std::vector<int64> order{0, 1, 2};
+ DimComparator sorter(map, order, NDIM);
+ std::sort(sorting.begin(), sorting.end(), sorter);
+
+ EXPECT_EQ(sorting, std::vector<int64>({0, 4, 3, 2, 1}));
+
+ // new order should be: {0, 3, 2, 1, 4}
+ std::vector<int64> order1{2, 0, 1};
+ DimComparator sorter1(map, order1, NDIM);
+ for (std::size_t n = 0; n < N; ++n) sorting[n] = n;
+ std::sort(sorting.begin(), sorting.end(), sorter1);
+
+ EXPECT_EQ(sorting, std::vector<int64>({0, 3, 2, 1, 4}));
+}
+
+TEST(SparseTensorTest, SparseTensorConstruction) {
+ int N = 5;
+ const int NDIM = 3;
+ auto ix_c = GetSimpleIndexTensor(N, NDIM);
+ Eigen::Tensor<string, 1, Eigen::RowMajor> vals_c(N);
+ vals_c(0) = "hi0";
+ vals_c(1) = "hi1";
+ vals_c(2) = "hi2";
+ vals_c(3) = "hi3";
+ vals_c(4) = "hi4";
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ auto ix_t = ix.matrix<int64>();
+ auto vals_t = vals.vec<string>();
+ vals_t = vals_c;
+ ix_t = ix_c;
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+ SparseTensor st(ix, vals, shape, order);
+ EXPECT_FALSE(st.IndicesValid()); // Out of order
+
+ // Regardless of how order is updated; so long as there are no
+ // duplicates, the resulting indices are valid.
+ st.Reorder<string>({2, 0, 1});
+ EXPECT_TRUE(st.IndicesValid());
+ EXPECT_EQ(vals_t(0), "hi0");
+ EXPECT_EQ(vals_t(1), "hi3");
+ EXPECT_EQ(vals_t(2), "hi2");
+ EXPECT_EQ(vals_t(3), "hi1");
+ EXPECT_EQ(vals_t(4), "hi4");
+
+ ix_t = ix_c;
+ vals_t = vals_c;
+ st.Reorder<string>({0, 1, 2});
+ EXPECT_TRUE(st.IndicesValid());
+ EXPECT_EQ(vals_t(0), "hi0");
+ EXPECT_EQ(vals_t(1), "hi4");
+ EXPECT_EQ(vals_t(2), "hi3");
+ EXPECT_EQ(vals_t(3), "hi2");
+ EXPECT_EQ(vals_t(4), "hi1");
+
+ ix_t = ix_c;
+ vals_t = vals_c;
+ st.Reorder<string>({2, 1, 0});
+ EXPECT_TRUE(st.IndicesValid());
+}
+
+TEST(SparseTensorTest, EmptySparseTensorAllowed) {
+ int N = 0;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+ SparseTensor st(ix, vals, shape, order);
+ EXPECT_TRUE(st.IndicesValid());
+ EXPECT_EQ(st.order(), order);
+
+ std::vector<int64> new_order{1, 0, 2};
+ st.Reorder<string>(new_order);
+ EXPECT_TRUE(st.IndicesValid());
+ EXPECT_EQ(st.order(), new_order);
+}
+
+TEST(SparseTensorTest, SortingWorksCorrectly) {
+ int N = 30;
+ const int NDIM = 4;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+ TensorShape shape({1000, 1000, 1000, 1000});
+ SparseTensor st(ix, vals, shape);
+
+ auto ix_t = ix.matrix<int64>();
+
+ for (int n = 0; n < 100; ++n) {
+ ix_t = ix_t.random(Eigen::internal::UniformRandomGenerator<int64>(n + 1));
+ ix_t = ix_t.abs() % 1000;
+ st.Reorder<string>({0, 1, 2, 3});
+ EXPECT_TRUE(st.IndicesValid());
+ st.Reorder<string>({3, 2, 1, 0});
+ EXPECT_TRUE(st.IndicesValid());
+ st.Reorder<string>({1, 0, 2, 3});
+ EXPECT_TRUE(st.IndicesValid());
+ st.Reorder<string>({3, 0, 2, 1});
+ EXPECT_TRUE(st.IndicesValid());
+ }
+}
+
+TEST(SparseTensorTest, ValidateIndicesFindsInvalid) {
+ int N = 2;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ Eigen::Tensor<int64, 2, Eigen::RowMajor> ix_orig(N, NDIM);
+ ix_orig(0, 0) = 0;
+ ix_orig(0, 1) = 0;
+ ix_orig(0, 2) = 0;
+
+ ix_orig(1, 0) = 0;
+ ix_orig(1, 1) = 0;
+ ix_orig(1, 2) = 0;
+
+ auto ix_t = ix.matrix<int64>();
+ ix_t = ix_orig;
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+ SparseTensor st(ix, vals, shape, order);
+
+ st.Reorder<string>(order);
+ EXPECT_FALSE(st.IndicesValid()); // two indices are identical
+
+ ix_orig(1, 2) = 1;
+ ix_t = ix_orig;
+ st.Reorder<string>(order);
+ EXPECT_TRUE(st.IndicesValid()); // second index now (0, 0, 1)
+
+ ix_orig(0, 2) = 1;
+ ix_t = ix_orig;
+ st.Reorder<string>(order);
+ EXPECT_FALSE(st.IndicesValid()); // first index now (0, 0, 1)
+}
+
+TEST(SparseTensorTest, SparseTensorCheckBoundaries) {
+ int N = 5;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ auto ix_t = GetSimpleIndexTensor(N, NDIM);
+
+ ix.matrix<int64>() = ix_t;
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+
+ SparseTensor st(ix, vals, shape, order);
+ EXPECT_FALSE(st.IndicesValid());
+
+ st.Reorder<string>(order);
+ EXPECT_TRUE(st.IndicesValid());
+
+ ix_t(0, 0) = 11;
+ ix.matrix<int64>() = ix_t;
+ st.Reorder<string>(order);
+ EXPECT_FALSE(st.IndicesValid());
+
+ ix_t(0, 0) = -1;
+ ix.matrix<int64>() = ix_t;
+ st.Reorder<string>(order);
+ EXPECT_FALSE(st.IndicesValid());
+
+ ix_t(0, 0) = 0;
+ ix.matrix<int64>() = ix_t;
+ st.Reorder<string>(order);
+ EXPECT_TRUE(st.IndicesValid());
+}
+
+TEST(SparseTensorTest, SparseTensorToDenseTensor) {
+ int N = 5;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ auto ix_t = GetSimpleIndexTensor(N, NDIM);
+ auto vals_t = vals.vec<string>();
+
+ ix.matrix<int64>() = ix_t;
+
+ vals_t(0) = "hi0";
+ vals_t(1) = "hi1";
+ vals_t(2) = "hi2";
+ vals_t(3) = "hi3";
+ vals_t(4) = "hi4";
+
+ TensorShape shape({4, 4, 5});
+ std::vector<int64> order{0, 1, 2};
+ SparseTensor st(ix, vals, shape, order);
+
+ Tensor dense(DT_STRING, TensorShape({4, 4, 5}));
+ st.ToDense<string>(&dense);
+
+ auto dense_t = dense.tensor<string, 3>();
+ Eigen::array<Eigen::DenseIndex, NDIM> ix_n;
+ for (int n = 0; n < N; ++n) {
+ for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d);
+ EXPECT_EQ(dense_t(ix_n), vals_t(n));
+ }
+
+ // Spot checks on the others
+ EXPECT_EQ(dense_t(0, 0, 1), "");
+ EXPECT_EQ(dense_t(0, 0, 3), "");
+ EXPECT_EQ(dense_t(3, 3, 3), "");
+ EXPECT_EQ(dense_t(3, 3, 4), "");
+}
+
+TEST(SparseTensorTest, SparseTensorToLargerDenseTensor) {
+ int N = 5;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ auto ix_t = GetSimpleIndexTensor(N, NDIM);
+ auto vals_t = vals.vec<string>();
+
+ ix.matrix<int64>() = ix_t;
+
+ vals_t(0) = "hi0";
+ vals_t(1) = "hi1";
+ vals_t(2) = "hi2";
+ vals_t(3) = "hi3";
+ vals_t(4) = "hi4";
+
+ TensorShape shape({4, 4, 5});
+ std::vector<int64> order{0, 1, 2};
+ SparseTensor st(ix, vals, shape, order);
+
+ Tensor dense(DT_STRING, TensorShape({10, 10, 10}));
+ st.ToDense<string>(&dense);
+
+ auto dense_t = dense.tensor<string, 3>();
+ Eigen::array<Eigen::DenseIndex, NDIM> ix_n;
+ for (int n = 0; n < N; ++n) {
+ for (int d = 0; d < NDIM; ++d) ix_n[d] = ix_t(n, d);
+ EXPECT_EQ(dense_t(ix_n), vals_t(n));
+ }
+
+ // Spot checks on the others
+ EXPECT_EQ(dense_t(0, 0, 1), "");
+ EXPECT_EQ(dense_t(0, 0, 3), "");
+ EXPECT_EQ(dense_t(3, 3, 3), "");
+ EXPECT_EQ(dense_t(3, 3, 4), "");
+ EXPECT_EQ(dense_t(9, 0, 0), "");
+ EXPECT_EQ(dense_t(9, 0, 9), "");
+ EXPECT_EQ(dense_t(9, 9, 9), "");
+}
+
+TEST(SparseTensorTest, SparseTensorGroup) {
+ int N = 5;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_INT32, TensorShape({N}));
+
+ auto ix_t = ix.matrix<int64>();
+ auto vals_t = vals.vec<int32>();
+
+ ix_t = GetSimpleIndexTensor(N, NDIM);
+
+ vals_t(0) = 1; // associated with ix (000)
+ vals_t(1) = 2; // associated with ix (300)
+ vals_t(2) = 3; // associated with ix (200)
+ vals_t(3) = 4; // associated with ix (010)
+ vals_t(4) = 5; // associated with ix (002)
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+
+ SparseTensor st(ix, vals, shape, order);
+ st.Reorder<int32>(order);
+
+ std::vector<std::vector<int64> > groups;
+ std::vector<TTypes<int64>::UnalignedConstMatrix> grouped_indices;
+ std::vector<TTypes<int32>::UnalignedVec> grouped_values;
+
+ // Group by index 0
+ auto gi = st.group({0});
+
+ // All the hard work is right here!
+ for (const auto& g : gi) {
+ groups.push_back(g.group());
+ VLOG(1) << "Group: " << str_util::Join(g.group(), ",");
+ VLOG(1) << "Indices: " << g.indices();
+ VLOG(1) << "Values: " << g.values<int32>();
+
+ grouped_indices.push_back(g.indices());
+ grouped_values.push_back(g.values<int32>());
+ }
+
+ // Group by dimension 0, we have groups: 0--, 2--, 3--
+ EXPECT_EQ(groups.size(), 3);
+ EXPECT_EQ(groups[0], std::vector<int64>({0}));
+ EXPECT_EQ(groups[1], std::vector<int64>({2}));
+ EXPECT_EQ(groups[2], std::vector<int64>({3}));
+
+ std::vector<Eigen::Tensor<int64, 2, Eigen::RowMajor> > expected_indices;
+ std::vector<Eigen::Tensor<int32, 1, Eigen::RowMajor> > expected_vals;
+
+ // First group: 000, 002, 010
+ expected_indices.emplace_back(3, NDIM); // 3 x 3 tensor
+ expected_vals.emplace_back(3); // 3 x 5 x 1 x 1 tensor
+ expected_indices[0].setZero();
+ expected_indices[0](1, 2) = 2; // 002
+ expected_indices[0](2, 1) = 1; // 010
+ expected_vals[0].setConstant(-1);
+ expected_vals[0](0) = 1; // val associated with ix 000
+ expected_vals[0](1) = 5; // val associated with ix 002
+ expected_vals[0](2) = 4; // val associated with ix 010
+
+ // Second group: 200
+ expected_indices.emplace_back(1, NDIM);
+ expected_vals.emplace_back(1);
+ expected_indices[1].setZero();
+ expected_indices[1](0, 0) = 2; // 200
+ expected_vals[1](0) = 3; // val associated with ix 200
+
+ // Third group: 300
+ expected_indices.emplace_back(1, NDIM);
+ expected_vals.emplace_back(1);
+ expected_indices[2].setZero();
+ expected_indices[2](0, 0) = 3; // 300
+ expected_vals[2](0) = 2; // val associated with ix 300
+
+ for (std::size_t gix = 0; gix < groups.size(); ++gix) {
+ // Compare indices
+ auto gi_t = grouped_indices[gix];
+ Eigen::Tensor<bool, 0, Eigen::RowMajor> eval =
+ (gi_t == expected_indices[gix]).all();
+ EXPECT_TRUE(eval()) << gix << " indices: " << gi_t << " vs. "
+ << expected_indices[gix];
+
+ // Compare values
+ auto gv_t = grouped_values[gix];
+ eval = (gv_t == expected_vals[gix]).all();
+ EXPECT_TRUE(eval()) << gix << " values: " << gv_t << " vs. "
+ << expected_vals[gix];
+ }
+}
+
+TEST(SparseTensorTest, Concat) {
+ int N = 5;
+ const int NDIM = 3;
+
+ Tensor ix(DT_INT64, TensorShape({N, NDIM}));
+ Tensor vals(DT_STRING, TensorShape({N}));
+
+ auto ix_c = GetSimpleIndexTensor(N, NDIM);
+
+ auto ix_t = ix.matrix<int64>();
+ auto vals_t = vals.vec<string>();
+
+ ix_t = ix_c;
+
+ TensorShape shape({10, 10, 10});
+ std::vector<int64> order{0, 1, 2};
+
+ SparseTensor st(ix, vals, shape, order);
+ EXPECT_FALSE(st.IndicesValid());
+ st.Reorder<string>(order);
+ EXPECT_TRUE(st.IndicesValid());
+
+ SparseTensor concatted = SparseTensor::Concat<string>({st, st, st, st});
+ EXPECT_EQ(concatted.order(), st.order());
+ TensorShape expected_shape({40, 10, 10});
+ EXPECT_EQ(concatted.shape(), expected_shape);
+ EXPECT_EQ(concatted.num_entries(), 4 * N);
+ EXPECT_TRUE(concatted.IndicesValid());
+
+ auto conc_ix_t = concatted.indices().matrix<int64>();
+ auto conc_vals_t = concatted.values().vec<string>();
+
+ for (int n = 0; n < 4; ++n) {
+ for (int i = 0; i < N; ++i) {
+ // Dimensions match except the primary dim, which is offset by
+ // shape[order[0]]
+ EXPECT_EQ(conc_ix_t(n * N + i, 0), 10 * n + ix_t(i, 0));
+ EXPECT_EQ(conc_ix_t(n * N + i, 1), ix_t(i, 1));
+ EXPECT_EQ(conc_ix_t(n * N + i, 1), ix_t(i, 1));
+
+ // Values match
+ EXPECT_EQ(conc_vals_t(n * N + i), vals_t(i));
+ }
+ }
+
+ // Concat works if non-primary ix is out of order, but output order
+ // is not defined
+ SparseTensor st_ooo(ix, vals, shape, {0, 2, 1}); // non-primary ix OOO
+ SparseTensor conc_ooo = SparseTensor::Concat<string>({st, st, st, st_ooo});
+ std::vector<int64> expected_ooo{-1, -1, -1};
+ EXPECT_EQ(conc_ooo.order(), expected_ooo);
+ EXPECT_EQ(conc_ooo.shape(), expected_shape);
+ EXPECT_EQ(conc_ooo.num_entries(), 4 * N);
+}
+
+// TODO(ebrevdo): ReduceToDense(R={dim1,dim2,...}, reduce_fn, &output)
+// reduce_fn sees slices of resorted values based on generator (dim: DDIMS), and
+// slices of resorted indices on generator.
+
+} // namespace
+} // namespace sparse
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_reader.cc b/tensorflow/core/util/tensor_slice_reader.cc
new file mode 100644
index 0000000000..00bc16f105
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_reader.cc
@@ -0,0 +1,230 @@
+#include "tensorflow/core/util/tensor_slice_reader.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/io/iterator.h"
+#include "tensorflow/core/lib/io/match.h"
+#include "tensorflow/core/lib/io/table.h"
+#include "tensorflow/core/lib/io/table_options.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+#include "tensorflow/core/util/tensor_slice_util.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+TensorSliceReader::Table::~Table() {}
+
+namespace {
+class TensorSliceReaderTable : public TensorSliceReader::Table {
+ public:
+ explicit TensorSliceReaderTable(RandomAccessFile* f, table::Table* t)
+ : file_(f), table_(t) {}
+
+ ~TensorSliceReaderTable() override {
+ delete table_;
+ delete file_;
+ }
+
+ bool Get(const string& key, string* value) override {
+ std::unique_ptr<table::Iterator> iter(table_->NewIterator());
+ iter->Seek(key);
+ if (iter->Valid() && iter->key() == key) {
+ StringPiece v = iter->value();
+ value->assign(v.data(), v.size());
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ private:
+ RandomAccessFile* file_;
+ table::Table* table_;
+};
+} // namespace
+
+Status OpenTableTensorSliceReader(const string& fname,
+ TensorSliceReader::Table** result) {
+ *result = nullptr;
+ Env* env = Env::Default();
+ RandomAccessFile* f = nullptr;
+ Status s = env->NewRandomAccessFile(fname, &f);
+ if (s.ok()) {
+ uint64 file_size;
+ s = env->GetFileSize(fname, &file_size);
+ if (s.ok()) {
+ table::Options options;
+ table::Table* table;
+ s = table::Table::Open(options, f, file_size, &table);
+ if (s.ok()) {
+ *result = new TensorSliceReaderTable(f, table);
+ return Status::OK();
+ } else {
+ s = Status(s.code(),
+ strings::StrCat(s.error_message(),
+ ": perhaps your file is in a different "
+ "file format and you need to use a "
+ "different restore operator?"));
+ }
+ }
+ }
+ LOG(WARNING) << "Could not open " << fname << ": " << s;
+ delete f;
+ return s;
+}
+
+TensorSliceReader::TensorSliceReader(const string& filepattern,
+ OpenTableFunction open_function)
+ : TensorSliceReader(filepattern, open_function, kLoadAllShards) {}
+
+TensorSliceReader::TensorSliceReader(const string& filepattern,
+ OpenTableFunction open_function,
+ int preferred_shard)
+ : filepattern_(filepattern), open_function_(open_function) {
+ VLOG(1) << "TensorSliceReader for " << filepattern;
+ Status s = io::GetMatchingFiles(Env::Default(), filepattern, &fnames_);
+ if (!s.ok()) {
+ status_ = errors::InvalidArgument(
+ "Unsuccessful TensorSliceReader constructor: "
+ "Failed to get matching files on ",
+ filepattern, ": ", s.ToString());
+ return;
+ }
+ if (fnames_.empty()) {
+ status_ = errors::NotFound(
+ "Unsuccessful TensorSliceReader constructor: "
+ "Failed to find any matching files for ",
+ filepattern);
+ return;
+ }
+ sss_.resize(fnames_.size());
+ for (size_t shard = 0; shard < fnames_.size(); ++shard) {
+ fname_to_index_.insert(std::make_pair(fnames_[shard], shard));
+ }
+ if (preferred_shard == kLoadAllShards || fnames_.size() == 1 ||
+ static_cast<size_t>(preferred_shard) >= fnames_.size()) {
+ LoadAllShards();
+ } else {
+ VLOG(1) << "Loading shard " << preferred_shard << " for " << filepattern_;
+ LoadShard(preferred_shard);
+ }
+}
+
+void TensorSliceReader::LoadShard(int shard) const {
+ CHECK_LT(shard, sss_.size());
+ if (sss_[shard] || !status_.ok()) {
+ return; // Already loaded, or invalid.
+ }
+ string value;
+ SavedTensorSlices sts;
+ const string fname = fnames_[shard];
+ VLOG(1) << "Reading meta data from file " << fname << "...";
+ Table* table;
+ Status s = open_function_(fname, &table);
+ if (!s.ok()) {
+ status_ = errors::DataLoss("Unable to open table file ", fname, ": ",
+ s.ToString());
+ return;
+ }
+ sss_[shard].reset(table);
+ if (!(table->Get(kSavedTensorSlicesKey, &value) &&
+ ParseProtoUnlimited(&sts, value))) {
+ status_ = errors::Internal(
+ "Failed to find the saved tensor slices at the beginning of the "
+ "checkpoint file: ",
+ fname);
+ return;
+ }
+ for (const SavedSliceMeta& ssm : sts.meta().tensor()) {
+ TensorShape ssm_shape(ssm.shape());
+ for (const TensorSliceProto& tsp : ssm.slice()) {
+ TensorSlice ss_slice(tsp);
+ RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname, ss_slice);
+ }
+ }
+}
+
+void TensorSliceReader::LoadAllShards() const {
+ VLOG(1) << "Loading all shards for " << filepattern_;
+ for (size_t i = 0; i < fnames_.size() && status_.ok(); ++i) {
+ LoadShard(i);
+ }
+ all_shards_loaded_ = true;
+}
+
+const TensorSliceSet* TensorSliceReader::FindTensorSlice(
+ const string& name, const TensorSlice& slice,
+ std::vector<std::pair<TensorSlice, string>>* details) const {
+ const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
+ if (tss && !tss->QueryMeta(slice, details)) {
+ return nullptr;
+ }
+ return tss;
+}
+
+TensorSliceReader::~TensorSliceReader() { gtl::STLDeleteValues(&tensors_); }
+
+void TensorSliceReader::RegisterTensorSlice(const string& name,
+ const TensorShape& shape,
+ DataType type, const string& tag,
+ const TensorSlice& slice) const {
+ TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
+ // Create a tensor slice set if needed
+ if (!tss) {
+ tss = new TensorSliceSet(shape, type);
+ tensors_.insert(std::make_pair(name, tss));
+ } else {
+ // Check if the shapes match
+ TensorShape tss_shape(tss->shape());
+ if (!shape.IsSameSize(tss_shape)) {
+ status_ =
+ errors::Internal("Incompatible tensor shapes detected for tensor ",
+ name, ": existing = ", tss_shape.DebugString(),
+ ", new = ", shape.DebugString());
+ return;
+ }
+ if (type != tss->type()) {
+ status_ =
+ errors::Internal("Incompatible tensor types detected for tensor ",
+ name, ": existing = ", DataTypeString(tss->type()),
+ ", new = ", DataTypeString(type));
+ return;
+ }
+ }
+ // Register the tensor slices without the actual data.
+ Status s = tss->Register(slice, tag, nullptr);
+ if (!s.ok()) {
+ status_ = s;
+ }
+}
+
+bool TensorSliceReader::HasTensor(const string& name, TensorShape* shape,
+ DataType* type) const {
+ mutex_lock l(mu_);
+ const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
+ if (!tss && !all_shards_loaded_) {
+ VLOG(1) << "Did not find tensor in preferred shard, loading all shards: "
+ << name;
+ LoadAllShards();
+ tss = gtl::FindPtrOrNull(tensors_, name);
+ }
+ if (tss) {
+ if (shape) {
+ *shape = tss->shape();
+ }
+ if (type) {
+ *type = tss->type();
+ }
+ return true;
+ } else {
+ return false;
+ }
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_reader.h b/tensorflow/core/util/tensor_slice_reader.h
new file mode 100644
index 0000000000..b5f26a689b
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_reader.h
@@ -0,0 +1,157 @@
+// The utility to read checkpoints for google brain tensor ops and v3
+// checkpoints for dist_belief.
+//
+
+#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_
+#define TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_
+
+#include <unordered_map>
+
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/util/saved_tensor_slice.pb.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+#include "tensorflow/core/util/tensor_slice_set.h"
+#include "tensorflow/core/util/tensor_slice_util.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+// The reader reads in all the meta data about all the tensor slices. Then it
+// will try to read the relevant data on-demand to produce the data for the
+// slices needed.
+// NOTE(yangke): another way to do this is to first load a list of the tensor
+// slices needed and then just selectively read some of the meta data. That
+// might optimize the loading but makes the logic a bit more complicated. We
+// might want to revisit that.
+// TODO(yangke): consider moving to TensorProto.
+class TensorSliceReader {
+ public:
+ // Abstract interface for reading data out of a tensor slice checkpoint file
+ class Table {
+ public:
+ virtual ~Table();
+ virtual bool Get(const string& key, string* value) = 0;
+ };
+ typedef std::function<Status(const string&, Table**)> OpenTableFunction;
+
+ static const int kLoadAllShards = -1;
+ TensorSliceReader(const string& filepattern, OpenTableFunction open_function);
+ TensorSliceReader(const string& filepattern, OpenTableFunction open_function,
+ int preferred_shard);
+ virtual ~TensorSliceReader();
+
+ // Get the filename this reader is attached to.
+ const string& filepattern() const { return filepattern_; }
+
+ // Get the number of files matched.
+ int num_files() const { return sss_.size(); }
+
+ // Get the status of the reader.
+ const Status status() const { return status_; }
+
+ // Checks if the reader contains any slice of a tensor. In case the reader
+ // does contain the tensor, if "shape" is not nullptr, fill "shape" with the
+ // shape of the tensor; if "type" is not nullptr, fill "type" with the type
+ // of the tensor.
+ bool HasTensor(const string& name, TensorShape* shape, DataType* type) const;
+
+ // Checks if the reader contains all the data about a tensor slice, and if
+ // yes, copies the data of the slice to "data". The caller needs to make sure
+ // that "data" points to a buffer that holds enough data.
+ // This is a slow function since it needs to read sstables.
+ template <typename T>
+ bool CopySliceData(const string& name, const TensorSlice& slice,
+ T* data) const;
+
+ // Get the tensors.
+ const std::unordered_map<string, TensorSliceSet*>& Tensors() const {
+ return tensors_;
+ }
+
+ private:
+ friend class TensorSliceWriteTestHelper;
+
+ void LoadShard(int shard) const;
+ void LoadAllShards() const;
+ void RegisterTensorSlice(const string& name, const TensorShape& shape,
+ DataType type, const string& tag,
+ const TensorSlice& slice) const;
+
+ const TensorSliceSet* FindTensorSlice(
+ const string& name, const TensorSlice& slice,
+ std::vector<std::pair<TensorSlice, string>>* details) const;
+
+ const string filepattern_;
+ const OpenTableFunction open_function_;
+ std::vector<string> fnames_;
+ std::unordered_map<string, int> fname_to_index_;
+
+ // Guards the attributes below.
+ mutable mutex mu_;
+ mutable bool all_shards_loaded_ = false;
+ mutable std::vector<std::unique_ptr<Table>> sss_;
+ mutable std::unordered_map<string, TensorSliceSet*> tensors_;
+ mutable Status status_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorSliceReader);
+};
+
+Status OpenTableTensorSliceReader(const string& fname,
+ TensorSliceReader::Table** table);
+
+template <typename T>
+bool TensorSliceReader::CopySliceData(const string& name,
+ const TensorSlice& slice, T* data) const {
+ std::vector<std::pair<TensorSlice, string>> details;
+ const TensorSliceSet* tss;
+ {
+ mutex_lock l(mu_);
+ tss = FindTensorSlice(name, slice, &details);
+ if (!tss && !all_shards_loaded_) {
+ VLOG(1) << "Did not find slice in preferred shard, loading all shards."
+ << name << ": " << slice.DebugString();
+ LoadAllShards();
+ tss = FindTensorSlice(name, slice, &details);
+ }
+ if (!tss) {
+ // No such tensor
+ return false;
+ }
+ }
+ // We have the data -- copy it over.
+ string value;
+ for (const auto& x : details) {
+ const TensorSlice& slice_s = x.first;
+ const string& fname = x.second;
+ int idx = gtl::FindWithDefault(fname_to_index_, fname, -1);
+ CHECK_GE(idx, 0) << "Failed to find the index for filename " << fname;
+ // We read a record in the corresponding sstable
+ const string key = EncodeTensorNameSlice(name, slice_s);
+ CHECK(sss_[idx]->Get(key, &value))
+ << "Failed to seek to the record for tensor " << name << ", slice "
+ << slice_s.DebugString() << ": computed key = " << key;
+ SavedTensorSlices sts;
+ CHECK(ParseProtoUnlimited(&sts, value))
+ << "Failed to parse the record for tensor " << name << ", slice "
+ << slice_s.DebugString() << ": computed key = " << key;
+ CopyDataFromTensorSliceToTensorSlice(
+ tss->shape(), slice_s, slice,
+ checkpoint::TensorProtoData<T>(sts.data().data()), data);
+ }
+ return true;
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_
diff --git a/tensorflow/core/util/tensor_slice_reader_cache.cc b/tensorflow/core/util/tensor_slice_reader_cache.cc
new file mode 100644
index 0000000000..af81d0115e
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_reader_cache.cc
@@ -0,0 +1,94 @@
+#include "tensorflow/core/util/tensor_slice_reader_cache.h"
+
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+TensorSliceReaderCacheWrapper::TensorSliceReaderCacheWrapper() {}
+TensorSliceReaderCacheWrapper::~TensorSliceReaderCacheWrapper() {
+ if (cache_) {
+ delete cache_;
+ }
+ cache_ = nullptr;
+}
+
+const TensorSliceReader* TensorSliceReaderCacheWrapper::GetReader(
+ const string& filepattern,
+ TensorSliceReader::OpenTableFunction open_function,
+ int preferred_shard) const {
+ mutex_lock l(mu_);
+ if (!cache_) {
+ cache_ = new TensorSliceReaderCache;
+ }
+ return cache_->GetReader(filepattern, open_function, preferred_shard);
+}
+
+TensorSliceReaderCache::TensorSliceReaderCache() {}
+
+TensorSliceReaderCache::~TensorSliceReaderCache() {
+ for (auto pair : readers_) {
+ delete pair.second.second;
+ }
+}
+
+const TensorSliceReader* TensorSliceReaderCache::GetReader(
+ const string& filepattern,
+ TensorSliceReader::OpenTableFunction open_function, int preferred_shard) {
+ mutex_lock l(mu_);
+
+ // Get the function pointer from the open_function value.
+ TensorSliceReaderCache::OpenFuncType* func_ptr =
+ open_function.target<TensorSliceReaderCache::OpenFuncType>();
+ if (!func_ptr) {
+ // We could not get the pointer, no caching is possible.
+ LOG(WARNING) << "Caching disabled because the open function is a lambda.";
+ return nullptr;
+ }
+
+ // Wait if another thread is already trying to open the same files.
+ while (still_opening_.find(filepattern) != still_opening_.end()) {
+ cv_.wait(l);
+ }
+
+ TensorSliceReader* reader = nullptr;
+ if (readers_.find(filepattern) == readers_.end()) {
+ VLOG(1) << "Creating new TensorSliceReader for " << filepattern;
+ still_opening_.insert(filepattern);
+ // Release the lock temporary as constructing TensorSliceReader is
+ // expensive.
+ mu_.unlock();
+ TensorSliceReader* tmp_reader(
+ new TensorSliceReader(filepattern, open_function, preferred_shard));
+ // Acquire the lock again.
+ mu_.lock();
+ if (tmp_reader->status().ok()) {
+ reader = tmp_reader;
+ readers_[filepattern] = make_pair(*func_ptr, reader);
+ } else {
+ delete tmp_reader;
+ }
+ CHECK_EQ(1, still_opening_.erase(filepattern));
+ VLOG(1) << "Cached TensorSliceReader for " << filepattern << ": " << reader;
+ } else {
+ auto cached_val = readers_[filepattern];
+ if (cached_val.first == *func_ptr) {
+ reader = cached_val.second;
+ VLOG(1) << "Using cached TensorSliceReader for " << filepattern << ": "
+ << reader;
+ } else {
+ LOG(WARNING) << "Caching disabled because the checkpoint file "
+ << "is being opened with two different open functions: "
+ << filepattern;
+ }
+ }
+
+ cv_.notify_all();
+ return reader;
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_reader_cache.h b/tensorflow/core/util/tensor_slice_reader_cache.h
new file mode 100644
index 0000000000..eaeeeec83f
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_reader_cache.h
@@ -0,0 +1,73 @@
+// The utility to read checkpoints for google brain tensor ops and v3
+// checkpoints for dist_belief.
+//
+
+#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_
+#define TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_
+
+#include <unordered_map>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/util/tensor_slice_reader.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+class TensorSliceReaderCache;
+
+// Wrapper to a lazily allocated TensorSliceReaderCache.
+class TensorSliceReaderCacheWrapper {
+ public:
+ TensorSliceReaderCacheWrapper();
+ ~TensorSliceReaderCacheWrapper();
+
+ // Same as TensorSliceReaderCache::GetReader().
+ const TensorSliceReader* GetReader(
+ const string& filepattern,
+ TensorSliceReader::OpenTableFunction open_function,
+ int preferred_shard) const;
+
+ private:
+ mutable mutex mu_;
+ mutable TensorSliceReaderCache* cache_ = nullptr;
+};
+
+// A cache of TensorSliceReaders.
+class TensorSliceReaderCache {
+ public:
+ TensorSliceReaderCache();
+ ~TensorSliceReaderCache();
+
+ // Returns the TensorSliceReader corresponding to 'filepattern' and the
+ // open_function. May return nullptr if we can not create a new
+ // TensorSliceReader for the filepattern/open_function combination.
+ const TensorSliceReader* GetReader(
+ const string& filepattern,
+ TensorSliceReader::OpenTableFunction open_function, int preferred_shard);
+
+ private:
+ // Need to use a regular function type in the key map as std::function does
+ // not support ==.
+ typedef Status (*OpenFuncType)(const string&, TensorSliceReader::Table**);
+
+ // Protects attributes below.
+ mutex mu_;
+
+ // Maps of opened readers.
+ std::unordered_map<string, std::pair<OpenFuncType, TensorSliceReader*>>
+ readers_;
+
+ // Set of keys that a previous GetReader() call is still trying to populate.
+ std::set<string> still_opening_;
+
+ // Condition variable to notify when a reader has been created.
+ condition_variable cv_;
+};
+
+} // namespace checkpoint
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_
diff --git a/tensorflow/core/util/tensor_slice_reader_test.cc b/tensorflow/core/util/tensor_slice_reader_test.cc
new file mode 100644
index 0000000000..e14b920003
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_reader_test.cc
@@ -0,0 +1,395 @@
+#include "tensorflow/core/util/tensor_slice_reader.h"
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+#include "tensorflow/core/util/tensor_slice_writer.h"
+#include "tensorflow/core/util/tensor_slice_reader_cache.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+namespace {
+
+// A simple test where we write a few tensor slices with a number of tensor
+// slice writers and then read them back from a tensor slice reader.
+//
+// We have a 2-d tensor of shape 4 X 5 that looks like this:
+//
+// 0 1 2 3 4
+// 5 6 7 8 9
+// 10 11 12 13 14
+// 15 16 17 18 19
+//
+// We assume this is a row-major matrix.
+
+void SimpleFloatHelper(TensorSliceWriter::CreateBuilderFunction create_function,
+ TensorSliceReader::OpenTableFunction open_function) {
+ const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint");
+
+ TensorShape shape({4, 5});
+
+ // File #0 contains a slice that is the top two rows:
+ //
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ const string fname = strings::StrCat(fname_base, "_0");
+ TensorSliceWriter writer(fname, create_function);
+ const float data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ TF_CHECK_OK(writer.Finish());
+ }
+
+ // File #1 contains two slices:
+ //
+ // slice #0 is the bottom left corner
+ // . . . . .
+ // . . . . .
+ // 10 11 12 . .
+ // 15 16 17 . .
+ //
+ // slice #1 is the bottom right corner
+ // . . . . .
+ // . . . . .
+ // . . . . .
+ // . . . 18 19
+ {
+ const string fname = strings::StrCat(fname_base, "_1");
+ TensorSliceWriter writer(fname, create_function);
+ // slice #0
+ {
+ const float data[] = {10, 11, 12, 15, 16, 17};
+ TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ }
+ // slice #1
+ {
+ const float data[] = {18, 19};
+ TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ }
+ TF_CHECK_OK(writer.Finish());
+ }
+
+ // Notice that we leave a hole in the tensor
+ // . . . . .
+ // . . . . .
+ // . . . (13) (14)
+ // . . . . .
+
+ // Now we need to read the tensor slices
+ const string filepattern = strings::StrCat(fname_base, "_*");
+ TensorSliceReader reader(filepattern, open_function);
+ EXPECT_OK(reader.status());
+ EXPECT_EQ(2, reader.num_files());
+
+ // We query some of the tensors
+ {
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("test", &shape, &type));
+ EXPECT_EQ(
+ "dim { size: 4 } "
+ "dim { size: 5 }",
+ shape.DebugString());
+ EXPECT_EQ(DT_FLOAT, type);
+ EXPECT_FALSE(reader.HasTensor("don't exist", nullptr, nullptr));
+ }
+
+ // Now we query some slices
+ //
+ // Slice #1 is an exact match
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
+ float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ float results[10];
+ EXPECT_TRUE(reader.CopySliceData("test", s, results));
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #2 is a subset match
+ // . . . . .
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
+ float expected[] = {5, 6, 7, 8, 9};
+ float results[5];
+ EXPECT_TRUE(reader.CopySliceData("test", s, results));
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #4 includes the hole and so there is no match
+ // . . . . .
+ // . . 7 8 9
+ // . . 12 13 14
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
+ float results[6];
+ EXPECT_FALSE(reader.CopySliceData("test", s, results));
+ }
+}
+
+TEST(TensorSliceReaderTest, SimpleFloat) {
+ SimpleFloatHelper(CreateTableTensorSliceBuilder, OpenTableTensorSliceReader);
+}
+
+template <typename T, typename U>
+void SimpleIntXHelper(TensorSliceWriter::CreateBuilderFunction create_function,
+ TensorSliceReader::OpenTableFunction open_function,
+ const string& checkpoint_file) {
+ const string fname_base = io::JoinPath(testing::TmpDir(), checkpoint_file);
+
+ TensorShape shape({4, 5});
+
+ // File #0 contains a slice that is the top two rows:
+ //
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ const string fname = strings::StrCat(fname_base, "_0");
+ TensorSliceWriter writer(fname, create_function);
+ const T data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ TF_CHECK_OK(writer.Finish());
+ }
+
+ // File #1 contains two slices:
+ //
+ // slice #0 is the bottom left corner
+ // . . . . .
+ // . . . . .
+ // 10 11 12 . .
+ // 15 16 17 . .
+ //
+ // slice #1 is the bottom right corner
+ // . . . . .
+ // . . . . .
+ // . . . . .
+ // . . . 18 19
+ {
+ const string fname = strings::StrCat(fname_base, "_1");
+ TensorSliceWriter writer(fname, create_function);
+ // slice #0
+ {
+ const T data[] = {10, 11, 12, 15, 16, 17};
+ TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ }
+ // slice #1
+ {
+ const T data[] = {18, 19};
+ TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ }
+ TF_CHECK_OK(writer.Finish());
+ }
+
+ // Notice that we leave a hole in the tensor
+ // . . . . .
+ // . . . . .
+ // . . . (13) (14)
+ // . . . . .
+
+ // Now we need to read the tensor slices
+ const string filepattern = strings::StrCat(fname_base, "_*");
+ TensorSliceReader reader(filepattern, open_function);
+ EXPECT_OK(reader.status());
+ EXPECT_EQ(2, reader.num_files());
+
+ // We query some of the tensors
+ {
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader.HasTensor("test", &shape, &type));
+ EXPECT_EQ(
+ "dim { size: 4 } "
+ "dim { size: 5 }",
+ shape.DebugString());
+ EXPECT_EQ(DataTypeToEnum<T>::v(), type);
+ EXPECT_FALSE(reader.HasTensor("don't exist", nullptr, nullptr));
+ }
+
+ // Now we query some slices
+ //
+ // Slice #1 is an exact match
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
+ T expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ U results[10];
+ EXPECT_TRUE(reader.CopySliceData("test", s, results));
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #2 is a subset match
+ // . . . . .
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
+ T expected[] = {5, 6, 7, 8, 9};
+ U results[5];
+ EXPECT_TRUE(reader.CopySliceData("test", s, results));
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #4 includes the hole and so there is no match
+ // . . . . .
+ // . . 7 8 9
+ // . . 12 13 14
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
+ U results[6];
+ EXPECT_FALSE(reader.CopySliceData("test", s, results));
+ }
+}
+
+#define TEST_SIMPLE_INT(TYPE, SAVED_TYPE) \
+ TEST(TensorSliceReaderTest, Simple##TYPE) { \
+ SimpleIntXHelper<TYPE, SAVED_TYPE>(CreateTableTensorSliceBuilder, \
+ OpenTableTensorSliceReader, \
+ #TYPE "_checkpoint"); \
+ }
+
+TEST_SIMPLE_INT(int32, int32)
+TEST_SIMPLE_INT(int64, int64)
+TEST_SIMPLE_INT(int16, int32)
+TEST_SIMPLE_INT(int8, int32)
+TEST_SIMPLE_INT(uint8, int32)
+
+void CachedTensorSliceReaderTesterHelper(
+ TensorSliceWriter::CreateBuilderFunction create_function,
+ TensorSliceReader::OpenTableFunction open_function) {
+ const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint");
+
+ TensorShape shape({4, 5});
+
+ // File #0 contains a slice that is the top two rows:
+ //
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ const string fname = strings::StrCat(fname_base, "_0");
+ TensorSliceWriter writer(fname, create_function);
+ const float data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ TensorSlice slice = TensorSlice::ParseOrDie("0,2:-");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ TF_CHECK_OK(writer.Finish());
+ }
+
+ // File #1 contains two slices:
+ //
+ // slice #0 is the bottom left corner
+ // . . . . .
+ // . . . . .
+ // 10 11 12 . .
+ // 15 16 17 . .
+ //
+ // slice #1 is the bottom right corner
+ // . . . . .
+ // . . . . .
+ // . . . . .
+ // . . . 18 19
+ {
+ const string fname = strings::StrCat(fname_base, "_1");
+ TensorSliceWriter writer(fname, create_function);
+ // slice #0
+ {
+ const float data[] = {10, 11, 12, 15, 16, 17};
+ TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ }
+ // slice #1
+ {
+ const float data[] = {18, 19};
+ TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2");
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ }
+ TF_CHECK_OK(writer.Finish());
+ }
+
+ // Notice that we leave a hole in the tensor
+ // . . . . .
+ // . . . . .
+ // . . . (13) (14)
+ // . . . . .
+
+ // Now we need to read the tensor slices
+ TensorSliceReaderCache cache;
+ const string filepattern = strings::StrCat(fname_base, "_*");
+ const TensorSliceReader* reader = cache.GetReader(
+ filepattern, open_function, TensorSliceReader::kLoadAllShards);
+ EXPECT_TRUE(reader != nullptr);
+ EXPECT_EQ(2, reader->num_files());
+
+ // We query some of the tensors
+ {
+ TensorShape shape;
+ DataType type;
+ EXPECT_TRUE(reader->HasTensor("test", &shape, &type));
+ EXPECT_EQ(
+ "dim { size: 4 } "
+ "dim { size: 5 }",
+ shape.DebugString());
+ EXPECT_EQ(DT_FLOAT, type);
+ EXPECT_FALSE(reader->HasTensor("don't exist", nullptr, nullptr));
+ }
+
+ // Make sure the reader is cached.
+ const TensorSliceReader* reader2 = cache.GetReader(
+ filepattern, open_function, TensorSliceReader::kLoadAllShards);
+ EXPECT_EQ(reader, reader2);
+
+ reader = cache.GetReader("file_does_not_exist", open_function,
+ TensorSliceReader::kLoadAllShards);
+ EXPECT_TRUE(reader == nullptr);
+}
+
+TEST(CachedTensorSliceReaderTest, SimpleFloat) {
+ CachedTensorSliceReaderTesterHelper(CreateTableTensorSliceBuilder,
+ OpenTableTensorSliceReader);
+}
+
+} // namespace
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_set.cc b/tensorflow/core/util/tensor_slice_set.cc
new file mode 100644
index 0000000000..765686f189
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_set.cc
@@ -0,0 +1,148 @@
+#include "tensorflow/core/util/tensor_slice_set.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/tensor_slice_util.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+TensorSliceSet::TensorSliceSet(const TensorShape& shape, DataType type)
+ : shape_(shape), type_(type) {}
+
+TensorSliceSet::~TensorSliceSet() {}
+
+Status TensorSliceSet::Register(const TensorSlice& slice,
+ const string& tag, const float* data) {
+ TensorShape result_shape;
+ TF_RETURN_IF_ERROR(slice.SliceTensorShape(shape_, &result_shape));
+ string str = slice.DebugString();
+ // We check if there is any intersection between this slice and any of the
+ // registered slices.
+ for (const auto x : slices_) {
+ if (slice.Overlaps(x.second.slice)) {
+ return errors::Internal("Overlapping slices: existing slice = ", x.first,
+ ", new slice = ", str);
+ }
+ }
+ // No overlap: we can now insert the slice
+ TensorSliceSet::SliceInfo info = {slice, tag, data,
+ result_shape.num_elements()};
+ slices_.insert(std::make_pair(str, info));
+ return Status::OK();
+}
+
+// TODO(yangke): merge Query() with QueryMeta()
+bool TensorSliceSet::Query(const TensorSlice& slice, float* data) const {
+ Status s;
+ string str = slice.DebugString();
+ // First we check if there is an exactly match (this is the dominant case).
+ const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str);
+ if (info) {
+ if (data) {
+ std::copy_n(info->data, info->num_floats, data);
+ }
+ return true;
+ } else {
+ // We didn't find any exact match but there is still a posibility that
+ // mutliple existing slices can be patched together to output the slice.
+ // We figure this out by computing the intersection of each of the existing
+ // slices with the query slice, and check if the union of all these
+ // intersections cover the entire slice. We rely on the fact that the
+ // existing slices don't have any intersection among themselves.
+ TensorShape target_shape;
+ Status s;
+ s = slice.SliceTensorShape(shape_, &target_shape);
+ if (!s.ok()) {
+ LOG(WARNING) << s;
+ return false;
+ }
+ int64 total_size = target_shape.num_elements();
+
+ int64 overlap_size = 0;
+ TensorSlice intersection;
+ TensorShape inter_shape;
+ for (const auto x : slices_) {
+ if (slice.Intersect(x.second.slice, &intersection)) {
+ s = intersection.SliceTensorShape(shape_, &inter_shape);
+ if (!s.ok()) {
+ LOG(WARNING) << s;
+ return false;
+ }
+ overlap_size += inter_shape.num_elements();
+ }
+ }
+ if (total_size == overlap_size) {
+ // We have it!
+ // Now we need to copy the data to "data"
+ if (data) {
+ for (const auto x : slices_) {
+ CopyDataFromTensorSliceToTensorSlice(shape_, x.second.slice, slice,
+ x.second.data, data);
+ }
+ }
+ return true;
+ } else {
+ // We don't have all the data for the asked tensor slice
+ return false;
+ }
+ }
+}
+
+bool TensorSliceSet::QueryMeta(
+ const TensorSlice& slice,
+ std::vector<std::pair<TensorSlice, string>>* results) const {
+ results->clear();
+ Status s;
+ string str = slice.DebugString();
+ // First we check if there is an exactly match (this is the dominant case).
+ const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str);
+ if (info) {
+ results->emplace_back(std::make_pair(info->slice, info->tag));
+ return true;
+ } else {
+ // We didn't find any exact match but there is still a posibility that
+ // multiple existing slices can be patched together to output the slice.
+ // We figure this out by computing the intersection of each of the existing
+ // slices with the query slice, and check if the union of all these
+ // intersections cover the entire slice. We rely on the fact that the
+ // existing slices don't have any intersection among themselves.
+ TensorShape target_shape;
+ Status s;
+ s = slice.SliceTensorShape(shape_, &target_shape);
+ if (!s.ok()) {
+ LOG(WARNING) << s;
+ return false;
+ }
+ int64 total_size = target_shape.num_elements();
+
+ int64 overlap_size = 0;
+ TensorSlice intersection;
+ TensorShape inter_shape;
+ for (const auto x : slices_) {
+ if (slice.Intersect(x.second.slice, &intersection)) {
+ s = intersection.SliceTensorShape(shape_, &inter_shape);
+ if (!s.ok()) {
+ LOG(WARNING) << s;
+ return false;
+ }
+ overlap_size += inter_shape.num_elements();
+ results->emplace_back(std::make_pair(x.second.slice, x.second.tag));
+ }
+ }
+ if (total_size == overlap_size) {
+ // We have it!
+ return true;
+ } else {
+ // We don't have all the data for the asked tensor slice
+ results->clear();
+ return false;
+ }
+ }
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_set.h b/tensorflow/core/util/tensor_slice_set.h
new file mode 100644
index 0000000000..f3f7ac0e76
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_set.h
@@ -0,0 +1,73 @@
+// A class to manage slices of a tensor. You can "register" set of slices for a
+// tensor and then "query" if we have data for a given slice.
+
+// TODO(yangke): consider moving it to a more private place so that we don't
+// need to expose the API.
+
+#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
+#define TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
+
+#include <string> // for string
+#include <unordered_map>
+
+#include "tensorflow/core/platform/port.h" // for int64
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/lib/core/stringpiece.h" // for StringPiece
+#include "tensorflow/core/public/status.h" // for Status
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+class TensorSliceSet {
+ public:
+ TensorSliceSet(const TensorShape& shape, DataType type);
+ virtual ~TensorSliceSet();
+
+ const TensorShape& shape() const { return shape_; }
+ const DataType type() const { return type_; }
+
+ // Register a new slice for the tensor. The "tag" is an arbitrary string
+ // associated with the slice (in one application it denotes the name of the
+ // file that contains the slice); the "data" points to the data of the tensor
+ // slice (it can be a nullptr).
+ // We don't take the ownership of "data" and the caller needs to make sure
+ // the data is always available during the life time of the tensor slice set
+ // if it is not nullptr.
+ Status Register(const TensorSlice& slice, const string& tag,
+ const float* data);
+
+ // Query about a new slice: checks if we have data for "slice" and if we have
+ // the data and "data" is not nullptr, fill "data" with the slice data. The
+ // caller needs to make sure "data" point to a large eough buffer.
+ // TODO(yangke): avoid unnecessary copying by using a core::RefCounted
+ // pointer.
+ bool Query(const TensorSlice& slice, float* data) const;
+
+ // Alternative way of querying about a new slice: instead of copying the
+ // data, it returns a list of meta data about the stored slices that will
+ // supply data for the slice.
+ bool QueryMeta(
+ const TensorSlice& slice,
+ std::vector<std::pair<tensorflow::TensorSlice, string>>* results) const;
+
+ private:
+ const TensorShape shape_;
+ const DataType type_;
+ struct SliceInfo {
+ TensorSlice slice;
+ const string tag;
+ const float* data;
+ int64 num_floats;
+ };
+ // We maintain a mapping from the slice string to the slice information.
+ std::unordered_map<string, SliceInfo> slices_;
+};
+
+} // namespace checkpoint
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
diff --git a/tensorflow/core/util/tensor_slice_set_test.cc b/tensorflow/core/util/tensor_slice_set_test.cc
new file mode 100644
index 0000000000..fb2f46f34c
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_set_test.cc
@@ -0,0 +1,227 @@
+#include "tensorflow/core/util/tensor_slice_set.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include <gtest/gtest.h>
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+namespace {
+
+// A simple test: we have a 2-d tensor of shape 4 X 5 that looks like this:
+//
+// 0 1 2 3 4
+// 5 6 7 8 9
+// 10 11 12 13 14
+// 15 16 17 18 19
+//
+// We assume this is a row-major matrix.
+//
+// We store the tensor in a couple of slices and verify that we can recover all
+// of them.
+TEST(TensorSliceSetTest, QueryTwoD) {
+ TensorShape shape({4, 5});
+
+ TensorSliceSet tss(shape, DT_FLOAT);
+ // We store a few slices.
+
+ // Slice #1 is the top two rows:
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ const float src_1[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-");
+ TF_CHECK_OK(tss.Register(slice_1, "", src_1));
+
+ // Slice #2 is the bottom left corner
+ // . . . . .
+ // . . . . .
+ // 10 11 12 . .
+ // 15 16 17 . .
+ const float src_2[] = {10, 11, 12, 15, 16, 17};
+ TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3");
+ TF_CHECK_OK(tss.Register(slice_2, "", src_2));
+
+ // Slice #3 is the bottom right corner
+ // . . . . .
+ // . . . . .
+ // . . . . .
+ // . . . 18 19
+ const float src_3[] = {18, 19};
+ TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2");
+ TF_CHECK_OK(tss.Register(slice_3, "", src_3));
+
+ // Notice that we leave a hole in the tensor
+ // . . . . .
+ // . . . . .
+ // . . . (13) (14)
+ // . . . . .
+
+ // Now we query some of the slices
+
+ // Slice #1 is an exact match
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
+ float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ float results[10];
+ EXPECT_TRUE(tss.Query(s, results));
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #2 is a subset match
+ // . . . . .
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
+ float expected[] = {5, 6, 7, 8, 9};
+ float results[5];
+ EXPECT_TRUE(tss.Query(s, results));
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #3 is a more complicated match: it needs the combination of a couple
+ // of slices
+ // . . . . .
+ // 5 6 7 . .
+ // 10 11 12 . .
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3");
+ float expected[] = {5, 6, 7, 10, 11, 12};
+ float results[6];
+ EXPECT_TRUE(tss.Query(s, results));
+ for (int i = 0; i < 6; ++i) {
+ EXPECT_EQ(expected[i], results[i]);
+ }
+ }
+
+ // Slice #4 includes the hole and so there is no match
+ // . . . . .
+ // . . 7 8 9
+ // . . 12 13 14
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
+ float results[6];
+ EXPECT_FALSE(tss.Query(s, results));
+ }
+}
+
+// Testing the meta version of the tensor slice set.
+TEST(TensorSliceSetTest, QueryMetaTwoD) {
+ TensorShape shape({4, 5});
+
+ TensorSliceSet tss(shape, DT_INT32);
+ // We store a few slices.
+
+ // Slice #1 is the top two rows:
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-");
+ TF_CHECK_OK(tss.Register(slice_1, "slice_1", nullptr));
+
+ // Slice #2 is the bottom left corner
+ // . . . . .
+ // . . . . .
+ // 10 11 12 . .
+ // 15 16 17 . .
+ TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3");
+ TF_CHECK_OK(tss.Register(slice_2, "slice_2", nullptr));
+
+ // Slice #3 is the bottom right corner
+ // . . . . .
+ // . . . . .
+ // . . . . .
+ // . . . 18 19
+ TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2");
+ TF_CHECK_OK(tss.Register(slice_3, "slice_3", nullptr));
+
+ // Notice that we leave a hole in the tensor
+ // . . . . .
+ // . . . . .
+ // . . . (13) (14)
+ // . . . . .
+
+ // Now we query some of the slices
+
+ // Slice #1 is an exact match
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ // We just need slice_1 for this
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
+ std::vector<std::pair<TensorSlice, string>> results;
+ EXPECT_TRUE(tss.QueryMeta(s, &results));
+ EXPECT_EQ(1, results.size());
+ EXPECT_EQ("0,2:-", results[0].first.DebugString());
+ EXPECT_EQ("slice_1", results[0].second);
+ }
+
+ // Slice #2 is a subset match
+ // . . . . .
+ // 5 6 7 8 9
+ // . . . . .
+ // . . . . .
+ // We just need slice_1 for this
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
+ std::vector<std::pair<TensorSlice, string>> results;
+ EXPECT_TRUE(tss.QueryMeta(s, &results));
+ EXPECT_EQ(1, results.size());
+ EXPECT_EQ("0,2:-", results[0].first.DebugString());
+ EXPECT_EQ("slice_1", results[0].second);
+ }
+
+ // Slice #3 is a more complicated match: it needs the combination of a couple
+ // of slices
+ // . . . . .
+ // 5 6 7 . .
+ // 10 11 12 . .
+ // . . . . .
+ // We need both slice_1 and slice_2 for this.
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3");
+ std::vector<std::pair<TensorSlice, string>> results;
+ EXPECT_TRUE(tss.QueryMeta(s, &results));
+ EXPECT_EQ(2, results.size());
+ EXPECT_EQ("2,2:0,3", results[0].first.DebugString());
+ EXPECT_EQ("slice_2", results[0].second);
+ EXPECT_EQ("0,2:-", results[1].first.DebugString());
+ EXPECT_EQ("slice_1", results[1].second);
+ }
+
+ // Slice #4 includes the hole and so there is no match
+ // . . . . .
+ // . . 7 8 9
+ // . . 12 13 14
+ // . . . . .
+ {
+ TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
+ std::vector<std::pair<TensorSlice, string>> results;
+ EXPECT_FALSE(tss.QueryMeta(s, &results));
+ EXPECT_EQ(0, results.size());
+ }
+}
+
+} // namespace
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_util.h b/tensorflow/core/util/tensor_slice_util.h
new file mode 100644
index 0000000000..5422c3bef3
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_util.h
@@ -0,0 +1,88 @@
+#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_
+#define TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_
+
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+
+// Some hackery to invoke eigen tensor to copy over tensor slices with variable
+// dimension tensors.
+// TODO(yangke): get rid of that once the variable dimension tensor support is
+// in.
+static const int kTensorSliceMaxRank = 8;
+
+// Create a tensor map with the given shape: we support up to 8 dimensions. If
+// the shape has less than 8 dimensions, we pad the remaining dimension with 1.
+template <typename T>
+Eigen::TensorMap<Eigen::Tensor<T, kTensorSliceMaxRank, Eigen::RowMajor>>
+GetEigenTensorMapFromTensorShape(const TensorShape& shape, T* data) {
+ Eigen::DSizes<Eigen::DenseIndex, kTensorSliceMaxRank> dsizes =
+ shape.AsEigenDSizesWithPadding<kTensorSliceMaxRank>();
+ Eigen::TensorMap<Eigen::Tensor<T, kTensorSliceMaxRank, Eigen::RowMajor>> eig(
+ data, dsizes);
+ return eig;
+}
+
+// Given a tensor described by "shape", two slices "slice_s" and "slice_d",
+// and two pointers "ptr_s" and "ptr_d", where "ptr_s" points to a chunk of
+// memory that stores the data for "slice_s" and "ptr_d" points to a chunk of
+// memory that stores the data for "slice_d". This function copies the data
+// that belongs to the intersection of the two slices from slice_s to
+// slice_d. Uses Tensor cast<DstT>() to convert from SrcT to DstT. Returns true
+// iff the two slices share any intersection (and thus some data is copied).
+// TODO(yangke): figure out if we can make it private.
+template <typename SrcT, typename DstT>
+static bool CopyDataFromTensorSliceToTensorSlice(const TensorShape& shape,
+ const TensorSlice& slice_s,
+ const TensorSlice& slice_d,
+ const SrcT* ptr_s,
+ DstT* ptr_d) {
+ CHECK_LE(shape.dims(), kTensorSliceMaxRank) << "Only tensors of size up to "
+ << kTensorSliceMaxRank
+ << " are supported";
+ // We need to compute the intersection of the two slices.
+ TensorSlice inter;
+ if (!slice_s.Intersect(slice_d, &inter)) {
+ // There is no intersection: returns false.
+ return false;
+ } else {
+ // We need to compute the applied shapes after applying slice_s and
+ // slice_d.
+ TensorShape shp_s, shp_d;
+ Status s;
+ s = slice_s.SliceTensorShape(shape, &shp_s);
+ if (!s.ok()) {
+ LOG(WARNING) << s;
+ return false;
+ }
+ s = slice_d.SliceTensorShape(shape, &shp_d);
+ if (!s.ok()) {
+ LOG(WARNING) << s;
+ return false;
+ }
+
+ // We need to compute the relative slice of "inter" w.r.t. both slice_s and
+ // slice_d.
+ TensorSlice rel_s, rel_d;
+ slice_s.ComputeRelative(inter, &rel_s);
+ slice_d.ComputeRelative(inter, &rel_d);
+
+ // Get the eigen tensor maps to the data.
+ auto t_s = GetEigenTensorMapFromTensorShape(shp_s, ptr_s);
+ auto t_d = GetEigenTensorMapFromTensorShape(shp_d, ptr_d);
+
+ Eigen::DSizes<Eigen::DenseIndex, kTensorSliceMaxRank> s_start, s_len,
+ d_start, d_len;
+
+ rel_s.FillIndicesAndSizes<kTensorSliceMaxRank>(shp_s, &s_start, &s_len);
+ rel_d.FillIndicesAndSizes<kTensorSliceMaxRank>(shp_d, &d_start, &d_len);
+ t_d.slice(d_start, d_len) = t_s.slice(s_start, s_len).template cast<DstT>();
+ return true;
+ }
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_TENSOR_SLICE_UTIL_H_
diff --git a/tensorflow/core/util/tensor_slice_util_test.cc b/tensorflow/core/util/tensor_slice_util_test.cc
new file mode 100644
index 0000000000..348b0c884e
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_util_test.cc
@@ -0,0 +1,91 @@
+#include "tensorflow/core/util/tensor_slice_util.h"
+
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+// Testing copying data from one tensor slice to another tensor slice
+TEST(TensorSliceUtilTest, CopyTensorSliceToTensorSlice) {
+ // We map out a 2-d tensor of size 4 X 5 and we want the final results look
+ // like this:
+ //
+ // 0 1 2 3 4
+ // 5 6 7 8 9
+ // 10 11 12 13 14
+ // 15 16 17 18 19
+ //
+ // We assume this is a row-major matrix
+ //
+ TensorShape shape({4, 5});
+
+ // We will try to do a couple of slice to slice copies.
+
+ // Case 1: simple identity copy
+ // The slice is the "interior" of the matrix
+ // . . . . .
+ // . 6 7 8 .
+ // , 11 12 13 .
+ // . . . . .
+ {
+ TensorSlice slice_s = TensorSlice::ParseOrDie("1,2:1,3");
+ TensorSlice slice_d = TensorSlice::ParseOrDie("1,2:1,3");
+ const float ptr_s[] = {6, 7, 8, 11, 12, 13};
+ float ptr_d[6];
+ for (int i = 0; i < 6; ++i) {
+ ptr_d[i] = 0;
+ }
+ EXPECT_TRUE(CopyDataFromTensorSliceToTensorSlice(shape, slice_s, slice_d,
+ ptr_s, ptr_d));
+ for (int i = 0; i < 6; ++i) {
+ EXPECT_EQ(ptr_s[i], ptr_d[i]);
+ }
+ }
+
+ // Case 2: no intersection
+ {
+ TensorSlice slice_s = TensorSlice::ParseOrDie("1,2:1,3");
+ TensorSlice slice_d = TensorSlice::ParseOrDie("3,1:2,3");
+ const float ptr_s[] = {6, 7, 8, 11, 12, 13};
+ float ptr_d[6];
+ EXPECT_FALSE(CopyDataFromTensorSliceToTensorSlice(shape, slice_s, slice_d,
+ ptr_s, ptr_d));
+ }
+
+ // Case 3: a trickier case
+ // The source slice is on the upper left corner:
+ // 0 1 2 . .
+ // 5 6 7 . .
+ // 10 11 12 . .
+ // . . . . .
+ //
+ // The destination slice is the right part of the middle stripe:
+ // . . . . .
+ // . X X X X
+ // . X X X X
+ // . . . . .
+ //
+ // So we expect to copy over the 2X2 block:
+ // . . . . .
+ // . 6 7 . .
+ // . 11 12 . .
+ // . . . . .
+ {
+ TensorSlice slice_s = TensorSlice::ParseOrDie("0,3:0,3");
+ TensorSlice slice_d = TensorSlice::ParseOrDie("1,2:1,4");
+ const float ptr_s[] = {0, 1, 2, 5, 6, 7, 10, 11, 12};
+ float ptr_d[8];
+ for (int i = 0; i < 8; ++i) {
+ ptr_d[i] = 0;
+ }
+ EXPECT_TRUE(CopyDataFromTensorSliceToTensorSlice(shape, slice_s, slice_d,
+ ptr_s, ptr_d));
+ const float expected[] = {6, 7, 0, 0, 11, 12, 0, 0};
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_EQ(expected[i], ptr_d[i]);
+ }
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_writer.cc b/tensorflow/core/util/tensor_slice_writer.cc
new file mode 100644
index 0000000000..bb2fd96c05
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_writer.cc
@@ -0,0 +1,110 @@
+#include "tensorflow/core/util/tensor_slice_writer.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/io/table_builder.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/public/env.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+namespace {
+
+class TableBuilder : public TensorSliceWriter::Builder {
+ public:
+ TableBuilder(const string& name, WritableFile* f)
+ : name_(name),
+ file_(f),
+ builder_(new table::TableBuilder(table::Options(), f)) {}
+ void Add(StringPiece key, StringPiece val) override {
+ builder_->Add(key, val);
+ }
+ Status Finish(int64* file_size) override {
+ *file_size = -1;
+ Status s = builder_->Finish();
+ if (s.ok()) {
+ s = file_->Close();
+ if (s.ok()) {
+ *file_size = builder_->FileSize();
+ }
+ }
+ if (!s.ok()) {
+ s = errors::Internal("Error writing (tmp) checkpoint file: ", name_, ": ",
+ s.ToString());
+ }
+ builder_.reset();
+ file_.reset();
+ return s;
+ }
+
+ private:
+ string name_;
+ std::unique_ptr<WritableFile> file_;
+ std::unique_ptr<table::TableBuilder> builder_;
+};
+} // anonymous namespace
+
+Status CreateTableTensorSliceBuilder(
+ const string& name, TensorSliceWriter::Builder** builder) {
+ *builder = nullptr;
+ WritableFile* f;
+ Status s = Env::Default()->NewWritableFile(name, &f);
+ if (s.ok()) {
+ *builder = new TableBuilder(name, f);
+ return Status::OK();
+ } else {
+ return s;
+ }
+}
+
+TensorSliceWriter::TensorSliceWriter(const string& filename,
+ CreateBuilderFunction create_builder)
+ : filename_(filename),
+ create_builder_(create_builder),
+ tmpname_(strings::StrCat(filename, ".tempstate", random::New64())),
+ slices_(0) {}
+
+Status TensorSliceWriter::Finish() {
+ Builder* b;
+ Status s = create_builder_(tmpname_, &b);
+ if (!s.ok()) {
+ delete b;
+ return s;
+ }
+ std::unique_ptr<Builder> builder(b);
+
+ // We save the saved tensor slice metadata as the first element.
+ string meta;
+ sts_.AppendToString(&meta);
+ builder->Add(kSavedTensorSlicesKey, meta);
+
+ // Go through all the data and add them
+ for (const auto& x : data_) {
+ builder->Add(x.first, x.second);
+ }
+
+ int64 file_size;
+ s = builder->Finish(&file_size);
+ // We need to rename the file to the proper name
+ if (s.ok()) {
+ s = Env::Default()->RenameFile(tmpname_, filename_);
+ if (s.ok()) {
+ VLOG(1) << "Written " << slices_ << " slices for "
+ << sts_.meta().tensor_size() << " tensors (" << file_size
+ << " bytes) to " << filename_;
+ } else {
+ LOG(ERROR) << "Failed to rename file " << tmpname_ << " to " << filename_;
+ }
+ } else {
+ Env::Default()->DeleteFile(tmpname_);
+ }
+ return s;
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_slice_writer.h b/tensorflow/core/util/tensor_slice_writer.h
new file mode 100644
index 0000000000..cce3880cb3
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_writer.h
@@ -0,0 +1,149 @@
+// The utility to write checkpoints for google brain tensor ops and v3
+// checkpoints for dist_belief.
+//
+
+#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_
+#define TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_
+
+#include <unordered_map>
+
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/util/saved_tensor_slice.pb.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+class TensorSliceWriter {
+ public:
+ // Abstract interface that TensorSliceWriter uses for building
+ class Builder {
+ public:
+ virtual ~Builder() {}
+ virtual void Add(StringPiece key, StringPiece value) = 0;
+ virtual Status Finish(int64* file_size) = 0;
+ };
+ typedef std::function<Status(const string&, Builder**)>
+ CreateBuilderFunction;
+
+ TensorSliceWriter(const string& filename,
+ CreateBuilderFunction create_builder);
+ virtual ~TensorSliceWriter() {}
+ // Adds a slice. We support float and int32 for now.
+ // TODO(yangke): add more supports
+ template <typename T>
+ Status Add(const string& name, const TensorShape& shape,
+ const TensorSlice& slice, const T* data);
+ Status Finish();
+
+ private:
+ // Allocate "num_elements" elements in "ss" and save the data in "data"
+ // there.
+ template <typename T>
+ static void SaveData(const T* data, int num_elements, SavedSlice* ss);
+
+ const string filename_;
+ const CreateBuilderFunction create_builder_;
+ const string tmpname_;
+
+ // A mapping from the tensor names to their index in meta_.saved_slice_meta()
+ std::unordered_map<string, int> name_to_index_;
+ // The metadata that holds all the saved tensor slices.
+ SavedTensorSlices sts_;
+ // The data to be written to the builder
+ std::map<string, string> data_;
+ // Total number of slices written
+ int slices_;
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorSliceWriter);
+};
+
+template <typename T>
+Status TensorSliceWriter::Add(const string& name, const TensorShape& shape,
+ const TensorSlice& slice, const T* data) {
+ // The tensor and the slice have to be compatible
+ if (shape.dims() != slice.dims()) {
+ return errors::Internal("Incompatible tensor shape and slice: ", "shape = ",
+ shape.DebugString(), ", slice = ",
+ slice.DebugString());
+ }
+ DataType dt = DataTypeToEnum<T>::value;
+ // We need to add an entry for "name" if there isn't an entry already.
+ int index = gtl::FindWithDefault(name_to_index_, name, -1);
+ if (index >= 0) {
+ // The same tensor has been registered -- we verify that the shapes and the
+ // type agree.
+ const SavedSliceMeta& ssm = sts_.meta().tensor(index);
+ CHECK_EQ(name, ssm.name()) << ssm.ShortDebugString();
+ TensorShape ssm_shape(ssm.shape());
+ if (!shape.IsSameSize(ssm_shape)) {
+ return errors::Internal("Mismatching shapes: existing tensor = ",
+ ssm_shape.DebugString(), ", trying to add name ",
+ name, ", shape = ", shape.DebugString());
+ }
+ if (dt != ssm.type()) {
+ return errors::Internal(
+ "Mismatching types: existing type = ", DataTypeString(ssm.type()),
+ ", trying to add name ", name, ", type = ", DataTypeString(dt));
+ }
+ } else {
+ // Insert the new tensor name with the shape information
+ index = sts_.meta().tensor_size();
+ name_to_index_.insert(std::make_pair(name, index));
+ SavedSliceMeta* ssm = sts_.mutable_meta()->add_tensor();
+ ssm->set_name(name);
+ shape.AsProto(ssm->mutable_shape());
+ ssm->set_type(dt);
+ }
+ // Now we need to add the slice info the list of slices.
+ SavedSliceMeta* ssm = sts_.mutable_meta()->mutable_tensor(index);
+ slice.AsProto(ssm->add_slice());
+
+ // Now we need to add the real data.
+ {
+ SavedTensorSlices sts;
+ SavedSlice* ss = sts.mutable_data();
+ ss->set_name(name);
+ slice.AsProto(ss->mutable_slice());
+ TensorShape saved_shape(ssm->shape());
+ TensorShape sliced_shape;
+ TF_RETURN_IF_ERROR(slice.SliceTensorShape(saved_shape, &sliced_shape));
+ SaveData(data, sliced_shape.num_elements(), ss);
+ string key = EncodeTensorNameSlice(name, slice);
+ // TODO(yangke): consider doing a two-pass thing where the first pass just
+ // list the tensor slices we want to save and then another pass to actually
+ // set the data. Need to figure out if the interface works well.
+ std::pair<string, string> key_value(key, "");
+ sts.AppendToString(&key_value.second);
+ data_.insert(key_value);
+ }
+ ++slices_;
+ return Status::OK();
+}
+
+template <typename T>
+void TensorSliceWriter::SaveData(const T* data, int num_elements,
+ SavedSlice* ss) {
+ Fill(data, num_elements, ss->mutable_data());
+}
+
+// Create a table builder that will write to "filename" in
+// tensorflow::io::Table format. If successful, return OK
+// and set "*builder" to the allocated builder. Otherwise, return a
+// non-OK status.
+Status CreateTableTensorSliceBuilder(const string& filename,
+ TensorSliceWriter::Builder** builder);
+
+} // namespace checkpoint
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_
diff --git a/tensorflow/core/util/tensor_slice_writer_test.cc b/tensorflow/core/util/tensor_slice_writer_test.cc
new file mode 100644
index 0000000000..ca3dffe422
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_writer_test.cc
@@ -0,0 +1,248 @@
+#include "tensorflow/core/util/tensor_slice_writer.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
+#include "tensorflow/core/util/tensor_slice_reader.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/test.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+class TensorSliceWriteTestHelper {
+ public:
+ static void CheckEntries(const string& fname);
+ static void GetData(TensorSliceReader::Table* table, const string& name,
+ const TensorSlice& slice, SavedSlice* ss);
+};
+
+namespace {
+
+// Testing that an array is what is expected
+void ExpectIdenticalFloatArrays(const float* expected, int size,
+ const float* actual) {
+ // TODO(yangke): copy some of the Dump* functions over
+ // LOG(INFO) << "Expected = " << DumpFloatArray(expected, size);
+ // LOG(INFO) << "Actual = " << DumpFloatArray(actual, size);
+ for (int i = 0; i < size; ++i) {
+ EXPECT_NEAR(expected[i], actual[i], 1e-6);
+ }
+}
+
+template <typename T, typename U>
+void ExpectIdenticalIntArrays(const T* expected, int size, const U* actual) {
+ for (int i = 0; i < size; ++i) {
+ EXPECT_EQ(expected[i], static_cast<T>(actual[i]));
+ }
+}
+
+// Nifty routine to get the size of an array
+template <typename T, unsigned SIZE>
+inline size_t ArraySize(const T(&v)[SIZE]) {
+ return SIZE;
+}
+
+// A simple test on writing a few tensor slices
+// TODO(yangke): refactor into smaller tests: will do as we add more stuff to
+// the writer.
+TEST(TensorSliceWriteTest, SimpleWrite) {
+ const string filename = io::JoinPath(testing::TmpDir(), "checkpoint");
+
+ TensorSliceWriter writer(filename, CreateTableTensorSliceBuilder);
+
+ // Add some int32 tensor slices
+ {
+ TensorShape shape({5, 10});
+ TensorSlice slice = TensorSlice::ParseOrDie("-:0,1");
+ const int32 data[] = {0, 1, 2, 3, 4};
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ }
+
+ // Two slices share the same tensor name
+ {
+ TensorShape shape({5, 10});
+ TensorSlice slice = TensorSlice::ParseOrDie("-:3,1");
+ const int32 data[] = {10, 11, 12, 13, 14};
+ TF_CHECK_OK(writer.Add("test", shape, slice, data));
+ }
+
+ // Another slice from a different float tensor -- it has a different name and
+ // should be inserted in front of the previous tensor
+ {
+ TensorShape shape({3, 2});
+ TensorSlice slice = TensorSlice::ParseOrDie("-:-");
+ const float data[] = {1.2, 1.3, 1.4, 2.1, 2.2, 2.3};
+ TF_CHECK_OK(writer.Add("AA", shape, slice, data));
+ }
+
+ // A slice with int64 data
+ {
+ TensorShape shape({5, 10});
+ TensorSlice slice = TensorSlice::ParseOrDie("-:3,1");
+ const int64 data[] = {10, 11, 12, 13, 14};
+ TF_CHECK_OK(writer.Add("int64", shape, slice, data));
+ }
+
+ // A slice with int16 data
+ {
+ TensorShape shape({5, 10});
+ TensorSlice slice = TensorSlice::ParseOrDie("-:3,1");
+ const int16 data[] = {10, 11, 12, 13, 14};
+ TF_CHECK_OK(writer.Add("int16", shape, slice, data));
+ }
+
+ TF_CHECK_OK(writer.Finish());
+
+ // Now we examine the checkpoint file manually.
+ TensorSliceWriteTestHelper::CheckEntries(filename);
+}
+
+} // namespace
+
+void TensorSliceWriteTestHelper::GetData(TensorSliceReader::Table* table,
+ const string& name,
+ const TensorSlice& slice,
+ SavedSlice* ss) {
+ string key = EncodeTensorNameSlice(name, slice);
+ string value;
+ EXPECT_TRUE(table->Get(key, &value));
+ SavedTensorSlices sts;
+ EXPECT_TRUE(ParseProtoUnlimited(&sts, value));
+ EXPECT_FALSE(sts.has_meta());
+ *ss = sts.data();
+ EXPECT_EQ(name, ss->name());
+ TensorSlice slice2(ss->slice());
+ EXPECT_EQ(slice.DebugString(), slice2.DebugString());
+}
+
+void TensorSliceWriteTestHelper::CheckEntries(const string& fname) {
+ TensorSliceReader::Table* tptr;
+ TF_CHECK_OK(OpenTableTensorSliceReader(fname, &tptr));
+ std::unique_ptr<TensorSliceReader::Table> table(tptr);
+ CHECK_NOTNULL(table.get());
+
+ // We expect a block of SavedTensorSlices
+ string value;
+ ASSERT_TRUE(table->Get(kSavedTensorSlicesKey, &value));
+ {
+ SavedTensorSlices sts;
+ EXPECT_TRUE(ParseProtoUnlimited(&sts, value));
+ // We also expect two entries for the tensors
+ EXPECT_TRUE(sts.has_meta());
+ EXPECT_EQ(4, sts.meta().tensor_size());
+ // We don't expect any data in the first block.
+ EXPECT_FALSE(sts.has_data());
+ // The two tensors should be stored in the same order as they are first
+ // created.
+ {
+ // The two slices of the "test" tensor
+ const SavedSliceMeta& ssm = sts.meta().tensor(0);
+ EXPECT_EQ("test", ssm.name());
+ EXPECT_EQ(
+ "dim { size: 5 } "
+ "dim { size: 10 }",
+ ssm.shape().ShortDebugString());
+ EXPECT_EQ(DT_INT32, ssm.type());
+ EXPECT_EQ(2, ssm.slice_size());
+ TensorSlice s0(ssm.slice(0));
+ TensorSlice s1(ssm.slice(1));
+ EXPECT_EQ("-:0,1", s0.DebugString());
+ EXPECT_EQ("-:3,1", s1.DebugString());
+ }
+ {
+ // The "AA" tensor
+ const SavedSliceMeta& ssm = sts.meta().tensor(1);
+ EXPECT_EQ("AA", ssm.name());
+ EXPECT_EQ(
+ "dim { size: 3 } "
+ "dim { size: 2 }",
+ ssm.shape().ShortDebugString());
+ EXPECT_EQ(DT_FLOAT, ssm.type());
+ EXPECT_EQ(1, ssm.slice_size());
+ TensorSlice s0(ssm.slice(0));
+ EXPECT_EQ("-:-", s0.DebugString());
+ }
+ {
+ // The "int64" tensor
+ const SavedSliceMeta& ssm = sts.meta().tensor(2);
+ EXPECT_EQ("int64", ssm.name());
+ EXPECT_EQ(
+ "dim { size: 5 } "
+ "dim { size: 10 }",
+ ssm.shape().ShortDebugString());
+ EXPECT_EQ(DT_INT64, ssm.type());
+ EXPECT_EQ(1, ssm.slice_size());
+ TensorSlice s0(ssm.slice(0));
+ EXPECT_EQ("-:3,1", s0.DebugString());
+ }
+ {
+ // The "int16" tensor
+ const SavedSliceMeta& ssm = sts.meta().tensor(3);
+ EXPECT_EQ("int16", ssm.name());
+ EXPECT_EQ(
+ "dim { size: 5 } "
+ "dim { size: 10 }",
+ ssm.shape().ShortDebugString());
+ EXPECT_EQ(DT_INT16, ssm.type());
+ EXPECT_EQ(1, ssm.slice_size());
+ TensorSlice s0(ssm.slice(0));
+ EXPECT_EQ("-:3,1", s0.DebugString());
+ }
+ }
+
+ // We expect 5 blocks of tensor data
+ {
+ // Block 1: we expect it to be the full slice of the "AA" tensor
+ SavedSlice ss;
+ GetData(table.get(), "AA", TensorSlice(2), &ss);
+ const float data[] = {1.2, 1.3, 1.4, 2.1, 2.2, 2.3};
+ EXPECT_EQ(ArraySize(data), ss.data().float_val_size());
+ ExpectIdenticalFloatArrays(data, ArraySize(data),
+ ss.data().float_val().data());
+ }
+
+ {
+ // Block 2: we expect it to be the first slice of the "test" tensor
+ SavedSlice ss;
+ GetData(table.get(), "test", TensorSlice({{0, -1}, {0, 1}}), &ss);
+ const int32 data[] = {0, 1, 2, 3, 4};
+ EXPECT_EQ(ArraySize(data), ss.data().int_val_size());
+ ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data());
+ }
+
+ {
+ // Block 3: we expect it to be the second slice of the "test" tensor
+ SavedSlice ss;
+ GetData(table.get(), "test", TensorSlice({{0, -1}, {3, 1}}), &ss);
+ const int32 data[] = {10, 11, 12, 13, 14};
+ EXPECT_EQ(ArraySize(data), ss.data().int_val_size());
+ ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data());
+ }
+
+ {
+ // Block 4: we expect it to be the slice of the "int64" tensor
+ SavedSlice ss;
+ GetData(table.get(), "int64", TensorSlice({{0, -1}, {3, 1}}), &ss);
+ const int64 data[] = {10, 11, 12, 13, 14};
+ EXPECT_EQ(ArraySize(data), ss.data().int64_val_size());
+ ExpectIdenticalIntArrays(data, ArraySize(data),
+ ss.data().int64_val().data());
+ }
+
+ {
+ // Block 5: we expect it to be the slice of the "int16" tensor
+ SavedSlice ss;
+ GetData(table.get(), "int16", TensorSlice({{0, -1}, {3, 1}}), &ss);
+ const int16 data[] = {10, 11, 12, 13, 14};
+ EXPECT_EQ(ArraySize(data), ss.data().int_val_size());
+ ExpectIdenticalIntArrays(data, ArraySize(data), ss.data().int_val().data());
+ }
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/use_cudnn.cc b/tensorflow/core/util/use_cudnn.cc
new file mode 100644
index 0000000000..544b48a679
--- /dev/null
+++ b/tensorflow/core/util/use_cudnn.cc
@@ -0,0 +1,20 @@
+#include "tensorflow/core/util/use_cudnn.h"
+
+#include <stdlib.h>
+
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+bool CanUseCudnn() {
+ const char* tf_use_cudnn = getenv("TF_USE_CUDNN");
+ if (tf_use_cudnn != nullptr) {
+ string tf_use_cudnn_str = tf_use_cudnn;
+ if (tf_use_cudnn_str == "0") {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/use_cudnn.h b/tensorflow/core/util/use_cudnn.h
new file mode 100644
index 0000000000..20ce24c513
--- /dev/null
+++ b/tensorflow/core/util/use_cudnn.h
@@ -0,0 +1,12 @@
+// The utility to check whether we have Cudnn depenedency.
+
+#ifndef TENSORFLOW_UTIL_USE_CUDNN_H_
+#define TENSORFLOW_UTIL_USE_CUDNN_H_
+
+namespace tensorflow {
+
+bool CanUseCudnn();
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_USE_CUDNN_H_
diff --git a/tensorflow/core/util/util.cc b/tensorflow/core/util/util.cc
new file mode 100644
index 0000000000..14ac513074
--- /dev/null
+++ b/tensorflow/core/util/util.cc
@@ -0,0 +1,81 @@
+#include "tensorflow/core/util/util.h"
+
+#include "tensorflow/core/platform/logging.h"
+namespace tensorflow {
+
+StringPiece NodeNamePrefix(const StringPiece& op_name) {
+ StringPiece sp(op_name);
+ auto p = sp.find('/');
+ if (p == StringPiece::npos || p == 0) {
+ return "";
+ } else {
+ return StringPiece(sp.data(), p);
+ }
+}
+
+StringPiece NodeNameFullPrefix(const StringPiece& op_name) {
+ StringPiece sp(op_name);
+ auto p = sp.rfind('/');
+ if (p == StringPiece::npos || p == 0) {
+ return "";
+ } else {
+ return StringPiece(sp.data(), p);
+ }
+}
+
+MovingAverage::MovingAverage(int window)
+ : window_(window),
+ sum_(0.0),
+ data_(new double[window_]),
+ head_(0),
+ count_(0) {
+ CHECK_GE(window, 1);
+}
+
+MovingAverage::~MovingAverage() { delete[] data_; }
+
+void MovingAverage::Clear() {
+ count_ = 0;
+ head_ = 0;
+ sum_ = 0;
+}
+
+double MovingAverage::GetAverage() const {
+ if (count_ == 0) {
+ return 0;
+ } else {
+ return static_cast<double>(sum_) / count_;
+ }
+}
+
+void MovingAverage::AddValue(double v) {
+ if (count_ < window_) {
+ // This is the warmup phase. We don't have a full window's worth of data.
+ head_ = count_;
+ data_[count_++] = v;
+ } else {
+ if (window_ == ++head_) {
+ head_ = 0;
+ }
+ // Toss the oldest element
+ sum_ -= data_[head_];
+ // Add the newest element
+ data_[head_] = v;
+ }
+ sum_ += v;
+}
+
+static char hex_char[] = "0123456789abcdef";
+
+string PrintMemory(const char* ptr, int n) {
+ string ret;
+ ret.resize(n * 3);
+ for (int i = 0; i < n; ++i) {
+ ret[i * 3] = ' ';
+ ret[i * 3 + 1] = hex_char[ptr[i] >> 4];
+ ret[i * 3 + 2] = hex_char[ptr[i] & 0xf];
+ }
+ return ret;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/util.h b/tensorflow/core/util/util.h
new file mode 100644
index 0000000000..52650bd8ea
--- /dev/null
+++ b/tensorflow/core/util/util.h
@@ -0,0 +1,40 @@
+#ifndef TENSORFLOW_UTIL_UTIL_H_
+#define TENSORFLOW_UTIL_UTIL_H_
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+
+// If op_name has '/' in it, then return everything before the first '/'.
+// Otherwise return empty string.
+StringPiece NodeNamePrefix(const StringPiece& op_name);
+
+// If op_name has '/' in it, then return everything before the last '/'.
+// Otherwise return empty string.
+StringPiece NodeNameFullPrefix(const StringPiece& op_name);
+
+class MovingAverage {
+ public:
+ explicit MovingAverage(int window);
+ ~MovingAverage();
+
+ void Clear();
+
+ double GetAverage() const;
+ void AddValue(double v);
+
+ private:
+ const int window_; // Max size of interval
+ double sum_; // Sum over interval
+ double* data_; // Actual data values
+ int head_; // Offset of the newest statistic in data_
+ int count_; // # of valid data elements in window
+};
+
+// Returns a string printing bytes in ptr[0..n). The output looks
+// like "00 01 ef cd cd ef".
+string PrintMemory(const char* ptr, int n);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_UTIL_H_
diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc
new file mode 100644
index 0000000000..d9ab0805c5
--- /dev/null
+++ b/tensorflow/core/util/work_sharder.cc
@@ -0,0 +1,57 @@
+#include "tensorflow/core/util/work_sharder.h"
+
+#include <vector>
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+void Shard(int num_workers, thread::ThreadPool* workers, int64 total,
+ int64 cost_per_unit, std::function<void(int64, int64)> work) {
+ CHECK_GE(total, 0);
+ if (total == 0) {
+ return;
+ }
+ if (num_workers <= 1) {
+ // Just inline the whole work since we only have 1 thread (core).
+ work(0, total);
+ return;
+ }
+ cost_per_unit = std::max(1LL, cost_per_unit);
+ // We shard [0, total) into "num_shards" shards.
+ // 1 <= num_shards <= num worker threads
+ //
+ // If total * cost_per_unit is small, it is not worth shard too
+ // much. Let us assume each cost unit is 1ns, kMinCostPerShard=10000
+ // is 10us.
+ static const int64 kMinCostPerShard = 10000;
+ const int num_shards = std::max(
+ 1, std::min<int>(num_workers, total * cost_per_unit / kMinCostPerShard));
+ // Each shard contains up to "block_size" units. [0, total) is sharded
+ // into:
+ // [0, block_size), [block_size, 2*block_size), ...
+ // The 1st shard is done by the caller thread and the other shards
+ // are dispatched to the worker threads. The last shard may be smaller than
+ // block_size.
+ const int64 block_size = (total + num_shards - 1) / num_shards;
+ CHECK_GT(block_size, 0); // total > 0 guarantees this.
+ if (block_size >= total) {
+ work(0, total);
+ return;
+ }
+ const int num_shards_used = (total + block_size - 1) / block_size;
+ BlockingCounter counter(num_shards_used - 1);
+ for (int64 start = block_size; start < total; start += block_size) {
+ auto limit = std::min(start + block_size, total);
+ workers->Schedule([&work, &counter, start, limit]() {
+ work(start, limit); // Compute the shard.
+ counter.DecrementCount(); // The shard is done.
+ });
+ }
+
+ // Inline execute the 1st shard.
+ work(0, std::min(block_size, total));
+ counter.Wait();
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/util/work_sharder.h b/tensorflow/core/util/work_sharder.h
new file mode 100644
index 0000000000..1ea2cf4397
--- /dev/null
+++ b/tensorflow/core/util/work_sharder.h
@@ -0,0 +1,33 @@
+#ifndef TENSORFLOW_UTIL_WORK_SHARDER_H_
+#define TENSORFLOW_UTIL_WORK_SHARDER_H_
+
+#include <functional>
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+
+namespace tensorflow {
+
+// Shards the "total" unit of work assuming each unit of work having
+// roughly "cost_per_unit". Each unit of work is indexed 0, 1, ...,
+// total - 1. Each shard contains 1 or more units of work and the
+// total cost of each shard is roughly the same. The total number of
+// shards is no more than num_workers. The calling thread and the
+// "workers" are used to compute each shard (calling work(start,
+// limit). A common configuration is that "workers" is a thread pool
+// with "num_workers" threads.
+//
+// "work" should be a callable taking (int64, int64) arguments.
+// work(start, limit) computes the work units from [start,
+// limit), i.e., [start, limit) is a shard.
+//
+// REQUIRES: num_workers >= 0
+// REQUIRES: workers != nullptr
+// REQUIRES: total >= 0
+// REQUIRES: cost_per_unit >= 0
+void Shard(int num_workers, thread::ThreadPool* workers, int64 total,
+ int64 cost_per_unit, std::function<void(int64, int64)> work);
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_WORK_SHARDER_H_
diff --git a/tensorflow/core/util/work_sharder_test.cc b/tensorflow/core/util/work_sharder_test.cc
new file mode 100644
index 0000000000..d9792c0e8d
--- /dev/null
+++ b/tensorflow/core/util/work_sharder_test.cc
@@ -0,0 +1,57 @@
+#include "tensorflow/core/util/work_sharder.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace {
+
+void RunSharding(int64 num_workers, int64 total, int64 cost_per_unit) {
+ thread::ThreadPool threads(Env::Default(), "test", 16);
+ mutex mu;
+ int64 num_shards = 0;
+ int64 num_done_work = 0;
+ std::vector<bool> work(total, false);
+ Shard(num_workers, &threads, total, cost_per_unit,
+ [&mu, &num_shards, &num_done_work, &work](int start, int limit) {
+ VLOG(1) << "Shard [" << start << "," << limit << ")";
+ mutex_lock l(mu);
+ ++num_shards;
+ for (; start < limit; ++start) {
+ EXPECT_FALSE(work[start]); // No duplicate
+ ++num_done_work;
+ work[start] = true;
+ }
+ });
+ EXPECT_LE(num_shards, num_workers + 1);
+ EXPECT_EQ(num_done_work, total);
+ LOG(INFO) << num_workers << " " << total << " " << cost_per_unit << " "
+ << num_shards;
+}
+
+TEST(Shard, Basic) {
+ for (auto workers : {0, 1, 2, 3, 5, 7, 10, 11, 15, 100, 1000}) {
+ for (auto total : {0, 1, 7, 10, 64, 100, 256, 1000, 9999}) {
+ for (auto cost_per_unit : {0, 1, 11, 102, 1003, 10005, 1000007}) {
+ RunSharding(workers, total, cost_per_unit);
+ }
+ }
+ }
+}
+
+void BM_Sharding(int iters, int arg) {
+ thread::ThreadPool threads(Env::Default(), "test", 16);
+ const int64 total = 1LL << 30;
+ auto lambda = [](int64 start, int64 limit) {};
+ auto work = std::cref(lambda);
+ for (; iters > 0; iters -= arg) {
+ Shard(arg - 1, &threads, total, 1, work);
+ }
+}
+BENCHMARK(BM_Sharding)->Range(1, 128);
+
+} // namespace
+} // namespace tensorflow