aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/BUILD73
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt34
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt29
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt40
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt22
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt31
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt27
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt3
-rw-r--r--tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt20
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt14
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt29
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Substr.pbtxt6
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt16
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt15
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt15
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt23
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.cc23
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.h14
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc2
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc49
-rw-r--r--tensorflow/core/common_runtime/direct_session.h12
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc121
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc47
-rw-r--r--tensorflow/core/common_runtime/eager/context.h23
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc2
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.cc16
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.h11
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device_test.cc4
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.cc16
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.h1
-rw-r--r--tensorflow/core/common_runtime/executor.cc58
-rw-r--r--tensorflow/core/common_runtime/function.cc45
-rw-r--r--tensorflow/core/common_runtime/function_test.cc22
-rw-r--r--tensorflow/core/common_runtime/gpu/cuda_host_allocator.h12
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc14
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h42
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc90
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc10
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h10
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc20
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h18
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc35
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc64
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.h9
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_process_state.cc162
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_process_state.h58
-rw-r--r--tensorflow/core/common_runtime/gpu/pool_allocator_test.cc68
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.cc45
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.h7
-rw-r--r--tensorflow/core/common_runtime/graph_runner.cc4
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator.h184
-rw-r--r--tensorflow/core/common_runtime/placer.cc54
-rw-r--r--tensorflow/core/common_runtime/placer.h2
-rw-r--r--tensorflow/core/common_runtime/placer_test.cc49
-rw-r--r--tensorflow/core/common_runtime/pool_allocator.cc46
-rw-r--r--tensorflow/core/common_runtime/pool_allocator.h27
-rw-r--r--tensorflow/core/common_runtime/process_state.cc71
-rw-r--r--tensorflow/core/common_runtime/process_state.h15
-rw-r--r--tensorflow/core/common_runtime/renamed_device.h7
-rw-r--r--tensorflow/core/common_runtime/rendezvous_util.cc1
-rw-r--r--tensorflow/core/common_runtime/session_state.cc2
-rw-r--r--tensorflow/core/common_runtime/single_threaded_cpu_device.h1
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.cc188
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.h137
-rw-r--r--tensorflow/core/common_runtime/tracing_device.h5
-rw-r--r--tensorflow/core/common_runtime/visitable_allocator.h79
-rw-r--r--tensorflow/core/debug/debug_io_utils.cc2
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc8
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc15
-rw-r--r--tensorflow/core/example/example.proto8
-rw-r--r--tensorflow/core/framework/allocator.cc29
-rw-r--r--tensorflow/core/framework/allocator.h39
-rw-r--r--tensorflow/core/framework/allocator_registry.h1
-rw-r--r--tensorflow/core/framework/attr_value_util_test.cc1
-rw-r--r--tensorflow/core/framework/dataset.cc10
-rw-r--r--tensorflow/core/framework/dataset.h190
-rw-r--r--tensorflow/core/framework/dataset_stateful_op_whitelist.h33
-rw-r--r--tensorflow/core/framework/device_base.h10
-rw-r--r--tensorflow/core/framework/function.cc24
-rw-r--r--tensorflow/core/framework/function.h4
-rw-r--r--tensorflow/core/framework/function_testlib.cc50
-rw-r--r--tensorflow/core/framework/function_testlib.h6
-rw-r--r--tensorflow/core/framework/model.cc365
-rw-r--r--tensorflow/core/framework/model.h379
-rw-r--r--tensorflow/core/framework/node_def_util.cc8
-rw-r--r--tensorflow/core/framework/node_def_util.h4
-rw-r--r--tensorflow/core/framework/op_kernel.cc20
-rw-r--r--tensorflow/core/framework/op_kernel.h31
-rw-r--r--tensorflow/core/framework/op_segment.cc8
-rw-r--r--tensorflow/core/framework/op_segment.h4
-rw-r--r--tensorflow/core/framework/resource_mgr.cc2
-rw-r--r--tensorflow/core/framework/resource_mgr.h6
-rw-r--r--tensorflow/core/framework/stats_aggregator.h3
-rw-r--r--tensorflow/core/framework/tensor.h3
-rw-r--r--tensorflow/core/framework/tensor_test.cc1
-rw-r--r--tensorflow/core/framework/tensor_util.h1
-rw-r--r--tensorflow/core/framework/types.h3
-rw-r--r--tensorflow/core/framework/variant.cc25
-rw-r--r--tensorflow/core/framework/variant.h60
-rw-r--r--tensorflow/core/framework/variant_encode_decode.h32
-rw-r--r--tensorflow/core/framework/variant_op_copy_test.cc6
-rw-r--r--tensorflow/core/framework/variant_op_registry.cc85
-rw-r--r--tensorflow/core/framework/variant_op_registry.h216
-rw-r--r--tensorflow/core/framework/variant_op_registry_test.cc96
-rw-r--r--tensorflow/core/framework/variant_tensor_data.cc22
-rw-r--r--tensorflow/core/framework/variant_tensor_data.h10
-rw-r--r--tensorflow/core/framework/variant_test.cc15
-rw-r--r--tensorflow/core/graph/graph_constructor.cc8
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc9
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc1
-rw-r--r--tensorflow/core/graph/testlib.cc27
-rw-r--r--tensorflow/core/graph/testlib.h9
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc245
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc378
-rw-r--r--tensorflow/core/grappler/costs/utils.cc8
-rw-r--r--tensorflow/core/grappler/costs/utils.h2
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc8
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer.h2
-rw-r--r--tensorflow/core/grappler/inputs/utils.cc7
-rw-r--r--tensorflow/core/grappler/inputs/utils.h4
-rw-r--r--tensorflow/core/grappler/op_types.cc39
-rw-r--r--tensorflow/core/grappler/op_types.h2
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD71
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc59
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc88
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc173
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h7
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc230
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD44
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization.cc106
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization.h47
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc94
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc111
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector.h115
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc138
-rw-r--r--tensorflow/core/grappler/optimizers/function_api_info.cc167
-rw-r--r--tensorflow/core/grappler/optimizers/function_api_info.h80
-rw-r--r--tensorflow/core/grappler/optimizers/function_api_info_test.cc160
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc51
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h3
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc46
-rw-r--r--tensorflow/core/grappler/utils/functions.cc32
-rw-r--r--tensorflow/core/grappler/utils/functions.h13
-rw-r--r--tensorflow/core/grappler/utils/scc.h7
-rw-r--r--tensorflow/core/kernels/BUILD116
-rw-r--r--tensorflow/core/kernels/bias_op.cc13
-rw-r--r--tensorflow/core/kernels/boosted_trees/BUILD16
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantile_ops.cc453
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/BUILD4
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h96
-rw-r--r--tensorflow/core/kernels/conditional_accumulator.h6
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base.cc13
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base.h3
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base_op.h3
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_op.cc3
-rw-r--r--tensorflow/core/kernels/conv_2d.h45
-rw-r--r--tensorflow/core/kernels/conv_3d.h43
-rw-r--r--tensorflow/core/kernels/conv_grad_filter_ops.cc3
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc6
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.cc11
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.h10
-rw-r--r--tensorflow/core/kernels/conv_grad_ops_3d.cc1330
-rw-r--r--tensorflow/core/kernels/conv_ops.cc19
-rw-r--r--tensorflow/core/kernels/conv_ops_3d.cc20
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.h6
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc81
-rw-r--r--tensorflow/core/kernels/data/BUILD57
-rw-r--r--tensorflow/core/kernels/data/batch_dataset_op.cc5
-rw-r--r--tensorflow/core/kernels/data/cache_dataset_ops.cc8
-rw-r--r--tensorflow/core/kernels/data/captured_function.cc173
-rw-r--r--tensorflow/core/kernels/data/captured_function.h30
-rw-r--r--tensorflow/core/kernels/data/concatenate_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/dataset_ops.cc2
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.cc6
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.h6
-rw-r--r--tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/filter_by_component_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc17
-rw-r--r--tensorflow/core/kernels/data/flat_map_dataset_op.cc19
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc48
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.h2
-rw-r--r--tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc53
-rw-r--r--tensorflow/core/kernels/data/interleave_dataset_op.cc20
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc23
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.h2
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc66
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc25
-rw-r--r--tensorflow/core/kernels/data/map_defun_op.cc242
-rw-r--r--tensorflow/core/kernels/data/model_dataset_op.cc146
-rw-r--r--tensorflow/core/kernels/data/optimize_dataset_op.cc6
-rw-r--r--tensorflow/core/kernels/data/optional_ops.cc16
-rw-r--r--tensorflow/core/kernels/data/optional_ops.h2
-rw-r--r--tensorflow/core/kernels/data/padded_batch_dataset_op.cc5
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc664
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc56
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc51
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.h2
-rw-r--r--tensorflow/core/kernels/data/parse_example_dataset_op.cc11
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner.cc15
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner.h2
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner_test.cc2
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc53
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.h2
-rw-r--r--tensorflow/core/kernels/data/random_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/range_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/reader_dataset_ops.cc4
-rw-r--r--tensorflow/core/kernels/data/repeat_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/scan_dataset_op.cc23
-rw-r--r--tensorflow/core/kernels/data/shuffle_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/single_threaded_executor.cc380
-rw-r--r--tensorflow/core/kernels/data/single_threaded_executor.h62
-rw-r--r--tensorflow/core/kernels/data/single_threaded_executor_test.cc332
-rw-r--r--tensorflow/core/kernels/data/skip_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/slide_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/sql/driver_manager.cc4
-rw-r--r--tensorflow/core/kernels/data/sql/driver_manager.h4
-rw-r--r--tensorflow/core/kernels/data/sql/query_connection.h3
-rw-r--r--tensorflow/core/kernels/data/sql/sqlite_query_connection.cc4
-rw-r--r--tensorflow/core/kernels/data/sql/sqlite_query_connection.h4
-rw-r--r--tensorflow/core/kernels/data/sql_dataset_ops.cc5
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_ops.cc2
-rw-r--r--tensorflow/core/kernels/data/stats_dataset_ops.cc2
-rw-r--r--tensorflow/core/kernels/data/take_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/tensor_dataset_op.cc21
-rw-r--r--tensorflow/core/kernels/data/tensor_queue_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/tensor_slice_dataset_op.cc15
-rw-r--r--tensorflow/core/kernels/data/unbatch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/window_dataset.cc2
-rw-r--r--tensorflow/core/kernels/data/window_dataset.h2
-rw-r--r--tensorflow/core/kernels/data/window_dataset_op.cc219
-rw-r--r--tensorflow/core/kernels/data/writer_ops.cc3
-rw-r--r--tensorflow/core/kernels/data/zip_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/debug_ops.h4
-rw-r--r--tensorflow/core/kernels/decode_bmp_op.cc7
-rw-r--r--tensorflow/core/kernels/decode_csv_op.cc3
-rw-r--r--tensorflow/core/kernels/dynamic_stitch_op.cc4
-rw-r--r--tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h602
-rw-r--r--tensorflow/core/kernels/eigen_backward_spatial_convolutions.h48
-rw-r--r--tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc31
-rw-r--r--tensorflow/core/kernels/eigen_benchmark.h304
-rw-r--r--tensorflow/core/kernels/eigen_benchmark_cpu_test.cc422
-rw-r--r--tensorflow/core/kernels/eigen_cuboid_convolution.h1356
-rw-r--r--tensorflow/core/kernels/eigen_volume_patch.h1
-rw-r--r--tensorflow/core/kernels/fuzzing/BUILD2
-rw-r--r--tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc45
-rw-r--r--tensorflow/core/kernels/gather_functor.h1
-rw-r--r--tensorflow/core/kernels/gather_nd_op_cpu_impl.h15
-rw-r--r--tensorflow/core/kernels/gpu_utils.h3
-rw-r--r--tensorflow/core/kernels/list_kernels.cc12
-rw-r--r--tensorflow/core/kernels/list_kernels.cu.cc3
-rw-r--r--tensorflow/core/kernels/list_kernels.h21
-rw-r--r--tensorflow/core/kernels/logistic-loss.h2
-rw-r--r--tensorflow/core/kernels/lookup_table_op.cc26
-rw-r--r--tensorflow/core/kernels/loss_test.cc174
-rw-r--r--tensorflow/core/kernels/map_stage_op.cc10
-rw-r--r--tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc4
-rw-r--r--tensorflow/core/kernels/mirror_pad_op.h1
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc31
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc42
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc49
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops_test.cc407
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.cc28
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc1
-rw-r--r--tensorflow/core/kernels/mkl_softmax_op.cc38
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc130
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc51
-rw-r--r--tensorflow/core/kernels/poisson-loss.h109
-rw-r--r--tensorflow/core/kernels/qr_op_complex128.cc8
-rw-r--r--tensorflow/core/kernels/qr_op_double.cc8
-rw-r--r--tensorflow/core/kernels/qr_op_float.cc8
-rw-r--r--tensorflow/core/kernels/queue_ops.cc2
-rw-r--r--tensorflow/core/kernels/reduction_ops_max.cc2
-rw-r--r--tensorflow/core/kernels/reduction_ops_sum.cc10
-rw-r--r--tensorflow/core/kernels/regex_full_match_op.cc33
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_utils.cc26
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc2
-rw-r--r--tensorflow/core/kernels/reverse_sequence_op.cc5
-rw-r--r--tensorflow/core/kernels/save_restore_tensor.cc9
-rw-r--r--tensorflow/core/kernels/save_restore_v2_ops.cc4
-rw-r--r--tensorflow/core/kernels/sdca_internal.cc2
-rw-r--r--tensorflow/core/kernels/sdca_ops.cc3
-rw-r--r--tensorflow/core/kernels/shape_op_test.cc10
-rw-r--r--tensorflow/core/kernels/sparse_conditional_accumulator.h4
-rw-r--r--tensorflow/core/kernels/sparse_conditional_accumulator_op.cc4
-rw-r--r--tensorflow/core/kernels/split_op.cc7
-rw-r--r--tensorflow/core/kernels/stack_ops.cc26
-rw-r--r--tensorflow/core/kernels/string_strip_op.cc2
-rw-r--r--tensorflow/core/kernels/substr_op.cc50
-rw-r--r--tensorflow/core/kernels/substr_op_test.cc105
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc4
-rw-r--r--tensorflow/core/kernels/typed_conditional_accumulator_base.h5
-rw-r--r--tensorflow/core/kernels/unravel_index_op.cc10
-rw-r--r--tensorflow/core/kernels/whole_file_read_ops.cc2
-rw-r--r--tensorflow/core/lib/core/errors.h18
-rw-r--r--tensorflow/core/lib/core/status.h1
-rw-r--r--tensorflow/core/lib/core/stringpiece.h6
-rw-r--r--tensorflow/core/lib/gtl/inlined_vector.h665
-rw-r--r--tensorflow/core/lib/gtl/inlined_vector_test.cc898
-rw-r--r--tensorflow/core/lib/gtl/optional.h853
-rw-r--r--tensorflow/core/lib/gtl/optional_test.cc1098
-rw-r--r--tensorflow/core/lib/io/block_builder.h1
-rw-r--r--tensorflow/core/lib/io/path.h1
-rw-r--r--tensorflow/core/lib/io/record_reader.cc3
-rw-r--r--tensorflow/core/lib/io/record_reader.h8
-rw-r--r--tensorflow/core/lib/io/record_writer.cc15
-rw-r--r--tensorflow/core/lib/io/record_writer.h34
-rw-r--r--tensorflow/core/lib/io/recordio_test.cc2
-rw-r--r--tensorflow/core/lib/io/table_test.cc2
-rw-r--r--tensorflow/core/lib/io/zlib_outputbuffer.cc2
-rw-r--r--tensorflow/core/lib/io/zlib_outputbuffer.h2
-rw-r--r--tensorflow/core/lib/monitoring/collection_registry.h1
-rw-r--r--tensorflow/core/lib/monitoring/metric_def.h1
-rw-r--r--tensorflow/core/lib/png/png_io.h1
-rw-r--r--tensorflow/core/lib/wav/wav_io.cc5
-rw-r--r--tensorflow/core/ops/boosted_trees_ops.cc125
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt727
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc2
-rw-r--r--tensorflow/core/ops/dataset_ops.cc38
-rw-r--r--tensorflow/core/ops/image_ops.cc15
-rw-r--r--tensorflow/core/ops/ops.pbtxt353
-rw-r--r--tensorflow/core/ops/parsing_ops.cc7
-rw-r--r--tensorflow/core/ops/parsing_ops_test.cc7
-rw-r--r--tensorflow/core/ops/sdca_ops.cc2
-rw-r--r--tensorflow/core/ops/string_ops.cc6
-rw-r--r--tensorflow/core/platform/abi.cc4
-rw-r--r--tensorflow/core/platform/abi.h3
-rw-r--r--tensorflow/core/platform/cloud/curl_http_request.cc4
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc16
-rw-r--r--tensorflow/core/platform/cloud/oauth_client.cc4
-rw-r--r--tensorflow/core/platform/cloud/oauth_client_test.cc6
-rw-r--r--tensorflow/core/platform/cloud/retrying_file_system.h2
-rw-r--r--tensorflow/core/platform/cloud/retrying_file_system_test.cc2
-rw-r--r--tensorflow/core/platform/cord.h26
-rw-r--r--tensorflow/core/platform/default/build_config.bzl2
-rw-r--r--tensorflow/core/platform/default/build_config_root.bzl86
-rw-r--r--tensorflow/core/platform/default/cord.h (renamed from tensorflow/core/lib/gtl/optional.cc)17
-rw-r--r--tensorflow/core/platform/default/device_tracer.cc5
-rw-r--r--tensorflow/core/platform/env_test.cc7
-rw-r--r--tensorflow/core/platform/file_system.h8
-rw-r--r--tensorflow/core/platform/hadoop/hadoop_file_system.cc2
-rw-r--r--tensorflow/core/platform/posix/posix_file_system.cc2
-rw-r--r--tensorflow/core/platform/s3/s3_file_system.cc2
-rw-r--r--tensorflow/core/platform/tracing.h4
-rw-r--r--tensorflow/core/platform/windows/windows_file_system.cc2
-rw-r--r--tensorflow/core/protobuf/config.proto9
-rw-r--r--tensorflow/core/public/version.h4
-rw-r--r--tensorflow/core/util/ctc/ctc_beam_entry.h2
-rw-r--r--tensorflow/core/util/ctc/ctc_beam_scorer.h2
-rw-r--r--tensorflow/core/util/ctc/ctc_beam_search.h1
-rw-r--r--tensorflow/core/util/ctc/ctc_decoder.h2
-rw-r--r--tensorflow/core/util/ctc/ctc_loss_util.h2
-rw-r--r--tensorflow/core/util/mkl_util.h39
-rw-r--r--tensorflow/core/util/sparse/group_iterator.cc10
-rw-r--r--tensorflow/core/util/sparse/group_iterator.h4
-rw-r--r--tensorflow/core/util/status_util.h36
-rw-r--r--tensorflow/core/util/status_util_test.cc36
-rw-r--r--tensorflow/core/util/tensor_bundle/naming.h1
373 files changed, 14768 insertions, 6982 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 84b11024fd..9bcf5b0865 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -85,11 +85,12 @@ load(
"tf_cc_tests",
"tf_copts",
"tf_cuda_library",
+ "tf_features_nomodules_if_android",
"tf_gen_op_libs",
"tf_generate_proto_text_sources",
"tf_genrule_cmd_append_to_srcs",
"tf_opts_nortti_if_android",
- "tf_features_nomodules_if_android",
+ "transitive_hdrs",
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl")
load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
@@ -120,16 +121,16 @@ load(
"tf_additional_libdevice_srcs",
"tf_additional_minimal_lib_srcs",
"tf_additional_mpi_lib_defines",
- "tf_additional_proto_hdrs",
"tf_additional_proto_compiler_hdrs",
+ "tf_additional_proto_hdrs",
"tf_additional_proto_srcs",
"tf_additional_test_deps",
"tf_additional_test_srcs",
"tf_additional_verbs_lib_defines",
"tf_jspb_proto_library",
"tf_kernel_tests_linkstatic",
- "tf_lib_proto_parsing_deps",
"tf_lib_proto_compiler_deps",
+ "tf_lib_proto_parsing_deps",
"tf_nano_proto_library",
"tf_platform_hdrs",
"tf_platform_srcs",
@@ -168,6 +169,7 @@ COMMON_PROTO_SRCS = [
"example/example.proto",
"example/feature.proto",
"framework/allocation_description.proto",
+ "framework/api_def.proto",
"framework/attr_value.proto",
"framework/cost_graph.proto",
"framework/device_attributes.proto",
@@ -179,7 +181,6 @@ COMMON_PROTO_SRCS = [
"framework/log_memory.proto",
"framework/node_def.proto",
"framework/op_def.proto",
- "framework/api_def.proto",
"framework/reader_base.proto",
"framework/remote_fused_graph_execute_info.proto",
"framework/resource_handle.proto",
@@ -299,6 +300,7 @@ filegroup(
name = "platform_base_hdrs",
srcs = [
"platform/byte_order.h",
+ "platform/cord.h",
"platform/env_time.h",
"platform/logging.h",
"platform/macros.h",
@@ -695,7 +697,24 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":lib_internal",
+ "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+cc_library(
+ name = "feature_util",
+ srcs = ["example/feature_util.cc"],
+ hdrs = [
+ "example/feature_util.h",
+ "platform/types.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":core_stringpiece",
+ ":platform_protobuf",
+ ":protos_all_cc",
],
)
@@ -703,6 +722,7 @@ cc_library(
name = "abi",
srcs = ["platform/abi.cc"],
hdrs = ["platform/abi.h"],
+ deps = [":platform_base"],
)
cc_library(
@@ -857,7 +877,6 @@ tf_cuda_library(
"util/bcast.h",
"util/cuda_kernel_helper.h",
"util/device_name_utils.h",
- "util/env_var.h",
"util/events_writer.h",
"util/example_proto_fast_parsing.h",
"util/example_proto_helper.h",
@@ -872,7 +891,6 @@ tf_cuda_library(
"util/sparse/sparse_tensor.h",
"util/stat_summarizer.h",
"util/stat_summarizer_options.h",
- "util/status_util.h",
"util/stream_executor_util.h",
"util/strided_slice_op.h",
"util/tensor_format.h",
@@ -939,15 +957,6 @@ cc_library(
)
cc_library(
- name = "status_util",
- hdrs = ["util/status_util.h"],
- deps = [
- ":graph",
- ":lib",
- ],
-)
-
-cc_library(
name = "reader_base",
srcs = ["framework/reader_base.cc"],
hdrs = ["framework/reader_base.h"],
@@ -1347,6 +1356,7 @@ cc_library(
"//tensorflow/core/kernels:mkl_relu_op",
"//tensorflow/core/kernels:mkl_reshape_op",
"//tensorflow/core/kernels:mkl_softmax_op",
+ "//tensorflow/core/kernels:mkl_transpose_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
"//tensorflow/core/kernels:mkl_aggregate_ops",
]) + if_cuda([
@@ -1418,9 +1428,11 @@ cc_library(
":test",
":testlib_ops",
"//tensorflow/cc:scope",
+ "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:ops_testutil",
"//tensorflow/core/kernels:ops_util",
+ "//tensorflow/core/kernels:random_ops",
],
)
@@ -1910,6 +1922,13 @@ tf_pyclif_proto_library(
)
tf_pyclif_proto_library(
+ name = "protobuf/config_pyclif",
+ proto_lib = ":protos_all_cc",
+ proto_srcfile = "protobuf/config.proto",
+ visibility = ["//visibility:public"],
+)
+
+tf_pyclif_proto_library(
name = "protobuf/device_properties_pyclif",
proto_lib = ":protos_all_cc",
proto_srcfile = "protobuf/device_properties.proto",
@@ -2048,6 +2067,7 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [
"platform/snappy.h",
"platform/tensor_coding.h",
"platform/tracing.h",
+ "util/env_var.h",
]
# Replicated for lib_internal and lib_internal_impl.
@@ -2087,6 +2107,7 @@ cc_library(
"platform/*.cc",
"platform/profile_utils/**/*.cc",
"framework/resource_handle.cc",
+ "util/env_var.cc",
],
exclude = [
"**/*test*",
@@ -2442,7 +2463,6 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
"framework/unique_tensor_references.h",
"framework/variant.h",
"util/command_line_flags.h",
- "util/env_var.h",
"util/equal_graph_def.h",
"util/presized_cuckoo_map.h",
"util/tensor_slice_set.h",
@@ -2518,6 +2538,7 @@ tf_cuda_library(
"util/memmapped_file_system_writer.*",
"util/stats_calculator.*",
"util/version_info.cc",
+ "util/env_var.cc",
],
) + select({
"//tensorflow:windows": [],
@@ -2762,7 +2783,6 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/step_stats_collector.h",
"common_runtime/threadpool_device.h",
"common_runtime/tracing_device.h",
- "common_runtime/visitable_allocator.h",
"common_runtime/process_state.h",
"common_runtime/pool_allocator.h",
"graph/gradients.h",
@@ -3229,12 +3249,10 @@ tf_cc_tests(
"lib/gtl/edit_distance_test.cc",
"lib/gtl/flatmap_test.cc",
"lib/gtl/flatset_test.cc",
- "lib/gtl/inlined_vector_test.cc",
"lib/gtl/int_type_test.cc",
"lib/gtl/iterator_range_test.cc",
"lib/gtl/manual_constructor_test.cc",
"lib/gtl/map_util_test.cc",
- "lib/gtl/optional_test.cc",
"lib/gtl/top_n_test.cc",
"lib/hash/crc32c_test.cc",
"lib/hash/hash_test.cc",
@@ -3560,7 +3578,6 @@ tf_cc_tests(
"util/semver_test.cc",
"util/sparse/sparse_tensor_test.cc",
"util/stat_summarizer_test.cc",
- "util/status_util_test.cc",
"util/tensor_format_test.cc",
"util/tensor_slice_reader_test.cc",
"util/tensor_slice_set_test.cc",
@@ -3585,7 +3602,6 @@ tf_cc_tests(
":ops",
":protos_all_cc",
":protos_test_cc",
- ":status_util",
":test",
":test_main",
":testlib",
@@ -3724,6 +3740,7 @@ tf_cc_test_mkl(
":core_cpu_internal",
":framework",
":framework_internal",
+ ":lib",
":test",
":test_main",
":testlib",
@@ -4078,6 +4095,7 @@ tf_cuda_cc_test(
":testlib",
"//third_party/eigen3",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/core/kernels:collective_ops",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:dense_update_ops",
@@ -4119,6 +4137,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops",
# Link with support for TensorFlow Debugger (tfdbg).
"//tensorflow/core/debug",
+ "//tensorflow/core/kernels:collective_ops",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:dense_update_ops",
@@ -4701,6 +4720,18 @@ cc_library(
] + tf_additional_libdevice_deps(),
)
+transitive_hdrs(
+ name = "headers",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:stream_executor",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets go here (must be at the end).
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt
new file mode 100644
index 0000000000..cdaeb5091c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt
@@ -0,0 +1,34 @@
+op {
+ graph_op_name: "BoostedTreesBucketize"
+ visibility: HIDDEN
+ in_arg {
+ name: "float_values"
+ description: <<END
+float; List of Rank 2 Tensor each containing float values for a single feature.
+END
+ }
+ in_arg {
+ name: "bucket_boundaries"
+ description: <<END
+float; List of Rank 1 Tensors each containing the bucket boundaries for a single
+feature.
+END
+ }
+ out_arg {
+ name: "buckets"
+ description: <<END
+int; List of Rank 2 Tensors each containing the bucketized values for a single feature.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+inferred int; number of features.
+END
+ }
+ summary: "Bucketize each feature based on bucket boundaries."
+ description: <<END
+An op that returns a list of float tensors, where each tensor represents the
+bucketized values for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt
new file mode 100644
index 0000000000..20da1295f6
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt
@@ -0,0 +1,29 @@
+op {
+ graph_op_name: "BoostedTreesCreateQuantileStreamResource"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource; Handle to quantile stream resource.
+END
+ }
+ in_arg {
+ name: "epsilon"
+ description: <<END
+float; The required approximation error of the stream resource.
+END
+ }
+ in_arg {
+ name: "num_streams"
+ description: <<END
+int; The number of streams managed by the resource that shares the same epsilon.
+END
+ }
+ attr {
+ name: "max_elements"
+ description : <<END
+int; The maximum number of data points that can be fed to the stream.
+END
+ }
+ summary: "Create the Resource for Quantile Streams."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt
new file mode 100644
index 0000000000..ca111af312
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt
@@ -0,0 +1,40 @@
+op {
+ graph_op_name: "BoostedTreesMakeQuantileSummaries"
+ visibility: HIDDEN
+ in_arg {
+ name: "float_values"
+ description: <<END
+float; List of Rank 2 Tensors each containing values for a single feature.
+END
+ }
+ in_arg {
+ name: "example_weights"
+ description: <<END
+float; Rank 1 Tensor with weights per instance.
+END
+ }
+ in_arg {
+ name: "epsilon"
+ description: <<END
+float; The required maximum approximation error.
+END
+ }
+ out_arg {
+ name: "summaries"
+ description: <<END
+float; List of Rank 2 Tensors each containing the quantile summary (value, weight,
+min_rank, max_rank) of a single feature.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+int; Inferred from the size of float_values.
+The number of float features.
+END
+ }
+ summary: "Makes the summary of quantiles for the batch."
+ description: <<END
+An op that takes a list of tensors and outputs the quantile summaries for each tensor.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt
new file mode 100644
index 0000000000..bbeecbf32b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt
@@ -0,0 +1,22 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceAddSummaries"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+ }
+ in_arg {
+ name: "summaries"
+ description: <<END
+string; List of Rank 2 Tensor each containing the summaries for a single feature.
+END
+ }
+ summary: "Add the quantile summaries to each quantile stream resource."
+ description: <<END
+An op that adds a list of quantile summaries to a quantile stream resource. Each
+summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank)
+for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt
new file mode 100644
index 0000000000..2fd94efa10
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt
@@ -0,0 +1,31 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceFlush"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+ }
+ in_arg {
+ name: "num_buckets",
+ description: <<END
+int; approximate number of buckets unless using generate_quantiles.
+END
+ }
+ attr {
+ name: "generate_quantiles"
+ description: <<END
+bool; If True, the output will be the num_quantiles for each stream where the ith
+entry is the ith quantile of the input with an approximation error of epsilon.
+Duplicate values may be present.
+If False, the output will be the points in the histogram that we got which roughly
+translates to 1/epsilon boundaries and without any duplicates.
+Default to False.
+END
+ }
+ summary: "Flush the summaries for a quantile stream resource."
+ description: <<END
+An op that flushes the summaries for a quantile stream resource.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt
new file mode 100644
index 0000000000..206672802f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt
@@ -0,0 +1,27 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+ }
+ out_arg {
+ name: "bucket_boundaries"
+ description: <<END
+float; List of Rank 1 Tensors each containing the bucket boundaries for a feature.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+inferred int; number of features to get bucket boundaries for.
+END
+ }
+ summary: "Generate the bucket boundaries for each feature based on accumulated summaries."
+ description: <<END
+An op that returns a list of float tensors for a quantile stream resource. Each
+tensor is Rank 1 containing bucket boundaries for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt
new file mode 100644
index 0000000000..cb7786c051
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceHandleOp"
+ visibility: HIDDEN
+ summary: "Creates a handle to a BoostedTreesQuantileStreamResource."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt
index e39213cbc7..440800704e 100644
--- a/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt
@@ -11,7 +11,8 @@ END
name: "record_defaults"
description: <<END
One tensor per column of the input record, with either a
-scalar default value for that column or empty if the column is required.
+scalar default value for that column or an empty vector if the column is
+required.
END
}
out_arg {
diff --git a/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt
new file mode 100644
index 0000000000..758eeb96f0
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt
@@ -0,0 +1,20 @@
+op {
+ graph_op_name: "IsBoostedTreesQuantileStreamResourceInitialized"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource; The reference to quantile stream resource handle.
+END
+ }
+ out_arg {
+ name: "is_initialized"
+ description: <<END
+bool; True if the resource is initialized, False otherwise.
+END
+ }
+ summary: "Checks whether a quantile stream has been initialized."
+ description: <<END
+An Op that checks if quantile stream resource is initialized.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt
new file mode 100644
index 0000000000..171add16d4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt
@@ -0,0 +1,14 @@
+op {
+ graph_op_name: "ModelDataset"
+ visibility: HIDDEN
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ summary: "Identity transformation that models performance."
+ description: <<END
+Identity transformation that models performance.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt
new file mode 100644
index 0000000000..27bc4013c3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ParallelInterleaveDatasetV2"
+ visibility: HIDDEN
+ attr {
+ name: "f"
+ description: <<END
+A function mapping elements of `input_dataset`, concatenated with
+`other_arguments`, to a Dataset variant that contains elements matching
+`output_types` and `output_shapes`.
+END
+ }
+ summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
index 8cef243aee..30fd97a0d7 100644
--- a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
@@ -9,7 +9,7 @@ END
in_arg {
name: "pattern"
description: <<END
-A 1-D string tensor of the regular expression to match the input.
+A scalar string tensor containing the regular expression to match the input.
END
}
out_arg {
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
index 35f55fe106..d33a36ce06 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
index 70a07d9b4c..afdc39da96 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
index b2e3eece38..026b5b3991 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
index 7bac02e23d..a168eed87f 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
index a73306a892..876b860824 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt
new file mode 100644
index 0000000000..6d9d9908ca
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt
@@ -0,0 +1,29 @@
+op {
+ graph_op_name: "StaticRegexFullMatch"
+ in_arg {
+ name: "input"
+ description: <<END
+A string tensor of the text to be processed.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A bool tensor with the same shape as `input`.
+END
+ }
+ attr {
+ name: "pattern"
+ description: "The regular expression to match the input."
+ }
+ summary: "Check if the input matches the regex pattern."
+ description: <<END
+The input is a string tensor of any shape. The pattern is the
+regular expression to be matched with every element of the input tensor.
+The boolean values (True or False) of the output tensor indicate
+if the input matches the regex pattern provided.
+
+The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt b/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt
index 8d6fc04847..9a89a4e8e7 100644
--- a/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt
@@ -32,7 +32,7 @@ END
description: <<END
a bitmask where a bit i being 1 means to ignore the begin
value and instead use the largest interval possible. At runtime
-begin[i] will be replaced with `[0, n-1) if `stride[i] > 0` or
+begin[i] will be replaced with `[0, n-1)` if `stride[i] > 0` or
`[-1, n-1]` if `stride[i] < 0`
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
index 8fc1e5cba3..5246090ab3 100644
--- a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
@@ -32,8 +32,10 @@ For each string in the input `Tensor`, creates a substring starting at index
If `len` defines a substring that would extend beyond the length of the input
string, then as many characters as possible are used.
-If `pos` is negative or specifies a character index larger than any of the input
-strings, then an `InvalidArgumentError` is thrown.
+A negative `pos` indicates distance within the string backwards from the end.
+
+If `pos` specifies an index which is out of range for any of the input strings,
+then an `InvalidArgumentError` is thrown.
`pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on
Op creation.
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
index 907c6d2022..7a60e4387a 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
@@ -3,15 +3,14 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
-first dimension.
-END
+A tensor whose shape is a prefix of `data.shape`.END
}
out_arg {
name: "output"
description: <<END
-Has same shape as data, except for dimension 0 which
-has size `num_segments`.
+Has same shape as data, except for the first `segment_ids.rank`
+dimensions, which are replaced with a single dimension which has size
+`num_segments`.
END
}
summary: "Computes the maximum along segments of a tensor."
@@ -24,13 +23,16 @@ This operator is similar to the unsorted segment sum operator found
[(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
Instead of computing the sum over segments, it computes the maximum such that:
-\\(output_i = \max_j data_j\\) where max is over `j` such
-that `segment_ids[j] == i`.
+\\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such
+that `segment_ids[j...] == i`.
If the maximum is empty for a given segment ID `i`, it outputs the smallest
possible value for the specific numeric type,
`output[i] = numeric_limits<T>::lowest()`.
+If the given segment ID `i` is negative, then the corresponding value is
+dropped, and will not be included in the result.
+
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
</div>
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
index 37dd973b23..7e139ddf4d 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
@@ -3,15 +3,15 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
-first dimension.
+A tensor whose shape is a prefix of `data.shape`.
END
}
out_arg {
name: "output"
description: <<END
-Has same shape as data, except for dimension 0 which
-has size `num_segments`.
+Has same shape as data, except for the first `segment_ids.rank`
+dimensions, which are replaced with a single dimension which has size
+`num_segments`.
END
}
summary: "Computes the minimum along segments of a tensor."
@@ -24,11 +24,14 @@ This operator is similar to the unsorted segment sum operator found
[(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
Instead of computing the sum over segments, it computes the minimum such that:
-\\(output_i = \min_j data_j\\) where min is over `j` such
-that `segment_ids[j] == i`.
+\\(output_i = \min_{j...} data_[j...]\\) where min is over tuples `j...` such
+that `segment_ids[j...] == i`.
If the minimum is empty for a given segment ID `i`, it outputs the largest
possible value for the specific numeric type,
`output[i] = numeric_limits<T>::max()`.
+
+If the given segment ID `i` is negative, then the corresponding value is
+dropped, and will not be included in the result.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
index efbc023705..9c8ea3b620 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
@@ -3,15 +3,15 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
-first dimension.
+A tensor whose shape is a prefix of `data.shape`.
END
}
out_arg {
name: "output"
description: <<END
-Has same shape as data, except for dimension 0 which
-has size `num_segments`.
+Has same shape as data, except for the first `segment_ids.rank`
+dimensions, which are replaced with a single dimension which has size
+`num_segments`.
END
}
summary: "Computes the product along segments of a tensor."
@@ -25,9 +25,12 @@ This operator is similar to the unsorted segment sum operator found
Instead of computing the sum over segments, it computes the product of all
entries belonging to a segment such that:
-\\(output_i = \prod_j data_j\\) where the product is over `j` such
-that `segment_ids[j] == i`.
+\\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples
+`j...` such that `segment_ids[j...] == i`.
If there is no entry for a given segment ID `i`, it outputs 1.
+
+If the given segment ID `i` is negative, then the corresponding value is
+dropped, and will not be included in the result.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
index a8874950eb..7e5d9265c2 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
@@ -21,7 +21,7 @@ Read
for an explanation of segments.
Computes a tensor such that
-\\(output[i] = sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
+\\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`
need not be sorted and need not cover all values in the full
range of valid values.
diff --git a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
index 1bc3660479..01387b7527 100644
--- a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
@@ -2,10 +2,31 @@ op {
visibility: HIDDEN
graph_op_name: "WindowDataset"
in_arg {
- name: "window_size"
+ name: "size"
description: <<END
A scalar representing the number of elements to accumulate in a window.
END
}
+ in_arg {
+ name: "shift"
+ description: <<END
+A scalar representing the steps moving the sliding window forward in one
+iteration. It must be positive.
+END
+ }
+ in_arg {
+ name: "stride"
+ description: <<END
+A scalar representing the stride of the input elements of the sliding window.
+It must be positive.
+END
+ }
+ in_arg {
+ name: "drop_remainder"
+ description: <<END
+A scalar representing whether a window should be dropped in case its size is
+smaller than desired.
+END
+ }
summary: "A dataset that creates window datasets from the input dataset."
}
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
index 3bf0532491..3843ea9e60 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -31,7 +31,7 @@ namespace tensorflow {
BFCAllocator::BFCAllocator(SubAllocator* sub_allocator, size_t total_memory,
bool allow_growth, const string& name)
- : suballocator_(sub_allocator),
+ : sub_allocator_(sub_allocator),
name_(name),
free_chunks_list_(kInvalidChunkHandle),
next_allocation_id_(1) {
@@ -72,7 +72,7 @@ BFCAllocator::~BFCAllocator() {
VLOG(2) << "Number of regions allocated: "
<< region_manager_.regions().size();
for (const auto& region : region_manager_.regions()) {
- suballocator_->Free(region.ptr(), region.memory_size());
+ sub_allocator_->Free(region.ptr(), region.memory_size());
}
for (BinNum b = 0; b < kNumBins; b++) {
@@ -108,7 +108,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
// Try allocating.
size_t bytes = std::min(curr_region_allocation_bytes_, available_bytes);
- void* mem_addr = suballocator_->Alloc(alignment, bytes);
+ void* mem_addr = sub_allocator_->Alloc(alignment, bytes);
if (mem_addr == nullptr && !started_backpedal_) {
// Only backpedal once.
started_backpedal_ = true;
@@ -119,7 +119,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
while (mem_addr == nullptr) {
bytes = RoundedBytes(bytes * kBackpedalFactor);
if (bytes < rounded_bytes) break;
- mem_addr = suballocator_->Alloc(alignment, bytes);
+ mem_addr = sub_allocator_->Alloc(alignment, bytes);
}
}
@@ -158,10 +158,6 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
// Insert the chunk into the right bin.
InsertFreeChunkIntoBin(h);
- // Invoke visitors on newly allocated region.
- for (const auto& visitor : region_visitors_) {
- visitor(mem_addr, bytes);
- }
return true;
}
@@ -490,15 +486,6 @@ void BFCAllocator::FreeAndMaybeCoalesce(BFCAllocator::ChunkHandle h) {
InsertFreeChunkIntoBin(coalesced_chunk);
}
-void BFCAllocator::AddAllocVisitor(Visitor visitor) {
- VLOG(1) << "AddVisitor";
- mutex_lock l(lock_);
- region_visitors_.push_back(visitor);
- for (const auto& region : region_manager_.regions()) {
- visitor(region.ptr(), region.memory_size());
- }
-}
-
bool BFCAllocator::TracksAllocationSizes() { return true; }
size_t BFCAllocator::RequestedSize(const void* ptr) {
@@ -596,7 +583,7 @@ string BFCAllocator::RenderOccupancy() {
region_offset += region.memory_size();
}
- return std::string(rendered, resolution);
+ return string(rendered, resolution);
}
void BFCAllocator::DumpMemoryLog(size_t num_bytes) {
diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h
index 20e1dab1d5..364071e066 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.h
+++ b/tensorflow/core/common_runtime/bfc_allocator.h
@@ -23,7 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/allocator_retry.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/macros.h"
@@ -42,7 +42,7 @@ namespace tensorflow {
// coalescing. One assumption we make is that the process using this
// allocator owns pretty much all of the memory, and that nearly
// all requests to allocate memory go through this interface.
-class BFCAllocator : public VisitableAllocator {
+class BFCAllocator : public Allocator {
public:
// Takes ownership of sub_allocator.
BFCAllocator(SubAllocator* sub_allocator, size_t total_memory,
@@ -55,11 +55,6 @@ class BFCAllocator : public VisitableAllocator {
const AllocationAttributes& allocation_attr) override;
void DeallocateRaw(void* ptr) override;
- void AddAllocVisitor(Visitor visitor) override;
-
- // Does nothing, because memory is never freed.
- void AddFreeVisitor(Visitor visitor) override {}
-
bool TracksAllocationSizes() override;
size_t RequestedSize(const void* ptr) override;
@@ -423,7 +418,7 @@ class BFCAllocator : public VisitableAllocator {
// of the available memory.
bool started_backpedal_ = false;
- std::unique_ptr<SubAllocator> suballocator_;
+ std::unique_ptr<SubAllocator> sub_allocator_;
string name_;
// Structures mutable after construction
@@ -435,9 +430,6 @@ class BFCAllocator : public VisitableAllocator {
// Pointer to head of linked list of free Chunks
ChunkHandle free_chunks_list_ GUARDED_BY(lock_);
- // Called once on each region, ASAP.
- std::vector<Visitor> region_visitors_ GUARDED_BY(lock_);
-
// Counter containing the next unique identifier to assign to a
// newly-created chunk.
int64 next_allocation_id_ GUARDED_BY(lock_);
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index f8cb854b52..cf3d1f0b79 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -358,7 +358,7 @@ static Status WrappedTensorDeviceCopy(
#define REGISTER_WRAPPED_TENSOR_COPY(DIRECTION) \
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- Tensor, DIRECTION, "tensorflow::Tensor", WrappedTensorDeviceCopy)
+ Tensor, DIRECTION, WrappedTensorDeviceCopy)
REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index bf1d78ec65..af5d5b17e7 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -451,8 +451,22 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
RunState run_state(step_id, &devices_);
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
#ifndef __ANDROID__
- // Set up for collectives if the RunOption declares a key.
- if (run_options.experimental().collective_graph_key() > 0) {
+ // Set up for collectives if ExecutorsAndKeys declares a key.
+ if (executors_and_keys->collective_graph_key !=
+ BuildGraphOptions::kNoCollectiveGraphKey) {
+ if (run_options.experimental().collective_graph_key() !=
+ BuildGraphOptions::kNoCollectiveGraphKey) {
+ // If a collective_graph_key was specified in run_options, ensure that it
+ // matches what came out of GraphExecutionState::BuildGraph().
+ if (run_options.experimental().collective_graph_key() !=
+ executors_and_keys->collective_graph_key) {
+ return errors::Internal(
+ "collective_graph_key in RunOptions ",
+ run_options.experimental().collective_graph_key(),
+ " should match collective_graph_key from optimized graph ",
+ executors_and_keys->collective_graph_key);
+ }
+ }
if (!collective_executor_mgr_) {
std::unique_ptr<DeviceResolverInterface> drl(
new DeviceResolverLocal(device_mgr_.get()));
@@ -678,10 +692,16 @@ Status DirectSession::Run(const RunOptions& run_options,
// Check if we already have an executor for these arguments.
ExecutorsAndKeys* executors_and_keys;
RunStateArgs run_state_args(run_options.debug_options());
+ run_state_args.collective_graph_key =
+ run_options.experimental().collective_graph_key();
TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
target_nodes, &executors_and_keys,
&run_state_args));
+ {
+ mutex_lock l(collective_graph_key_lock_);
+ collective_graph_key_ = executors_and_keys->collective_graph_key;
+ }
// Configure a call frame for the step, which we use to feed and
// fetch values to and from the executors.
@@ -1116,6 +1136,8 @@ Status DirectSession::CreateExecutors(
BuildGraphOptions options;
options.callable_options = callable_options;
options.use_function_convention = !run_state_args->is_partial_run;
+ options.collective_graph_key =
+ callable_options.run_options().experimental().collective_graph_key();
std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
@@ -1123,9 +1145,9 @@ Status DirectSession::CreateExecutors(
ek->callable_options = callable_options;
std::unordered_map<string, std::unique_ptr<Graph>> graphs;
- TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &func_info->flib_def,
- run_state_args, &ek->input_types,
- &ek->output_types));
+ TF_RETURN_IF_ERROR(CreateGraphs(
+ options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types,
+ &ek->output_types, &ek->collective_graph_key));
if (run_state_args->is_partial_run) {
ek->graph = std::move(run_state_args->graph);
@@ -1180,14 +1202,11 @@ Status DirectSession::CreateExecutors(
auto opseg = device->op_segment();
params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
OpKernel** kernel) {
- // We do not share the kernel via the OpSegment if the node is
- // stateless, or a function.
// NOTE(mrry): We must not share function kernels (implemented
// using `CallOp`) between subgraphs, because `CallOp::handle_`
// is tied to a particular subgraph. Even if the function itself
// is stateful, the `CallOp` that invokes it is not.
- if (!lib->IsStateful(ndef.op()) ||
- lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
+ if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
return lib->CreateKernel(ndef, kernel);
}
auto create_fn = [lib, &ndef](OpKernel** kernel) {
@@ -1200,13 +1219,11 @@ Status DirectSession::CreateExecutors(
create_fn);
};
params.delete_kernel = [lib](OpKernel* kernel) {
- // If the node is stateful, opseg owns it. Otherwise, delete it.
- if (kernel && !lib->IsStateful(kernel->type_string())) {
+ if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
delete kernel;
- }
};
- optimizer.Optimize(lib, options_.env, device, &iter->second,
+ optimizer.Optimize(lib, options_.env, device, &partition_graph,
/*shape_map=*/nullptr);
// TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
@@ -1353,6 +1370,9 @@ Status DirectSession::GetOrCreateExecutors(
}
*callable_options.mutable_run_options()->mutable_debug_options() =
run_state_args->debug_options;
+ callable_options.mutable_run_options()
+ ->mutable_experimental()
+ ->set_collective_graph_key(run_state_args->collective_graph_key);
std::unique_ptr<ExecutorsAndKeys> ek;
std::unique_ptr<FunctionInfo> func_info;
TF_RETURN_IF_ERROR(
@@ -1379,7 +1399,7 @@ Status DirectSession::CreateGraphs(
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args, DataTypeVector* input_types,
- DataTypeVector* output_types) {
+ DataTypeVector* output_types, int64* collective_graph_key) {
mutex_lock l(graph_def_lock_);
std::unique_ptr<ClientGraph> client_graph;
@@ -1403,6 +1423,7 @@ Status DirectSession::CreateGraphs(
TF_RETURN_IF_ERROR(
execution_state->BuildGraph(subgraph_options, &client_graph));
}
+ *collective_graph_key = client_graph->collective_graph_key;
if (subgraph_options.callable_options.feed_size() !=
client_graph->feed_types.size()) {
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 55a6fbce6d..c2cf3c7fd7 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -117,6 +117,9 @@ class DirectSession : public Session {
::tensorflow::Status ReleaseCallable(CallableHandle handle) override;
private:
+ // For access to collective_graph_key_.
+ friend class DirectSessionCollectiveTest;
+
// We create one executor and its dependent library runtime for
// every partition.
struct PerPartitionExecutorsAndLib {
@@ -150,6 +153,8 @@ class DirectSession : public Session {
DataTypeVector output_types;
CallableOptions callable_options;
+
+ int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
};
// A FunctionInfo object is created for every unique set of feeds/fetches.
@@ -203,6 +208,7 @@ class DirectSession : public Session {
string handle;
std::unique_ptr<Graph> graph;
const DebugOptions& debug_options;
+ int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
};
// Initializes the base execution state given the 'graph',
@@ -234,7 +240,7 @@ class DirectSession : public Session {
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args, DataTypeVector* input_types,
- DataTypeVector* output_types);
+ DataTypeVector* output_types, int64* collective_graph_key);
::tensorflow::Status RunInternal(int64 step_id, const RunOptions& run_options,
CallFrameInterface* call_frame,
@@ -391,6 +397,10 @@ class DirectSession : public Session {
Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
+ // For testing collective graph key generation.
+ mutex collective_graph_key_lock_;
+ int64 collective_graph_key_ GUARDED_BY(collective_graph_key_lock_) = -1;
+
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
// EXPERIMENTAL: debugger (tfdbg) related
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 4b51b20bb1..65e816c202 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -1255,7 +1255,7 @@ TEST(DirectSessionTest, RunHandleTest) {
ASSERT_TRUE(s.ok());
ASSERT_EQ(1, outputs.size());
- ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()();
+ const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
Tensor string_handle(DT_STRING, {});
string_handle.flat<string>().setConstant(resource_handle.name());
@@ -1308,7 +1308,7 @@ TEST(DirectSessionTest, RunHandleTest_Callable) {
ASSERT_TRUE(s.ok());
ASSERT_EQ(1, outputs.size());
- ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()();
+ const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
Tensor string_handle(DT_STRING, {});
string_handle.flat<string>().setConstant(resource_handle.name());
@@ -2218,4 +2218,121 @@ BENCHMARK(BM_FeedFetch)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
BENCHMARK(BM_FeedFetchCallable)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
} // namespace
+
+class DirectSessionCollectiveTest : public ::testing::Test {
+ public:
+ // Creates a graph with CollectiveOps inside functions and runs it. Returns
+ // the generated collective_graph_key.
+ Status RunGraphWithCollectiveFunctions(bool add_unused_function,
+ int64* collective_graph_key) {
+ GraphDef g = CreateGraph(add_unused_function);
+ const Tensor t1 =
+ test::AsTensor<float>({0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1});
+ const Tensor t2 =
+ test::AsTensor<float>({0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3});
+ auto session = CreateSession();
+ TF_RETURN_IF_ERROR(session->Create(g));
+ std::vector<Tensor> outputs;
+ TF_RETURN_IF_ERROR(
+ session->Run({{"input1:0", t1}, {"input2:0", t2}}, {},
+ {"collective_call1:0", "collective_call2:0"}, &outputs));
+ DirectSession* direct_session = static_cast<DirectSession*>(session.get());
+ {
+ mutex_lock l(direct_session->collective_graph_key_lock_);
+ *collective_graph_key = direct_session->collective_graph_key_;
+ }
+ return Status::OK();
+ }
+
+ private:
+ // Creates a function with name `function_name` and a single CollectiveReduce
+ // node with instance key set as `instance_key`.
+ FunctionDef CollectiveFunction(const string& function_name,
+ int instance_key) {
+ return FunctionDefHelper::Define(
+ // Function name
+ function_name,
+ // In def
+ {"arg:float"},
+ // Out def
+ {"reduce:float"},
+ // Attr def
+ {},
+ // Node def
+ {{
+ {"reduce"},
+ "CollectiveReduce",
+ {"arg"},
+ {{"group_size", 2},
+ {"group_key", 1},
+ {"instance_key", instance_key},
+ {"subdiv_offsets", gtl::ArraySlice<int32>({0})},
+ {"merge_op", "Add"},
+ {"final_op", "Div"},
+ {"T", DT_FLOAT}},
+ }});
+ }
+
+ // Creates a GraphDef that adds two CollectiveFunctions, one each on CPU0 and
+ // CPU1, with instance_key 1, and appropriate placeholder inputs. If
+ // `add_unused_function` is true, adds another CollectiveFunction with
+ // instance_key 2 that is not invoked in the graph.
+ GraphDef CreateGraph(bool add_unused_function) {
+ GraphDef g;
+ FunctionDef collective_function =
+ CollectiveFunction("CollectiveFunction1", 1);
+ FunctionDefLibrary* lib = g.mutable_library();
+ *lib->add_function() = collective_function;
+ if (add_unused_function) {
+ FunctionDef unused_function =
+ CollectiveFunction("CollectiveFunction2", 2);
+ *lib->add_function() = unused_function;
+ }
+
+ // Inputs.
+ AttrValue dtype_attr;
+ SetAttrValue(DT_FLOAT, &dtype_attr);
+ NodeDef input1;
+ input1.set_name("input1");
+ input1.set_op("Placeholder");
+ input1.mutable_attr()->insert({"dtype", dtype_attr});
+ NodeDef input2;
+ input2.set_name("input2");
+ input2.set_op("Placeholder");
+ input2.mutable_attr()->insert({"dtype", dtype_attr});
+
+ // CollectiveReduce on CPU0 with instance_key 1.
+ NodeDef collective_call1;
+ collective_call1.set_name("collective_call1");
+ collective_call1.set_op("CollectiveFunction1");
+ collective_call1.add_input("input1");
+ collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:0");
+ // CollectiveReduce on CPU1 with instance_key 1.
+ NodeDef collective_call2;
+ collective_call2.set_name("collective_call2");
+ collective_call2.set_op("CollectiveFunction1");
+ collective_call2.add_input("input2");
+ collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:1");
+
+ *g.add_node() = input1;
+ *g.add_node() = input2;
+ *g.add_node() = collective_call1;
+ *g.add_node() = collective_call2;
+
+ return g;
+ }
+};
+
+#ifndef GOOGLE_CUDA
+// TODO(ayushd): enable this test for GPU builds.
+TEST_F(DirectSessionCollectiveTest,
+ TestCollectiveGraphKeyUsesOnlyCalledFunctions) {
+ int64 key1;
+ TF_ASSERT_OK(RunGraphWithCollectiveFunctions(false, &key1));
+ int64 key2;
+ TF_ASSERT_OK(RunGraphWithCollectiveFunctions(true, &key2));
+ ASSERT_EQ(key1, key2);
+}
+#endif
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 39a3b49cd1..18420b60fd 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -32,34 +32,55 @@ bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
return default_val;
}
+std::unique_ptr<thread::ThreadPool> EagerThreadPool(
+ const SessionOptions& opts) {
+ SessionOptions opts_copy(opts);
+ if (opts_copy.config.inter_op_parallelism_threads() == 0) {
+ // Eager defaults to a single thread when no threads are specified.
+ opts_copy.config.set_inter_op_parallelism_threads(1);
+ }
+
+ return std::unique_ptr<thread::ThreadPool>(
+ NewThreadPoolFromSessionOptions(opts_copy));
+}
+
} // namespace
EagerContext::EagerContext(const SessionOptions& opts,
ContextDevicePlacementPolicy default_policy,
- bool async, std::unique_ptr<DeviceMgr> device_mgr,
+ bool async,
+ std::unique_ptr<const DeviceMgr> device_mgr,
Rendezvous* rendezvous)
+ : EagerContext(opts, default_policy, async, device_mgr.release(),
+ /*device_mgr_owned*/ true, rendezvous) {}
+
+EagerContext::EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy,
+ bool async, const DeviceMgr* device_mgr,
+ bool device_mgr_owned, Rendezvous* rendezvous)
: policy_(default_policy),
- local_device_manager_(std::move(device_mgr)),
- local_unowned_device_manager_(nullptr),
- devices_(local_device_manager_->ListDevices()),
+ devices_(device_mgr->ListDevices()),
rendezvous_(rendezvous),
- thread_pool_(NewThreadPoolFromSessionOptions(opts)),
+ thread_pool_(EagerThreadPool(opts)),
pflr_(new ProcessFunctionLibraryRuntime(
- local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION,
- &func_lib_def_, {}, thread_pool_.get())),
+ device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {},
+ thread_pool_.get())),
log_device_placement_(opts.config.log_device_placement()),
num_active_steps_(0),
async_default_(async),
+ log_memory_(LogMemory::IsEnabled()),
env_(opts.env),
use_send_tensor_rpc_(false) {
- InitDeviceMapAndAsync();
- if (opts.config.inter_op_parallelism_threads() > 0) {
- runner_ = [this](std::function<void()> closure) {
- this->thread_pool_->Schedule(closure);
- };
+ if (device_mgr_owned) {
+ local_device_manager_.reset(device_mgr);
+ local_unowned_device_manager_ = nullptr;
} else {
- runner_ = [](std::function<void()> closure) { closure(); };
+ local_unowned_device_manager_ = device_mgr;
}
+ InitDeviceMapAndAsync();
+ runner_ = [this](std::function<void()> closure) {
+ this->thread_pool_->Schedule(std::move(closure));
+ };
}
void EagerContext::InitDeviceMapAndAsync() {
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 3c95ac590d..5ed6057ec6 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#endif
+#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
@@ -65,10 +66,17 @@ enum ContextDevicePlacementPolicy {
class EagerContext {
public:
- explicit EagerContext(const SessionOptions& opts,
- ContextDevicePlacementPolicy default_policy, bool async,
- std::unique_ptr<DeviceMgr> device_mgr,
- Rendezvous* rendezvous);
+ // TODO: remove this constructor once we migrate all callers to the next one.
+ EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy, bool async,
+ std::unique_ptr<const DeviceMgr> device_mgr,
+ Rendezvous* rendezvous);
+
+ EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy, bool async,
+ const DeviceMgr* device_mgr, bool device_mgr_owned,
+ Rendezvous* rendezvous);
+
~EagerContext();
// Returns the function library runtime for the given device.
@@ -134,6 +142,7 @@ class EagerContext {
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
bool LogDevicePlacement() { return log_device_placement_; }
+ bool LogMemory() { return log_memory_; }
Rendezvous* GetRendezvous() { return rendezvous_; }
@@ -207,8 +216,8 @@ class EagerContext {
thread_local_policies_ GUARDED_BY(policy_map_mu_);
// Only one of the below is set.
- std::unique_ptr<DeviceMgr> local_device_manager_;
- DeviceMgr* local_unowned_device_manager_;
+ std::unique_ptr<const DeviceMgr> local_device_manager_;
+ const DeviceMgr* local_unowned_device_manager_;
std::unique_ptr<DeviceMgr> remote_device_manager_;
// Devices owned by device_manager
@@ -254,6 +263,8 @@ class EagerContext {
std::unordered_map<std::thread::id, bool> thread_local_async_
GUARDED_BY(async_map_mu_);
+ const bool log_memory_;
+
Env* const env_;
#ifndef __ANDROID__
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 5b3a64ba98..1da1326a9a 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -296,7 +296,7 @@ Status EagerLocalExecute(EagerOperation* op,
LOG(INFO) << "Executing op " << ndef.op() << " in device "
<< device->name();
}
- kernel = new KernelAndDevice(ctx->GetRendezvous());
+ kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory());
auto* flr = ctx->func_lib(device);
if (flr == nullptr) {
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index 3d61ff4dc2..83d8425477 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -32,21 +32,6 @@ limitations under the License.
namespace tensorflow {
// static
-Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
- KernelAndDevice* out) {
- OpKernel* k = nullptr;
- Status s = CreateOpKernel(device->device_type().c_str(), device,
- device->GetAllocator(AllocatorAttributes()),
- nullptr, ndef, TF_GRAPH_DEF_VERSION, &k);
- out->device_ = device;
- out->kernel_.reset(k);
- out->flib_ = nullptr;
- out->runner_ = nullptr;
- out->default_runner_ = [](std::function<void()> f) { f(); };
- return s;
-}
-
-// static
Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
std::function<void(std::function<void()>)>* runner,
KernelAndDevice* out) {
@@ -95,6 +80,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
params.slice_reader_cache = &slice_reader_cache_;
params.rendezvous = rendez_;
params.cancellation_manager = &cm_;
+ params.log_memory = log_memory_;
if (stats != nullptr) {
params.track_allocations = true;
}
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index 0ef419cbaa..04151a1171 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -52,12 +52,12 @@ class KernelAndDevice {
static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
std::function<void(std::function<void()>)>* runner,
KernelAndDevice* out);
- // TODO(ashankar): Remove this
- static Status InitOp(Device* device, const NodeDef& ndef,
- KernelAndDevice* out);
- KernelAndDevice(tensorflow::Rendezvous* rendez)
- : device_(nullptr), flib_(nullptr), rendez_(rendez) {}
+ KernelAndDevice(tensorflow::Rendezvous* rendez, bool log_memory)
+ : device_(nullptr),
+ flib_(nullptr),
+ rendez_(rendez),
+ log_memory_(log_memory) {}
// TODO(ashankar): Handle list-valued inputs.
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
@@ -87,6 +87,7 @@ class KernelAndDevice {
DataTypeVector output_dtypes_;
std::function<void(std::function<void()>)>* runner_;
std::function<void(std::function<void()>)> default_runner_;
+ const bool log_memory_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
index 6abe98f53c..da280b2317 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
@@ -104,7 +104,7 @@ void BM_KernelAndDeviceInit(int iters) {
.NumInputs(2)
.BuildNodeDef());
TestEnv env;
- KernelAndDevice k(nullptr);
+ KernelAndDevice k(nullptr, false);
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(),
@@ -127,7 +127,7 @@ void BM_KernelAndDeviceRun(int iters) {
.NumInputs(inputs.size())
.BuildNodeDef());
TestEnv env;
- KernelAndDevice kernel(nullptr);
+ KernelAndDevice kernel(nullptr, false);
TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(),
nullptr, &kernel));
tensorflow::testing::StartTiming();
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index b912f7d37b..d58724cbfa 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -125,7 +125,6 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
Status TensorHandle::NumDims(int* num_dims) {
if (IsRemote()) {
TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
- CHECK(remote_shape_ != nullptr);
*num_dims = remote_shape_->dims();
} else {
TF_RETURN_IF_ERROR(WaitReady());
@@ -153,6 +152,21 @@ Status TensorHandle::Dim(int dim_index, int64* dim) {
return Status::OK();
}
+Status TensorHandle::NumElements(int64* num_elements) {
+ if (IsRemote()) {
+ TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
+ *num_elements = remote_shape_->num_elements();
+ } else {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ DCHECK(num_elements != nullptr);
+
+ *num_elements = tensor_.NumElements();
+ }
+
+ return Status::OK();
+}
+
Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) {
if (!IsRemote()) {
return errors::FailedPrecondition(
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index 1bc9c6531a..e55f1a0338 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -113,6 +113,7 @@ class TensorHandle : public core::RefCounted {
Status NumDims(int* num_dims);
Status Dim(int dim_index, int64* dim);
+ Status NumElements(int64* num_elements);
// Return the op_id and output num if the handle refers to a remote tensor.
Status RemoteAddress(int64* op_id, int32* output_num);
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 3ef6d35182..d0a0767d6b 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -76,56 +76,47 @@ bool IsInitializationOp(const Node* node) {
namespace nodestats {
inline int64 NowInNsec() { return Env::Default()->NowNanos(); }
-void SetScheduled(NodeExecStatsWrapper* stats, int64 micros) {
+void SetScheduled(NodeExecStatsInterface* stats, int64 micros) {
if (!stats) return;
stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
}
-void SetAllStart(NodeExecStatsWrapper* stats) {
+void SetAllStart(NodeExecStatsInterface* stats) {
if (!stats) return;
stats->RecordExecutorStarted();
}
-void SetOpStart(NodeExecStatsWrapper* stats) {
+void SetOpStart(NodeExecStatsInterface* stats) {
if (!stats) return;
stats->RecordComputeStarted();
}
-void SetOpEnd(NodeExecStatsWrapper* stats) {
+void SetOpEnd(NodeExecStatsInterface* stats) {
if (!stats) return;
stats->RecordComputeEnded();
}
-void SetAllEnd(NodeExecStatsWrapper* stats) {
+void SetAllEnd(NodeExecStatsInterface* stats) {
if (!stats) return;
stats->RecordExecutorEnded();
}
-void SetOutput(NodeExecStatsWrapper* stats, int slot, const Tensor* v) {
+void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) {
if (!stats) return;
stats->SetOutput(slot, v);
}
-void SetMemory(NodeExecStatsWrapper* stats, OpKernelContext* ctx) {
+void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
if (!stats) return;
stats->SetMemory(ctx);
}
-void SetReferencedTensors(NodeExecStatsWrapper* stats,
+void SetReferencedTensors(NodeExecStatsInterface* stats,
const TensorReferenceVector& tensors) {
if (!stats) return;
stats->SetReferencedTensors(tensors);
}
-// Sets the timeline_label field of *stats, using data from *node.
-// Returns true iff the node is a transfer node.
-bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) {
- if (!stats) {
- return false;
- }
- return stats->SetTimelineLabel(node);
-}
-
} // namespace nodestats
class ExecutorImpl;
@@ -1301,7 +1292,7 @@ class ExecutorState {
// After item->kernel computation is done, processes its outputs.
Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
- EntryVector* outputs, NodeExecStatsWrapper* stats);
+ EntryVector* outputs, NodeExecStatsInterface* stats);
// After processing the outputs, propagates the outputs to their dsts.
// Contents of *outputs are left in an indeterminate state after
@@ -1312,7 +1303,7 @@ class ExecutorState {
// "node" just finishes. Takes ownership of "stats". Returns true if
// execution has completed.
bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
- NodeExecStatsWrapper* stats,
+ NodeExecStatsInterface* stats,
TaggedNodeReadyQueue* inline_ready);
// Schedule all the expensive nodes in 'ready', and put all the inexpensive
@@ -1482,6 +1473,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
const Status fill_status =
device->FillContextMap(graph, &device_context_map_);
if (!fill_status.ok()) {
+ delete this;
done(fill_status);
return;
}
@@ -1492,6 +1484,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
ready.push_back(TaggedNode{n, root_frame_, 0, false});
}
if (ready.empty()) {
+ delete this;
done(Status::OK());
} else {
num_outstanding_ops_ = ready.size();
@@ -1511,7 +1504,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
struct ExecutorState::AsyncState {
AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
const NodeItem* _item, Entry* _first_input,
- NodeExecStatsWrapper* _stats)
+ NodeExecStatsInterface* _stats)
: saved_inputs(*p.inputs),
saved_input_device_contexts(*p.input_device_contexts),
saved_input_alloc_attrs(*p.input_alloc_attrs),
@@ -1536,7 +1529,7 @@ struct ExecutorState::AsyncState {
const NodeItem* item;
Entry* first_input;
OpKernelContext ctx;
- NodeExecStatsWrapper* stats;
+ NodeExecStatsInterface* stats;
private:
OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
@@ -1581,7 +1574,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
params.stats_collector = stats_collector_;
Status s;
- NodeExecStatsWrapper* stats = nullptr;
+ NodeExecStatsInterface* stats = nullptr;
EntryVector outputs;
bool completed = false;
inline_ready.push_back(tagged_node);
@@ -1611,7 +1604,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
if (stats_collector_ && !tagged_node.is_dead) {
// track allocations if and only if we are collecting statistics
params.track_allocations = true;
- stats = new NodeExecStatsWrapper(node->name());
+ stats = stats_collector_->CreateNodeExecStats(node);
nodestats::SetScheduled(stats, scheduled_nsec);
nodestats::SetAllStart(stats);
}
@@ -1669,7 +1662,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
auto done = [this, state]() {
Device* device = impl_->params_.device;
- NodeExecStatsWrapper* stats = state->stats; // Shorthand
+ NodeExecStatsInterface* stats = state->stats; // Shorthand
Entry* first_input = state->first_input; // Shorthand
nodestats::SetOpEnd(stats);
@@ -1860,7 +1853,7 @@ Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
EntryVector* outputs,
- NodeExecStatsWrapper* stats) {
+ NodeExecStatsInterface* stats) {
const Node* node = item.node;
DCHECK_EQ(0, outputs->size());
outputs->resize(item.num_outputs);
@@ -2078,16 +2071,15 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
bool ExecutorState::NodeDone(const Status& s, const Node* node,
const TaggedNodeSeq& ready,
- NodeExecStatsWrapper* stats,
+ NodeExecStatsInterface* stats,
TaggedNodeReadyQueue* inline_ready) {
nodestats::SetAllEnd(stats);
- if (stats_collector_ != nullptr &&
- !nodestats::SetTimelineLabel(node, stats)) {
- // Only record non-transfer nodes.
- // Transfers 'stats' ownership to 'stats_collector_'.
- stats_collector_->Save(impl_->params_.device->name(), stats);
- } else if (stats) {
- delete stats;
+ if (stats) {
+ if (stats_collector_) {
+ stats->Done(impl_->params_.device->name());
+ } else {
+ delete stats;
+ }
}
bool abort_run = false;
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 46bb8d92f8..472865ca43 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -414,9 +414,8 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(
device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
&fbody->fdef.signature(), this, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, graph_def_version_, &s);
- *kernel = new CallOp(handle, &construction);
- if (!s.ok()) {
- delete *kernel;
+ if (s.ok()) {
+ *kernel = new CallOp(handle, &construction);
}
return s;
}
@@ -615,11 +614,14 @@ void PruneFunctionBody(Graph* g) {
std::unordered_set<const Node*> nodes;
for (auto n : g->nodes()) {
// NOTE(mrry): "_Retval" nodes are stateful, and so will be added
- // to the seed set of `nodes`.
+ // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we
+ // specifically exclude them as seeds, to avoid unconditionally executing
+ // unused argument nodes (e.g. in a function like `lambda x, y: y`).
// TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
// still needed. It would be preferable to prune entire loops and/or
// conditionals if they are not used in the graph.
- if (n->IsControlFlow() || n->op_def().is_stateful()) {
+ if (n->IsControlFlow() ||
+ (n->op_def().is_stateful() && n->type_string() != kArgOp)) {
nodes.insert(n);
}
}
@@ -925,29 +927,18 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
}
DCHECK(run_opts.runner != nullptr);
- Executor::Args* exec_args = new Executor::Args;
+ Executor::Args exec_args;
// Inherit the step_id from the caller.
- exec_args->step_id = run_opts.step_id;
- exec_args->rendezvous = run_opts.rendezvous;
- exec_args->stats_collector = run_opts.stats_collector;
- exec_args->cancellation_manager = run_opts.cancellation_manager;
- exec_args->collective_executor = run_opts.collective_executor;
- exec_args->step_container = run_opts.step_container;
- exec_args->runner = *run_opts.runner;
- exec_args->call_frame = frame;
-
- item->exec->RunAsync(
- // Executor args
- *exec_args,
- // Done callback.
- std::bind(
- [item, frame, exec_args](DoneCallback done,
- // Start unbound arguments.
- const Status& status) {
- delete exec_args;
- done(status);
- },
- std::move(done), std::placeholders::_1));
+ exec_args.step_id = run_opts.step_id;
+ exec_args.rendezvous = run_opts.rendezvous;
+ exec_args.stats_collector = run_opts.stats_collector;
+ exec_args.cancellation_manager = run_opts.cancellation_manager;
+ exec_args.collective_executor = run_opts.collective_executor;
+ exec_args.step_container = run_opts.step_container;
+ exec_args.runner = *run_opts.runner;
+ exec_args.call_frame = frame;
+
+ item->exec->RunAsync(exec_args, std::move(done));
}
bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) {
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 120f480198..7bab9be9a6 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -802,9 +802,9 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
// Name
"SquareAndAddOneWithStatefulNodes",
// Args
- {"x: int32"},
+ {"x: int32", "y: float32"},
// Return values
- {"y: int32"},
+ {"z: int32"},
// Attrs
{},
// Nodes
@@ -822,12 +822,13 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
"RandomUniform",
{"shape"},
{{"T", T}, {"dtype", DT_FLOAT}}},
- // y = Add<T>(a, o)
- {{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
+ // z = Add<T>(a, o)
+ {{"z"}, "Add", {"a", "o"}, {{"T", T}}}});
Init({stateful_func});
auto x = test::AsTensor<int32>({1, 2, 3, 4});
- Tensor y;
+ auto y = test::AsTensor<float>({1.0, 2.0, 3.0, 4.0});
+ Tensor z;
FunctionLibraryRuntime::Handle handle;
TF_CHECK_OK(
@@ -837,18 +838,19 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
StepStatsCollector stats_collector(&stats);
FunctionLibraryRuntime::Options opts;
opts.stats_collector = &stats_collector;
- TF_CHECK_OK(Run(flr0_, handle, opts, {x}, {&y}));
+ TF_CHECK_OK(Run(flr0_, handle, opts, {x, y}, {&z}));
TF_CHECK_OK(flr0_->ReleaseHandle(handle));
TF_CHECK_OK(InstantiateAndRun(flr0_, "SquareAndAddOneWithStatefulNodes", {},
- {x}, {&y}));
- test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({2, 5, 10, 17}));
+ {x, y}, {&z}));
+ test::ExpectTensorEqual<int>(z, test::AsTensor<int32>({2, 5, 10, 17}));
stats_collector.FinalizeAndSwap(&stats);
- // Note that we do not expect the nodes named "x1", "x2", or "x3" to execute.
+ // Note that we do not expect the nodes named "y", "x1", "x2", or "x3" to
+ // execute.
std::set<string> expected_node_names(
- {"_SOURCE", "shape", "x", "o", "a", "keep_me", "y", "y_RetVal"});
+ {"_SOURCE", "shape", "x", "o", "a", "keep_me", "z", "z_RetVal"});
std::set<string> executed_node_names;
for (const auto& node_stats : stats.dev_stats()[0].node_stats()) {
executed_node_names.insert(node_stats.node_name());
diff --git a/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h b/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
index 636cd43575..6bd29ef775 100644
--- a/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
@@ -26,8 +26,12 @@ namespace tensorflow {
class CUDAHostAllocator : public SubAllocator {
public:
// Note: stream_exec cannot be null.
- explicit CUDAHostAllocator(se::StreamExecutor* stream_exec)
- : stream_exec_(stream_exec) {
+ explicit CUDAHostAllocator(se::StreamExecutor* stream_exec, int numa_node,
+ const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors)
+ : SubAllocator(alloc_visitors, free_visitors),
+ stream_exec_(stream_exec),
+ numa_node_(numa_node) {
CHECK(stream_exec_ != nullptr);
}
~CUDAHostAllocator() override {}
@@ -39,19 +43,23 @@ class CUDAHostAllocator : public SubAllocator {
if (ptr == nullptr) {
LOG(WARNING) << "could not allocate pinned host memory of size: "
<< num_bytes;
+ return ptr;
}
+ VisitAlloc(ptr, numa_node_, num_bytes);
}
return ptr;
}
void Free(void* ptr, size_t num_bytes) override {
if (ptr != nullptr) {
+ VisitFree(ptr, numa_node_, num_bytes);
stream_exec_->HostMemoryDeallocate(ptr);
}
}
private:
se::StreamExecutor* stream_exec_; // not owned, non-null
+ const int numa_node_;
TF_DISALLOW_COPY_AND_ASSIGN(CUDAHostAllocator);
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
index c8db384b64..44ffce77a1 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
@@ -22,19 +22,15 @@ limitations under the License.
namespace tensorflow {
-GPUBFCAllocator::GPUBFCAllocator(PlatformGpuId platform_gpu_id,
+GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator,
size_t total_memory, const string& name)
- : GPUBFCAllocator(platform_gpu_id, total_memory, GPUOptions(), name) {}
+ : GPUBFCAllocator(sub_allocator, total_memory, GPUOptions(), name) {}
-GPUBFCAllocator::GPUBFCAllocator(PlatformGpuId platform_gpu_id,
+GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator,
size_t total_memory,
const GPUOptions& gpu_options,
const string& name)
- : BFCAllocator(
- new GPUMemAllocator(
- GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
- gpu_options.per_process_gpu_memory_fraction() > 1.0 ||
- gpu_options.experimental().use_unified_memory()),
- total_memory, gpu_options.allow_growth(), name) {}
+ : BFCAllocator(sub_allocator, total_memory, gpu_options.allow_growth(),
+ name) {}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
index 435ffb4959..3470f7a9f7 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
@@ -31,28 +31,20 @@ limitations under the License.
namespace tensorflow {
-// A GPU memory allocator that implements a 'best-fit with coalescing'
-// algorithm.
-class GPUBFCAllocator : public BFCAllocator {
- public:
- // 'platform_gpu_id' refers to the ID of the GPU device within
- // the process and must reference a valid ID in the process.
- GPUBFCAllocator(PlatformGpuId platform_gpu_id, size_t total_memory,
- const string& name);
- GPUBFCAllocator(PlatformGpuId platform_gpu_id, size_t total_memory,
- const GPUOptions& gpu_options, const string& name);
- virtual ~GPUBFCAllocator() {}
-
- TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator);
-};
-
// Suballocator for GPU memory.
class GPUMemAllocator : public SubAllocator {
public:
+ // 'platform_gpu_id' refers to the ID of the GPU device within
+ // the process and must reference a valid ID in the process.
// Note: stream_exec cannot be null.
explicit GPUMemAllocator(se::StreamExecutor* stream_exec,
- bool use_unified_memory)
- : stream_exec_(stream_exec), use_unified_memory_(use_unified_memory) {
+ PlatformGpuId gpu_id, bool use_unified_memory,
+ const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors)
+ : SubAllocator(alloc_visitors, free_visitors),
+ stream_exec_(stream_exec),
+ gpu_id_(gpu_id),
+ use_unified_memory_(use_unified_memory) {
CHECK(stream_exec_ != nullptr);
}
~GPUMemAllocator() override {}
@@ -65,12 +57,14 @@ class GPUMemAllocator : public SubAllocator {
} else {
ptr = stream_exec_->AllocateArray<char>(num_bytes).opaque();
}
+ VisitAlloc(ptr, gpu_id_.value(), num_bytes);
}
return ptr;
}
void Free(void* ptr, size_t num_bytes) override {
if (ptr != nullptr) {
+ VisitFree(ptr, gpu_id_.value(), num_bytes);
if (use_unified_memory_) {
stream_exec_->UnifiedMemoryDeallocate(ptr);
} else {
@@ -82,11 +76,25 @@ class GPUMemAllocator : public SubAllocator {
private:
se::StreamExecutor* stream_exec_; // not owned, non-null
+ const PlatformGpuId gpu_id_;
const bool use_unified_memory_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(GPUMemAllocator);
};
+// A GPU memory allocator that implements a 'best-fit with coalescing'
+// algorithm.
+class GPUBFCAllocator : public BFCAllocator {
+ public:
+ GPUBFCAllocator(GPUMemAllocator* sub_allocator, size_t total_memory,
+ const string& name);
+ GPUBFCAllocator(GPUMemAllocator* sub_allocator, size_t total_memory,
+ const GPUOptions& gpu_options, const string& name);
+ ~GPUBFCAllocator() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator);
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
index 518ccba580..e313135d8d 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -46,7 +47,11 @@ static void CheckStats(Allocator* a, int64 num_allocs, int64 bytes_in_use,
}
TEST(GPUBFCAllocatorTest, NoDups) {
- GPUBFCAllocator a(PlatformGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
CheckStats(&a, 0, 0, 0, 0);
// Allocate a lot of raw pointers
@@ -75,7 +80,11 @@ TEST(GPUBFCAllocatorTest, NoDups) {
}
TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) {
- GPUBFCAllocator a(PlatformGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
// Allocate 256 raw pointers of sizes between 100 bytes and about
// a meg
random::PhiloxRandom philox(123, 17);
@@ -133,7 +142,11 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) {
}
TEST(GPUBFCAllocatorTest, ExerciseCoalescing) {
- GPUBFCAllocator a(PlatformGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
CheckStats(&a, 0, 0, 0, 0);
float* first_ptr = a.Allocate<float>(1024);
@@ -168,18 +181,30 @@ TEST(GPUBFCAllocatorTest, ExerciseCoalescing) {
}
TEST(GPUBFCAllocatorTest, AllocateZeroBufSize) {
- GPUBFCAllocator a(PlatformGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
float* ptr = a.Allocate<float>(0);
EXPECT_EQ(nullptr, ptr);
}
TEST(GPUBFCAllocatorTest, TracksSizes) {
- GPUBFCAllocator a(PlatformGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
EXPECT_EQ(true, a.TracksAllocationSizes());
}
TEST(GPUBFCAllocatorTest, AllocatedVsRequested) {
- GPUBFCAllocator a(PlatformGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
float* t1 = a.Allocate<float>(1);
EXPECT_EQ(4, a.RequestedSize(t1));
EXPECT_EQ(256, a.AllocatedSize(t1));
@@ -187,8 +212,12 @@ TEST(GPUBFCAllocatorTest, AllocatedVsRequested) {
}
TEST(GPUBFCAllocatorTest, TestCustomMemoryLimit) {
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
// Configure a 1MiB byte limit
- GPUBFCAllocator a(PlatformGpuId(0), 1 << 20, "GPU_0_bfc");
+ GPUBFCAllocator a(sub_allocator, 1 << 20, "GPU_0_bfc");
float* first_ptr = a.Allocate<float>(1 << 6);
float* second_ptr = a.Allocate<float>(1 << 20);
@@ -203,7 +232,11 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocationsWithGrowth) {
options.set_allow_growth(true);
// Max of 2GiB, but starts out small.
- GPUBFCAllocator a(PlatformGpuId(0), 1LL << 31, options, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1LL << 31, "GPU_0_bfc");
// Allocate 10 raw pointers of sizes between 100 bytes and about
// 64 megs.
@@ -264,8 +297,15 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocationsWithGrowth) {
}
TEST(GPUBFCAllocatorTest, DISABLED_AllocatorReceivesZeroMemory) {
- GPUBFCAllocator a(PlatformGpuId(0), 1UL << 60, "GPU_0_bfc");
- GPUBFCAllocator b(PlatformGpuId(0), 1UL << 60, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1UL << 60, "GPU_0_bfc");
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator b(sub_allocator, 1UL << 60, "GPU_0_bfc");
void* amem = a.AllocateRaw(1, 1);
void* bmem = b.AllocateRaw(1, 1 << 30);
a.DeallocateRaw(amem);
@@ -273,7 +313,11 @@ TEST(GPUBFCAllocatorTest, DISABLED_AllocatorReceivesZeroMemory) {
}
static void BM_Allocation(int iters) {
- GPUBFCAllocator a(PlatformGpuId(0), 1uLL << 33, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1uLL << 33, "GPU_0_bfc");
// Exercise a few different allocation sizes
std::vector<size_t> sizes = {256, 4096, 16384, 524288,
512, 1048576, 10485760, 104857600,
@@ -289,7 +333,11 @@ static void BM_Allocation(int iters) {
BENCHMARK(BM_Allocation);
static void BM_AllocationThreaded(int iters, int num_threads) {
- GPUBFCAllocator a(PlatformGpuId(0), 1uLL << 33, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1uLL << 33, "GPU_0_bfc");
thread::ThreadPool pool(Env::Default(), "test", num_threads);
std::atomic_int_fast32_t count(iters);
mutex done_lock;
@@ -325,7 +373,11 @@ BENCHMARK(BM_AllocationThreaded)->Arg(1)->Arg(4)->Arg(16);
// A more complex benchmark that defers deallocation of an object for
// "delay" allocations.
static void BM_AllocationDelayed(int iters, int delay) {
- GPUBFCAllocator a(PlatformGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
// Exercise a few different allocation sizes
std::vector<int> sizes = {256, 4096, 16384, 4096, 512, 1024, 1024};
int size_index = 0;
@@ -363,7 +415,11 @@ class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test {
// only methods inside this class can access private members of BFCAllocator.
void TestBinDebugInfo() {
- GPUBFCAllocator a(PlatformGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
std::vector<void*> initial_ptrs;
std::vector<size_t> initial_ptrs_allocated_sizes;
@@ -441,7 +497,11 @@ class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test {
}
void TestLog2FloorNonZeroSlow() {
- GPUBFCAllocator a(PlatformGpuId(0), 1 /* total_memory */, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 /* total_memory */, "GPU_0_bfc");
EXPECT_EQ(-1, a.Log2FloorNonZeroSlow(0));
EXPECT_EQ(0, a.Log2FloorNonZeroSlow(1));
EXPECT_EQ(1, a.Log2FloorNonZeroSlow(2));
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
index 553a5628ad..d85ca8892f 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
@@ -27,7 +27,7 @@ limitations under the License.
namespace tensorflow {
-GPUcudaMallocAllocator::GPUcudaMallocAllocator(VisitableAllocator* allocator,
+GPUcudaMallocAllocator::GPUcudaMallocAllocator(Allocator* allocator,
PlatformGpuId platform_gpu_id)
: base_allocator_(allocator) {
stream_exec_ =
@@ -61,14 +61,6 @@ void GPUcudaMallocAllocator::DeallocateRaw(void* ptr) {
#endif // GOOGLE_CUDA
}
-void GPUcudaMallocAllocator::AddAllocVisitor(Visitor visitor) {
- return base_allocator_->AddAllocVisitor(visitor);
-}
-
-void GPUcudaMallocAllocator::AddFreeVisitor(Visitor visitor) {
- return base_allocator_->AddFreeVisitor(visitor);
-}
-
bool GPUcudaMallocAllocator::TracksAllocationSizes() { return false; }
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
index 8f38cc5a18..8df3724bc4 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
@@ -29,20 +29,18 @@ namespace tensorflow {
// An allocator that wraps a GPU allocator and adds debugging
// functionality that verifies that users do not write outside their
// allocated memory.
-class GPUcudaMallocAllocator : public VisitableAllocator {
+class GPUcudaMallocAllocator : public Allocator {
public:
- explicit GPUcudaMallocAllocator(VisitableAllocator* allocator,
+ explicit GPUcudaMallocAllocator(Allocator* allocator,
PlatformGpuId platform_gpu_id);
~GPUcudaMallocAllocator() override;
string Name() override { return "gpu_debug"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;
- void AddAllocVisitor(Visitor visitor) override;
- void AddFreeVisitor(Visitor visitor) override;
bool TracksAllocationSizes() override;
private:
- VisitableAllocator* base_allocator_ = nullptr; // owned
+ Allocator* base_allocator_ = nullptr; // owned
se::StreamExecutor* stream_exec_; // Not owned.
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
index badb021aa5..989ddbe4af 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
@@ -73,7 +73,7 @@ void InitMask(se::StreamExecutor* exec, void* ptr, int64* mask) {
// -----------------------------------------------------------------------------
// GPUDebugAllocator
// -----------------------------------------------------------------------------
-GPUDebugAllocator::GPUDebugAllocator(VisitableAllocator* allocator,
+GPUDebugAllocator::GPUDebugAllocator(Allocator* allocator,
PlatformGpuId platform_gpu_id)
: base_allocator_(allocator) {
stream_exec_ =
@@ -112,14 +112,6 @@ void GPUDebugAllocator::DeallocateRaw(void* ptr) {
base_allocator_->DeallocateRaw(ptr);
}
-void GPUDebugAllocator::AddAllocVisitor(Visitor visitor) {
- return base_allocator_->AddAllocVisitor(visitor);
-}
-
-void GPUDebugAllocator::AddFreeVisitor(Visitor visitor) {
- return base_allocator_->AddFreeVisitor(visitor);
-}
-
bool GPUDebugAllocator::TracksAllocationSizes() { return true; }
size_t GPUDebugAllocator::RequestedSize(const void* ptr) {
@@ -159,7 +151,7 @@ bool GPUDebugAllocator::CheckFooter(void* ptr) {
// -----------------------------------------------------------------------------
// GPUNanResetAllocator
// -----------------------------------------------------------------------------
-GPUNanResetAllocator::GPUNanResetAllocator(VisitableAllocator* allocator,
+GPUNanResetAllocator::GPUNanResetAllocator(Allocator* allocator,
PlatformGpuId platform_gpu_id)
: base_allocator_(allocator) {
stream_exec_ =
@@ -202,14 +194,6 @@ void GPUNanResetAllocator::DeallocateRaw(void* ptr) {
base_allocator_->DeallocateRaw(ptr);
}
-void GPUNanResetAllocator::AddAllocVisitor(Visitor visitor) {
- return base_allocator_->AddAllocVisitor(visitor);
-}
-
-void GPUNanResetAllocator::AddFreeVisitor(Visitor visitor) {
- return base_allocator_->AddFreeVisitor(visitor);
-}
-
size_t GPUNanResetAllocator::RequestedSize(const void* ptr) {
return base_allocator_->RequestedSize(ptr);
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
index 9e007ed8c1..17757a106c 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
@@ -21,7 +21,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
@@ -31,16 +31,14 @@ namespace tensorflow {
// An allocator that wraps a GPU allocator and adds debugging
// functionality that verifies that users do not write outside their
// allocated memory.
-class GPUDebugAllocator : public VisitableAllocator {
+class GPUDebugAllocator : public Allocator {
public:
- explicit GPUDebugAllocator(VisitableAllocator* allocator,
+ explicit GPUDebugAllocator(Allocator* allocator,
PlatformGpuId platform_gpu_id);
~GPUDebugAllocator() override;
string Name() override { return "gpu_debug"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;
- void AddAllocVisitor(Visitor visitor) override;
- void AddFreeVisitor(Visitor visitor) override;
bool TracksAllocationSizes() override;
size_t RequestedSize(const void* ptr) override;
size_t AllocatedSize(const void* ptr) override;
@@ -53,7 +51,7 @@ class GPUDebugAllocator : public VisitableAllocator {
bool CheckFooter(void* ptr);
private:
- VisitableAllocator* base_allocator_ = nullptr; // owned
+ Allocator* base_allocator_ = nullptr; // owned
se::StreamExecutor* stream_exec_; // Not owned.
@@ -63,23 +61,21 @@ class GPUDebugAllocator : public VisitableAllocator {
// An allocator that wraps a GPU allocator and resets the memory on
// allocation and free to 'NaN', helping to identify cases where the
// user forgets to initialize the memory.
-class GPUNanResetAllocator : public VisitableAllocator {
+class GPUNanResetAllocator : public Allocator {
public:
- explicit GPUNanResetAllocator(VisitableAllocator* allocator,
+ explicit GPUNanResetAllocator(Allocator* allocator,
PlatformGpuId platform_gpu_id);
~GPUNanResetAllocator() override;
string Name() override { return "gpu_nan_reset"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;
- void AddAllocVisitor(Visitor visitor) override;
- void AddFreeVisitor(Visitor visitor) override;
size_t RequestedSize(const void* ptr) override;
size_t AllocatedSize(const void* ptr) override;
void GetStats(AllocatorStats* stats) override;
void ClearStats() override;
private:
- VisitableAllocator* base_allocator_ = nullptr; // owned
+ Allocator* base_allocator_ = nullptr; // owned
se::StreamExecutor* stream_exec_; // Not owned.
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
index bc3e3a8c35..aca08a7e33 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
@@ -35,7 +35,10 @@ namespace {
TEST(GPUDebugAllocatorTest, OverwriteDetection_None) {
const PlatformGpuId platform_gpu_id(0);
- GPUDebugAllocator a(new GPUBFCAllocator(platform_gpu_id, 1 << 30, ""),
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
platform_gpu_id);
auto stream_exec =
GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
@@ -60,7 +63,10 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Header) {
EXPECT_DEATH(
{
const PlatformGpuId platform_gpu_id(0);
- GPUDebugAllocator a(new GPUBFCAllocator(platform_gpu_id, 1 << 30, ""),
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
platform_gpu_id);
auto stream_exec =
GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
@@ -93,7 +99,10 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) {
EXPECT_DEATH(
{
const PlatformGpuId platform_gpu_id(0);
- GPUDebugAllocator a(new GPUBFCAllocator(platform_gpu_id, 1 << 30, ""),
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
platform_gpu_id);
auto stream_exec =
GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
@@ -123,7 +132,10 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) {
TEST(GPUDebugAllocatorTest, ResetToNan) {
const PlatformGpuId platform_gpu_id(0);
- GPUNanResetAllocator a(new GPUBFCAllocator(platform_gpu_id, 1 << 30, ""),
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUNanResetAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
platform_gpu_id);
auto stream_exec =
GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
@@ -165,8 +177,11 @@ TEST(GPUDebugAllocatorTest, ResetToNan) {
TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) {
const PlatformGpuId platform_gpu_id(0);
// NaN reset must be the outer-most allocator.
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUNanResetAllocator a(
- new GPUDebugAllocator(new GPUBFCAllocator(platform_gpu_id, 1 << 30, ""),
+ new GPUDebugAllocator(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
platform_gpu_id),
platform_gpu_id);
auto stream_exec =
@@ -208,15 +223,21 @@ TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) {
TEST(GPUDebugAllocatorTest, TracksSizes) {
const PlatformGpuId platform_gpu_id(0);
- GPUDebugAllocator a(new GPUBFCAllocator(platform_gpu_id, 1 << 30, ""),
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
platform_gpu_id);
EXPECT_EQ(true, a.TracksAllocationSizes());
}
TEST(GPUDebugAllocatorTest, AllocatedVsRequested) {
const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUNanResetAllocator a(
- new GPUDebugAllocator(new GPUBFCAllocator(platform_gpu_id, 1 << 30, ""),
+ new GPUDebugAllocator(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
platform_gpu_id),
platform_gpu_id);
float* t1 = a.Allocate<float>(1);
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 4bf23bc017..cf3faf68ff 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -41,7 +41,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/common_runtime/local_device.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -285,6 +284,38 @@ BaseGPUDevice::~BaseGPUDevice() {
for (auto ctx : device_contexts_) ctx->Unref();
}
+// This should be idempotent if already initialized.
+Status BaseGPUDevice::InitScratchBuffers() {
+ mutex_lock l(scratch_init_mutex_);
+ if (scratch_.size() < max_streams_) {
+ for (int i = 0; i < max_streams_; i++) {
+ DCHECK(streams_[i]);
+ if (scratch_.size() > i && scratch_[i]) continue;
+ size_t scratch_buffer_size =
+ Eigen::kCudaScratchSize + sizeof(unsigned int);
+ void* scratch_buffer = gpu_allocator_->AllocateRaw(
+ Allocator::kAllocatorAlignment, scratch_buffer_size);
+ if (scratch_buffer == nullptr) {
+ return errors::FailedPrecondition(
+ "Failed to allocate scratch buffer for device ",
+ tf_gpu_id_.value());
+ }
+ se::DeviceMemory<char> mem(
+ se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size));
+
+ bool ok = executor_->SynchronousMemZero(
+ &mem, Eigen::kCudaScratchSize + sizeof(unsigned int));
+ if (!ok) {
+ return errors::FailedPrecondition(
+ "Failed to memcopy into scratch buffer for device ",
+ tf_gpu_id_.value());
+ }
+ scratch_.push_back(static_cast<char*>(scratch_buffer));
+ }
+ }
+ return Status::OK();
+}
+
Status BaseGPUDevice::Init(const SessionOptions& options) {
auto executor_status = GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id_);
if (!executor_status.status().ok()) {
@@ -303,27 +334,6 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
for (int i = 0; i < max_streams_; i++) {
streams_.push_back(StreamGroupFactory::Global().GetOrCreate(
tf_gpu_id_, i, executor_, options.config.gpu_options()));
-
- size_t scratch_buffer_size = Eigen::kCudaScratchSize + sizeof(unsigned int);
- void* scratch_buffer = gpu_allocator_->AllocateRaw(
- Allocator::kAllocatorAlignment, scratch_buffer_size);
- if (scratch_buffer == nullptr) {
- return errors::FailedPrecondition(
- "Failed to allocate scratch buffer for device ", tf_gpu_id_.value());
- }
- scratch_.push_back(static_cast<char*>(scratch_buffer));
-
- se::DeviceMemory<char> mem(
- se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size));
-
- bool ok = executor_->SynchronousMemZero(
- &mem, Eigen::kCudaScratchSize + sizeof(unsigned int));
- if (!ok) {
- return errors::FailedPrecondition(
- "Failed to memcopy into scratch buffer for device ",
- tf_gpu_id_.value());
- }
-
device_contexts_.push_back(new GPUDeviceContext(
i, streams_.back()->compute, streams_.back()->host_to_device,
streams_.back()->device_to_host, streams_.back()->device_to_device));
@@ -870,10 +880,11 @@ PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice() {
return new ConcretePerOpGpuDevice();
}
-void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
- PerOpGpuDevice* device,
- DeviceContext* dc,
- Allocator* allocator) {
+Status BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
+ PerOpGpuDevice* device,
+ DeviceContext* dc,
+ Allocator* allocator) {
+ TF_RETURN_IF_ERROR(InitScratchBuffers());
if (dc) {
const GPUDeviceContext* gpu_dc = static_cast<GPUDeviceContext*>(dc);
const int stream_id = gpu_dc->stream_id();
@@ -884,6 +895,7 @@ void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
} else {
ReinitializeDevice(context, device, 0, allocator);
}
+ return Status::OK();
}
Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr,
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index 684cc0c1de..b25fe8645f 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -86,8 +86,9 @@ class BaseGPUDevice : public LocalDevice {
// The caller owns the returned device.
PerOpGpuDevice* MakeGpuDevice() override;
- void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
- DeviceContext* dc, Allocator* allocator) override;
+ Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
+ DeviceContext* dc,
+ Allocator* allocator) override;
// Returns the platform GPU id of this device within the native driver system;
// e.g., for CUDA this is the ordinal of the GPU within the system.
@@ -125,6 +126,7 @@ class BaseGPUDevice : public LocalDevice {
class StreamGroupFactory;
gtl::InlinedVector<StreamGroup*, 4> streams_;
+ mutex scratch_init_mutex_;
gtl::InlinedVector<char*, 4> scratch_;
std::vector<GPUDeviceContext*> device_contexts_;
GpuDeviceInfo* gpu_device_info_ = nullptr;
@@ -135,6 +137,9 @@ class BaseGPUDevice : public LocalDevice {
std::unique_ptr<EventMgr> em_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
+ // Initialize scractch buffers used by Eigen.
+ Status InitScratchBuffers();
+
void ReinitializeDevice(OpKernelContext* context, PerOpGpuDevice* device,
int stream_id, Allocator* allocator);
diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
index a5b46382f1..3e95374fda 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
@@ -76,12 +76,16 @@ GPUProcessState::GPUProcessState() : gpu_device_enabled_(false) {
// This function is defined for debugging problems with the allocators.
GPUProcessState::~GPUProcessState() {
CHECK_EQ(this, instance_);
- for (auto p : gpu_allocators_) {
- delete p;
- }
instance_ = nullptr;
}
+int GPUProcessState::BusIdForGPU(TfGpuId tf_gpu_id) {
+ // Return the NUMA node associated with the GPU's StreamExecutor.
+ se::StreamExecutor* se =
+ GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie();
+ return se->GetDeviceDescription().numa_node();
+}
+
Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
TfGpuId tf_gpu_id,
size_t total_bytes) {
@@ -93,13 +97,10 @@ Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
gpu_allocators_.resize(tf_gpu_id.value() + 1);
- if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
- gpu_al_.resize(tf_gpu_id.value() + 1);
}
- if (gpu_allocators_[tf_gpu_id.value()] == nullptr) {
- VisitableAllocator* gpu_allocator;
-
+ AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()];
+ if (allocator_parts.allocator.get() == nullptr) {
// Validate allocator types.
if (!allocator_type.empty() && allocator_type != "BFC") {
LOG(ERROR) << "Invalid allocator type: " << allocator_type;
@@ -108,8 +109,18 @@ Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
PlatformGpuId platform_gpu_id;
TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
- gpu_allocator =
- new GPUBFCAllocator(platform_gpu_id, total_bytes, options,
+ int bus_id = BusIdForGPU(tf_gpu_id);
+ while (bus_id >= gpu_visitors_.size()) {
+ gpu_visitors_.push_back({});
+ }
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id,
+ (options.per_process_gpu_memory_fraction() > 1.0 ||
+ options.experimental().use_unified_memory()),
+ gpu_visitors_[bus_id], {});
+ Allocator* gpu_allocator =
+ new GPUBFCAllocator(sub_allocator, total_bytes, options,
strings::StrCat("GPU_", tf_gpu_id.value(), "_bfc"));
// If true, checks for memory overwrites by writing
@@ -124,34 +135,25 @@ Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
gpu_allocator =
new GPUcudaMallocAllocator(gpu_allocator, platform_gpu_id);
}
- gpu_allocators_[tf_gpu_id.value()] = gpu_allocator;
-
- // If there are any pending AllocVisitors for this bus, add
- // them now.
- se::StreamExecutor* se =
- GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie();
- int bus_id = se->GetDeviceDescription().numa_node();
- if (bus_id >= 0 && bus_id < static_cast<int64>(gpu_visitors_.size())) {
- for (const auto& v : gpu_visitors_[bus_id]) {
- gpu_allocator->AddAllocVisitor(v);
- }
- }
+
+ Allocator* recording_allocator = nullptr;
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
ProcessState::MemDesc md;
md.loc = ProcessState::MemDesc::GPU;
md.dev_index = platform_gpu_id.value();
md.gpu_registered = false;
md.nic_registered = true;
- if (static_cast<int64>(gpu_al_.size()) <= tf_gpu_id.value()) {
- gpu_al_.resize(tf_gpu_id.value() + 1);
- }
- gpu_al_[tf_gpu_id.value()] = new internal::RecordingAllocator(
+ recording_allocator = new internal::RecordingAllocator(
&process_state_->mem_desc_map_, gpu_allocator, md, &mu_);
}
+ allocator_parts = {std::unique_ptr<Allocator>(gpu_allocator), sub_allocator,
+ std::unique_ptr<Allocator>(recording_allocator)};
+ }
+ if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
+ return allocator_parts.recording_allocator.get();
+ } else {
+ return allocator_parts.allocator.get();
}
- if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
- return gpu_al_[tf_gpu_id.value()];
- return gpu_allocators_[tf_gpu_id.value()];
#else
LOG(FATAL) << "GPUAllocator unavailable. Not compiled with --config=cuda.";
return nullptr;
@@ -173,11 +175,12 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
tf_shared_lock lock(mu_);
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types &&
- static_cast<int>(cuda_al_.size()) > 0) {
- return cuda_al_[0];
+ !cuda_host_allocators_.empty() &&
+ cuda_host_allocators_[0].recording_allocator != nullptr) {
+ return cuda_host_allocators_[0].recording_allocator.get();
}
if (static_cast<int>(cuda_host_allocators_.size()) > numa_node) {
- return cuda_host_allocators_[0];
+ return cuda_host_allocators_[0].allocator.get();
}
}
@@ -191,7 +194,7 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
// it knows is valid.
se::StreamExecutor* se = nullptr;
for (int i = 0; i < static_cast<int>(gpu_allocators_.size()); ++i) {
- if (gpu_allocators_[i] != nullptr) {
+ if (gpu_allocators_[i].allocator != nullptr) {
se = GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie();
break;
}
@@ -200,6 +203,15 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
CHECK_NE(nullptr, se);
while (static_cast<int>(cuda_host_allocators_.size()) <= numa_node) {
+ while (cuda_host_alloc_visitors_.size() <= numa_node) {
+ cuda_host_alloc_visitors_.push_back({});
+ }
+ while (cuda_host_free_visitors_.size() <= numa_node) {
+ cuda_host_free_visitors_.push_back({});
+ }
+ SubAllocator* sub_allocator = new CUDAHostAllocator(
+ se, numa_node, cuda_host_alloc_visitors_[numa_node],
+ cuda_host_free_visitors_[numa_node]);
// TODO(zheng-xq): evaluate whether 64GB by default is the best choice.
int64 cuda_host_mem_limit_in_mb = -1;
Status status = ReadInt64FromEnvVar("TF_CUDA_HOST_MEM_LIMIT_IN_MB",
@@ -209,62 +221,92 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
LOG(ERROR) << "GetCUDAHostAllocator: " << status.error_message();
}
int64 cuda_host_mem_limit = cuda_host_mem_limit_in_mb * (1LL << 20);
- VisitableAllocator* allocator =
- new BFCAllocator(new CUDAHostAllocator(se), cuda_host_mem_limit,
+ Allocator* allocator =
+ new BFCAllocator(sub_allocator, cuda_host_mem_limit,
true /*allow_growth*/, "cuda_host_bfc" /*name*/);
- if (LogMemory::IsEnabled()) {
+ if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) {
// Wrap the allocator to track allocation ids for better logging
// at the cost of performance.
- allocator = new TrackingVisitableAllocator(allocator, true);
+ allocator = new TrackingAllocator(allocator, true);
}
- cuda_host_allocators_.push_back(allocator);
+ cuda_host_allocators_.push_back({std::unique_ptr<Allocator>(allocator),
+ sub_allocator,
+ std::unique_ptr<Allocator>(nullptr)});
+ AllocatorParts& allocator_parts = cuda_host_allocators_.back();
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
ProcessState::MemDesc md;
md.loc = ProcessState::MemDesc::CPU;
md.dev_index = 0;
md.gpu_registered = true;
md.nic_registered = false;
- cuda_al_.push_back(new internal::RecordingAllocator(
- &process_state_->mem_desc_map_, cuda_host_allocators_.back(), md,
- &mu_));
+ allocator_parts.recording_allocator.reset(
+ new internal::RecordingAllocator(&process_state_->mem_desc_map_,
+ allocator_parts.allocator.get(), md,
+ &mu_));
}
}
- if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
- return cuda_al_[0];
- return cuda_host_allocators_[0];
+ if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
+ return cuda_host_allocators_[0].recording_allocator.get();
+ } else {
+ return cuda_host_allocators_[0].allocator.get();
+ }
}
void GPUProcessState::AddGPUAllocVisitor(int bus_id,
- const AllocVisitor& visitor) {
- CHECK(process_state_);
+ const SubAllocator::Visitor& visitor) {
#if GOOGLE_CUDA
mutex_lock lock(mu_);
- for (int i = 0; i < static_cast<int64>(gpu_allocators_.size()); ++i) {
- se::StreamExecutor* se =
- GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie();
- if (gpu_allocators_[i] &&
- (se->GetDeviceDescription().numa_node() + 1) == bus_id) {
- gpu_allocators_[i]->AddAllocVisitor(visitor);
- }
- }
+ CHECK(gpu_allocators_.empty()) // Crash OK
+ << "AddGPUAllocVisitor must be called before "
+ "first call to GetGPUAllocator.";
while (bus_id >= static_cast<int64>(gpu_visitors_.size())) {
- gpu_visitors_.push_back(std::vector<AllocVisitor>());
+ gpu_visitors_.push_back(std::vector<SubAllocator::Visitor>());
}
gpu_visitors_[bus_id].push_back(visitor);
#endif // GOOGLE_CUDA
}
+void GPUProcessState::AddCUDAHostAllocVisitor(
+ int numa_node, const SubAllocator::Visitor& visitor) {
+#if GOOGLE_CUDA
+ mutex_lock lock(mu_);
+ CHECK(cuda_host_allocators_.empty()) // Crash OK
+ << "AddCUDAHostAllocVisitor must be called before "
+ "first call to GetCUDAHostAllocator.";
+ while (numa_node >= static_cast<int64>(cuda_host_alloc_visitors_.size())) {
+ cuda_host_alloc_visitors_.push_back(std::vector<SubAllocator::Visitor>());
+ }
+ cuda_host_alloc_visitors_[numa_node].push_back(visitor);
+#endif // GOOGLE_CUDA
+}
+
+void GPUProcessState::AddCUDAHostFreeVisitor(
+ int numa_node, const SubAllocator::Visitor& visitor) {
+#if GOOGLE_CUDA
+ mutex_lock lock(mu_);
+ CHECK(cuda_host_allocators_.empty()) // Crash OK
+ << "AddCUDAHostFreeVisitor must be called before "
+ "first call to GetCUDAHostAllocator.";
+ while (numa_node >= static_cast<int64>(cuda_host_free_visitors_.size())) {
+ cuda_host_free_visitors_.push_back(std::vector<SubAllocator::Visitor>());
+ }
+ cuda_host_free_visitors_[numa_node].push_back(visitor);
+#endif // GOOGLE_CUDA
+}
+
void GPUProcessState::TestOnlyReset() {
- process_state_->ProcessState::TestOnlyReset();
+ if (process_state_) {
+ process_state_->ProcessState::TestOnlyReset();
+ }
{
mutex_lock lock(mu_);
gpu_device_enabled_ = false;
+ gpu_allocators_.clear();
gpu_visitors_.clear();
- gtl::STLDeleteElements(&gpu_allocators_);
- gtl::STLDeleteElements(&cuda_host_allocators_);
- gtl::STLDeleteElements(&gpu_al_);
- gtl::STLDeleteElements(&cuda_al_);
+ cuda_host_allocators_.clear();
+ cuda_host_alloc_visitors_.clear();
+ cuda_host_free_visitors_.clear();
}
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.h b/tensorflow/core/common_runtime/gpu/gpu_process_state.h
index cb41c3c6bd..43e9a31660 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_process_state.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.h
@@ -32,7 +32,6 @@ limitations under the License.
namespace tensorflow {
class Allocator;
-class VisitableAllocator;
class PoolAllocator;
// Singleton that manages per-process state when GPUs are present.
@@ -72,18 +71,30 @@ class GPUProcessState {
virtual Allocator* GetCUDAHostAllocator(int numa_node);
- // Registers a function to be called once on every new Region
- // allocated by every GPURegionAllocator proximate to the specified
- // bus. The AllocVisitor is provided with a memory pointer and the
- // size of the area it identifies. The pointer is not guaranteed to
- // be valid after the call terminates. The intention is for this
- // interface to be used for network device memory registration.
- // "bus_id" is platform-specific. On many platforms it
- // should be 0. On machines with multiple PCIe buses, it should be
- // the index of one of the PCIe buses. If the bus_id is invalid,
- // results are undefined.
- typedef std::function<void(void*, size_t)> AllocVisitor;
- virtual void AddGPUAllocVisitor(int bus_id, const AllocVisitor& visitor);
+ // Registers a Visitor to be invoked on new chunks of memory allocated by the
+ // SubAllocator of every GPU proximate to the specified bus. The AllocVisitor
+ // is provided with a memory pointer, a GPU id, and the size of the area it
+ // identifies. The pointer is not guaranteed to be valid after the call
+ // terminates. The intention is for this interface to be used for network
+ // device memory registration. "bus_id" is platform-specific. On many
+ // platforms it should be 0. On machines with multiple PCIe buses, it should
+ // be the index of one of the PCIe buses (maybe the NUMA node at which the
+ // PCIe is rooted). If the bus_id is invalid, results are undefined.
+ virtual void AddGPUAllocVisitor(int bus_id,
+ const SubAllocator::Visitor& visitor);
+
+ // Registers a Visitor to be invoked on new chunks of memory allocated by
+ // the SubAllocator of the CUDAHostAllocator for the given numa_node.
+ virtual void AddCUDAHostAllocVisitor(int numa_node,
+ const SubAllocator::Visitor& visitor);
+
+ // Registers a Visitor to be invoked on each chunk handed back for freeing to
+ // the SubAllocator of the CUDAHostAllocator for the given numa_node.
+ virtual void AddCUDAHostFreeVisitor(int numa_node,
+ const SubAllocator::Visitor& visitor);
+
+ // Returns bus_id for the given GPU id.
+ virtual int BusIdForGPU(TfGpuId tf_gpu_id);
protected:
GPUProcessState();
@@ -103,16 +114,21 @@ class GPUProcessState {
mutex mu_;
- std::vector<VisitableAllocator*> gpu_allocators_ GUARDED_BY(mu_);
- std::vector<std::vector<AllocVisitor>> gpu_visitors_ GUARDED_BY(mu_);
- std::vector<Allocator*> cuda_host_allocators_ GUARDED_BY(mu_);
+ struct AllocatorParts {
+ std::unique_ptr<Allocator> allocator;
+ SubAllocator* sub_allocator; // owned by allocator
+ std::unique_ptr<Allocator> recording_allocator;
+ };
+ std::vector<AllocatorParts> gpu_allocators_ GUARDED_BY(mu_);
+ std::vector<std::vector<SubAllocator::Visitor>> gpu_visitors_ GUARDED_BY(mu_);
- virtual ~GPUProcessState();
+ std::vector<AllocatorParts> cuda_host_allocators_ GUARDED_BY(mu_);
+ std::vector<std::vector<SubAllocator::Visitor>> cuda_host_alloc_visitors_
+ GUARDED_BY(mu_);
+ std::vector<std::vector<SubAllocator::Visitor>> cuda_host_free_visitors_
+ GUARDED_BY(mu_);
- // Optional RecordingAllocators that wrap the corresponding
- // Allocators for runtime attribute use analysis.
- std::vector<Allocator*> gpu_al_ GUARDED_BY(mu_);
- std::vector<Allocator*> cuda_al_ GUARDED_BY(mu_);
+ virtual ~GPUProcessState();
friend class GPUDeviceTest;
};
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
index 583bff2c07..6b2f6547b0 100644
--- a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
@@ -31,7 +31,8 @@ TEST(PoolAllocatorTest, ZeroSizeBuffers) {
2 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
+ .ValueOrDie(),
+ 0 /*numa_node*/, {}, {}),
new NoopRounder, "pool");
EXPECT_EQ(nullptr, pool.AllocateRaw(4 /*alignment*/, 0 /*num_bytes*/));
@@ -49,7 +50,8 @@ TEST(PoolAllocatorTest, ZeroSizePool) {
0 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
+ .ValueOrDie(),
+ 0 /*numa_node*/, {}, {}),
new NoopRounder, "pool");
EXPECT_EQ(0, pool.get_from_pool_count());
@@ -82,7 +84,8 @@ TEST(PoolAllocatorTest, Alignment) {
0 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
+ .ValueOrDie(),
+ 0 /*numa_node*/, {}, {}),
new NoopRounder, "pool");
for (int i = 0; i < 16; ++i) {
size_t alignment = 1 << i;
@@ -97,8 +100,8 @@ TEST(PoolAllocatorTest, Alignment) {
TEST(PoolAllocatorTest, AutoResize) {
PoolAllocator pool(2 /*pool_size_limit*/, true /*auto_resize*/,
- new BasicCPUAllocator(0 /*numa_node*/), new NoopRounder,
- "pool");
+ new BasicCPUAllocator(0 /*numa_node*/, {}, {}),
+ new NoopRounder, "pool");
// Alloc/dealloc 10 sizes just a few times, confirming pool size
// stays at 2.
@@ -123,14 +126,32 @@ TEST(PoolAllocatorTest, AutoResize) {
}
TEST(PoolAllocatorTest, CudaHostAllocator) {
+ int alloc_count = 0;
+ int64 alloc_size = 0;
+ SubAllocator::Visitor alloc_visitor =
+ [&alloc_count, &alloc_size](void* ptr, int numa_node, int64 size) {
+ ++alloc_count;
+ alloc_size += size;
+ };
+ int free_count = 0;
+ int64 free_size = 0;
+ SubAllocator::Visitor free_visitor =
+ [&free_count, &free_size](void* ptr, int numa_node, int64 size) {
+ ++free_count;
+ free_size += size;
+ };
se::Platform* platform =
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
- PoolAllocator pool(
- 2 /*pool_size_limit*/, false /*auto_resize*/,
- new CUDAHostAllocator(
- platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
- new NoopRounder, "pool");
+ CUDAHostAllocator* sub_allocator = new CUDAHostAllocator(
+ platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
+ .ValueOrDie(),
+ 0 /*numa_node*/, {alloc_visitor}, {free_visitor});
+ PoolAllocator pool(2 /*pool_size_limit*/, false /*auto_resize*/,
+ sub_allocator, new NoopRounder, "pool");
+ EXPECT_EQ(0, alloc_count);
+ EXPECT_EQ(0, alloc_size);
+ EXPECT_EQ(0, free_count);
+ EXPECT_EQ(0, free_size);
// Repeatedly Get a 16-byte value, confirming that there's only
// one real allocation.
@@ -138,6 +159,10 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
EXPECT_EQ(0, pool.get_from_pool_count());
EXPECT_EQ(1, pool.allocated_count());
EXPECT_NE(nullptr, p1_16);
+ EXPECT_EQ(1, alloc_count); // Underlying suballoc of 16 bytes
+ // Each suballocation includes a 16B ChunkPrefix.
+ static const int kChunkPrefixSize = 16;
+ EXPECT_EQ(16 + (alloc_count * kChunkPrefixSize), alloc_size);
pool.DeallocateRaw(p1_16);
// Pool contents {16}
EXPECT_EQ(1, pool.put_count());
@@ -148,6 +173,9 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
pool.DeallocateRaw(p2_16); // Put it back.
// Pool contents {16}
EXPECT_EQ(2, pool.put_count());
+ EXPECT_EQ(1, alloc_count); // Underlying suballoc of 16 bytes
+ EXPECT_EQ(16 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(0, free_count);
// Get two more values of different sizes.
void* p3_4 = pool.AllocateRaw(4, 4);
@@ -160,6 +188,9 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
void* p4_2 = pool.AllocateRaw(4, 2); // Get a third size buffer.
EXPECT_NE(nullptr, p4_2);
EXPECT_EQ(0, pool.evicted_count());
+ EXPECT_EQ(3, alloc_count);
+ EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(0, free_count);
// The pool is full: when we put back p4_2, the 16-byte buffer
// should be evicted since it was least recently inserted.
@@ -167,6 +198,10 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
// Pool contents {2, 4}
EXPECT_EQ(4, pool.put_count());
EXPECT_EQ(1, pool.evicted_count());
+ EXPECT_EQ(3, alloc_count);
+ EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(1, free_count);
+ EXPECT_EQ(16 + (free_count * kChunkPrefixSize), free_size);
// Re-getting and putting size 2 or 4 should not alter pool size or
// num-evicted.
@@ -180,12 +215,20 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
EXPECT_EQ(6, pool.put_count());
EXPECT_EQ(3, pool.allocated_count());
EXPECT_EQ(1, pool.evicted_count());
+ EXPECT_EQ(3, alloc_count);
+ EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(1, free_count);
+ EXPECT_EQ(16 + (free_count * kChunkPrefixSize), free_size);
pool.Clear();
EXPECT_EQ(0, pool.get_from_pool_count());
EXPECT_EQ(0, pool.put_count());
EXPECT_EQ(0, pool.allocated_count());
EXPECT_EQ(0, pool.evicted_count());
+ EXPECT_EQ(3, alloc_count);
+ EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(3, free_count);
+ EXPECT_EQ(16 + 4 + 2 + (free_count * kChunkPrefixSize), free_size);
}
TEST(PoolAllocatorTest, Pow2Rounder) {
@@ -206,7 +249,8 @@ TEST(PoolAllocatorTest, Name) {
2 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
+ .ValueOrDie(),
+ 0 /*numa_node*/, {}, {}),
new NoopRounder, "pool");
EXPECT_EQ("pool", pool.Name());
}
diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index 346befc255..4475fa979e 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_execution_state.h"
#include <memory>
+#include <set>
#include <string>
#include <unordered_set>
#include <utility>
@@ -560,6 +561,10 @@ Status GraphExecutionState::OptimizeGraph(
grappler::GrapplerItem item;
item.id = "tf_graph";
graph_->ToGraphDef(&item.graph);
+ // TODO(b/114748242): Add a unit test to test this bug fix.
+ if (flib_def_) {
+ *item.graph.mutable_library() = flib_def_->ToProto();
+ }
item.fetch.insert(item.fetch.end(),
options.callable_options.fetch().begin(),
@@ -727,12 +732,50 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
+ int64 collective_graph_key = options.collective_graph_key;
+ if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
+ // BuildGraphOptions does not specify a collective_graph_key. Check all
+ // nodes in the Graph and FunctionLibraryDefinition for collective ops and
+ // if found, initialize a collective_graph_key as a hash of the ordered set
+ // of instance keys.
+ std::set<int32> instance_key_set;
+ for (Node* node : optimized_graph->nodes()) {
+ if (node->IsCollective()) {
+ int32 instance_key;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node->attrs(), "instance_key", &instance_key));
+ instance_key_set.emplace(instance_key);
+ } else {
+ const FunctionDef* fdef = optimized_flib->Find(node->def().op());
+ if (fdef != nullptr) {
+ for (const NodeDef& ndef : fdef->node_def()) {
+ if (ndef.op() == "CollectiveReduce" ||
+ ndef.op() == "CollectiveBcastSend" ||
+ ndef.op() == "CollectiveBcastRecv") {
+ int32 instance_key;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(ndef, "instance_key", &instance_key));
+ instance_key_set.emplace(instance_key);
+ }
+ }
+ }
+ }
+ }
+ if (!instance_key_set.empty()) {
+ uint64 hash = 0x8774aa605c729c72ULL;
+ for (int32 instance_key : instance_key_set) {
+ hash = Hash64Combine(instance_key, hash);
+ }
+ collective_graph_key = hash;
+ }
+ }
+
// Copy the extracted graph in order to make its node ids dense,
// since the local CostModel used to record its stats is sized by
// the largest node id.
std::unique_ptr<ClientGraph> dense_copy(
new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types,
- rewrite_metadata.fetch_types));
+ rewrite_metadata.fetch_types, collective_graph_key));
CopyGraph(*optimized_graph, &dense_copy->graph);
// TODO(vrv): We should check invariants of the graph here.
diff --git a/tensorflow/core/common_runtime/graph_execution_state.h b/tensorflow/core/common_runtime/graph_execution_state.h
index d44a24c87b..9cabe478a6 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.h
+++ b/tensorflow/core/common_runtime/graph_execution_state.h
@@ -50,17 +50,20 @@ struct GraphExecutionStateOptions {
// BuildGraphOptions.
struct ClientGraph {
explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
- DataTypeVector feed_types, DataTypeVector fetch_types)
+ DataTypeVector feed_types, DataTypeVector fetch_types,
+ int64 collective_graph_key)
: flib_def(std::move(flib)),
graph(flib_def.get()),
feed_types(std::move(feed_types)),
- fetch_types(std::move(fetch_types)) {}
+ fetch_types(std::move(fetch_types)),
+ collective_graph_key(collective_graph_key) {}
// Each client-graph gets its own function library since optimization passes
// post rewrite for execution might want to introduce new functions.
std::unique_ptr<FunctionLibraryDefinition> flib_def;
Graph graph;
DataTypeVector feed_types;
DataTypeVector fetch_types;
+ int64 collective_graph_key;
};
// GraphExecutionState is responsible for generating an
diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc
index 0a1797fa19..f9aef3af70 100644
--- a/tensorflow/core/common_runtime/graph_runner.cc
+++ b/tensorflow/core/common_runtime/graph_runner.cc
@@ -56,7 +56,7 @@ class SimpleRendezvous : public Rendezvous {
}
mutex_lock l(mu_);
- string edge_name = std::string(parsed.edge_name);
+ string edge_name(parsed.edge_name);
if (table_.count(edge_name) > 0) {
return errors::Internal("Send of an already sent tensor");
}
@@ -69,7 +69,7 @@ class SimpleRendezvous : public Rendezvous {
Tensor tensor;
Status status = Status::OK();
{
- string key = std::string(parsed.edge_name);
+ string key(parsed.edge_name);
mutex_lock l(mu_);
if (table_.count(key) <= 0) {
status = errors::Internal("Did not find key ", key);
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index 6b76e7e0e7..538a70668a 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -23,10 +23,11 @@ limitations under the License.
#include <cstdlib>
#include "tensorflow/core/common_runtime/bfc_allocator.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/common_runtime/pool_allocator.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mem.h"
+#include "tensorflow/core/platform/numa.h"
#ifndef INTEL_MKL_DNN_ONLY
#include "i_malloc.h"
@@ -38,19 +39,113 @@ typedef unsigned int uint;
namespace tensorflow {
-class MklSubAllocator : public SubAllocator {
+class MklSubAllocator : public BasicCPUAllocator {
public:
+ MklSubAllocator() : BasicCPUAllocator(port::kNUMANoAffinity, {}, {}) {}
~MklSubAllocator() override {}
+};
+
+// CPU allocator that handles small-size allocations by calling
+// suballocator directly. Mostly, it is just a wrapper around a suballocator
+// (that calls malloc and free directly) with support for bookkeeping.
+class MklSmallSizeAllocator : public Allocator {
+ public:
+ MklSmallSizeAllocator(SubAllocator* sub_allocator, size_t total_memory,
+ const string& name)
+ : sub_allocator_(sub_allocator), name_(name) {
+ stats_.bytes_limit = total_memory;
+ }
+ ~MklSmallSizeAllocator() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(MklSmallSizeAllocator);
+
+ inline string Name() override { return name_; }
+
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override {
+ void* ptr = sub_allocator_->Alloc(alignment, num_bytes);
+ if (ptr != nullptr) {
+ std::pair<void*, size_t> map_val(ptr, num_bytes);
+ mutex_lock l(mutex_);
+ // Check that insertion in the hash map was successful.
+ CHECK(map_.insert(map_val).second);
+ // Increment statistics for small-size allocations.
+ IncrementStats(num_bytes);
+ }
+ return ptr;
+ }
+
+ void DeallocateRaw(void* ptr) override {
+ if (ptr == nullptr) {
+ LOG(ERROR) << "tried to deallocate nullptr";
+ return;
+ }
+
+ mutex_lock l(mutex_);
+ auto map_iter = map_.find(ptr);
+ if (map_iter != map_.end()) {
+ // Call free visitors.
+ size_t dealloc_bytes = map_iter->second;
+ sub_allocator_->Free(ptr, dealloc_bytes);
+ DecrementStats(dealloc_bytes);
+ map_.erase(map_iter);
+ } else {
+ LOG(ERROR) << "tried to deallocate invalid pointer";
+ return;
+ }
+ }
+
+ inline bool IsSmallSizeAllocation(const void* ptr) const {
+ mutex_lock l(mutex_);
+ return map_.find(ptr) != map_.end();
+ }
+
+ void GetStats(AllocatorStats* stats) override {
+ mutex_lock l(mutex_);
+ *stats = stats_;
+ }
- void* Alloc(size_t alignment, size_t num_bytes) override {
- return port::AlignedMalloc(num_bytes, alignment);
+ void ClearStats() override {
+ mutex_lock l(mutex_);
+ stats_.Clear();
}
- void Free(void* ptr, size_t num_bytes) override { port::AlignedFree(ptr); }
+
+ private:
+ // Increment statistics for the allocator handling small allocations.
+ inline void IncrementStats(size_t alloc_size)
+ EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
+ ++stats_.num_allocs;
+ stats_.bytes_in_use += alloc_size;
+ stats_.max_bytes_in_use =
+ std::max(stats_.max_bytes_in_use, stats_.bytes_in_use);
+ stats_.max_alloc_size =
+ std::max(alloc_size, static_cast<size_t>(stats_.max_alloc_size));
+ }
+
+ // Decrement statistics for the allocator handling small allocations.
+ inline void DecrementStats(size_t dealloc_size)
+ EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
+ stats_.bytes_in_use -= dealloc_size;
+ }
+
+ SubAllocator* sub_allocator_; // Not owned by this class.
+
+ // Mutex for protecting updates to map of allocations.
+ mutable mutex mutex_;
+
+ // Allocator name
+ string name_;
+
+ // Hash map to keep track of "small" allocations
+ // We do not use BFC allocator for small allocations.
+ std::unordered_map<const void*, size_t> map_ GUARDED_BY(mutex_);
+
+ // Allocator stats for small allocs
+ AllocatorStats stats_ GUARDED_BY(mutex_);
};
/// CPU allocator for MKL that wraps BFC allocator and intercepts
/// and redirects memory allocation calls from MKL.
-class MklCPUAllocator : public VisitableAllocator {
+class MklCPUAllocator : public Allocator {
public:
// Constructor and other standard functions
@@ -62,7 +157,10 @@ class MklCPUAllocator : public VisitableAllocator {
MklCPUAllocator() { TF_CHECK_OK(Initialize()); }
- ~MklCPUAllocator() override { delete allocator_; }
+ ~MklCPUAllocator() override {
+ delete small_size_allocator_;
+ delete large_size_allocator_;
+ }
Status Initialize() {
VLOG(2) << "MklCPUAllocator: In MklCPUAllocator";
@@ -96,8 +194,15 @@ class MklCPUAllocator : public VisitableAllocator {
}
VLOG(1) << "MklCPUAllocator: Setting max_mem_bytes: " << max_mem_bytes;
- allocator_ = new BFCAllocator(new MklSubAllocator, max_mem_bytes,
- kAllowGrowth, kName);
+
+ sub_allocator_ = new MklSubAllocator();
+
+ // SubAllocator is owned by BFCAllocator, so we do not need to deallocate
+ // it in MklSmallSizeAllocator.
+ small_size_allocator_ =
+ new MklSmallSizeAllocator(sub_allocator_, max_mem_bytes, kName);
+ large_size_allocator_ =
+ new BFCAllocator(sub_allocator_, max_mem_bytes, kAllowGrowth, kName);
#ifndef INTEL_MKL_DNN_ONLY
// For redirecting all allocations from MKL to this allocator
// From: http://software.intel.com/en-us/node/528565
@@ -112,23 +217,45 @@ class MklCPUAllocator : public VisitableAllocator {
inline string Name() override { return kName; }
inline void* AllocateRaw(size_t alignment, size_t num_bytes) override {
- return allocator_->AllocateRaw(alignment, num_bytes);
+ // If the allocation size is less than threshold, call small allocator,
+ // otherwise call large-size allocator (BFC). We found that BFC allocator
+ // does not deliver good performance for small allocations when
+ // inter_op_parallelism_threads is high.
+ return (num_bytes < kSmallAllocationsThreshold)
+ ? small_size_allocator_->AllocateRaw(alignment, num_bytes)
+ : large_size_allocator_->AllocateRaw(alignment, num_bytes);
}
inline void DeallocateRaw(void* ptr) override {
- allocator_->DeallocateRaw(ptr);
+ // Check if ptr is for "small" allocation. If it is, then call Free
+ // directly. Otherwise, call BFC to handle free.
+ if (small_size_allocator_->IsSmallSizeAllocation(ptr)) {
+ small_size_allocator_->DeallocateRaw(ptr);
+ } else {
+ large_size_allocator_->DeallocateRaw(ptr);
+ }
}
- void GetStats(AllocatorStats* stats) override { allocator_->GetStats(stats); }
-
- void ClearStats() override { allocator_->ClearStats(); }
-
- void AddAllocVisitor(Visitor visitor) override {
- allocator_->AddAllocVisitor(visitor);
+ void GetStats(AllocatorStats* stats) override {
+ AllocatorStats l_stats, s_stats;
+ small_size_allocator_->GetStats(&s_stats);
+ large_size_allocator_->GetStats(&l_stats);
+
+ // Combine statistics from small-size and large-size allocator.
+ stats->num_allocs = l_stats.num_allocs + s_stats.num_allocs;
+ stats->bytes_in_use = l_stats.bytes_in_use + s_stats.bytes_in_use;
+ stats->max_bytes_in_use =
+ l_stats.max_bytes_in_use + s_stats.max_bytes_in_use;
+
+ // Since small-size allocations go to MklSmallSizeAllocator,
+ // max_alloc_size from large_size_allocator would be the maximum
+ // size allocated by MklCPUAllocator.
+ stats->max_alloc_size = l_stats.max_alloc_size;
}
- void AddFreeVisitor(Visitor visitor) override {
- allocator_->AddFreeVisitor(visitor);
+ void ClearStats() override {
+ small_size_allocator_->ClearStats();
+ large_size_allocator_->ClearStats();
}
private:
@@ -148,26 +275,33 @@ class MklCPUAllocator : public VisitableAllocator {
Status s = Status(error::Code::UNIMPLEMENTED,
"Unimplemented case for hooking MKL function.");
TF_CHECK_OK(s); // way to assert with an error message
- return nullptr; // return a value and make static code analyzers happy
+ return nullptr; // return a value and make static code analyzers happy
}
static inline void* ReallocHook(void* ptr, size_t size) {
Status s = Status(error::Code::UNIMPLEMENTED,
"Unimplemented case for hooking MKL function.");
TF_CHECK_OK(s); // way to assert with an error message
- return nullptr; // return a value and make static code analyzers happy
+ return nullptr; // return a value and make static code analyzers happy
}
- /// Do we allow growth in BFC Allocator
+ // Do we allow growth in BFC Allocator
static const bool kAllowGrowth = true;
- /// Name
+ // Name
static constexpr const char* kName = "mklcpu";
- /// The alignment that we need for the allocations
+ // The alignment that we need for the allocations
static constexpr const size_t kAlignment = 64;
- VisitableAllocator* allocator_; // owned by this class
+ Allocator* large_size_allocator_; // owned by this class
+ MklSmallSizeAllocator* small_size_allocator_; // owned by this class.
+
+ SubAllocator* sub_allocator_; // not owned by this class
+
+ // Size in bytes that defines the upper-bound for "small" allocations.
+ // Any allocation below this threshold is "small" allocation.
+ static constexpr const size_t kSmallAllocationsThreshold = 4096;
// Prevent copying and assignment
TF_DISALLOW_COPY_AND_ASSIGN(MklCPUAllocator);
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
index d581f45a90..3b59995433 100644
--- a/tensorflow/core/common_runtime/placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/util/status_util.h"
namespace tensorflow {
@@ -255,9 +254,11 @@ class ColocationGraph {
old_root_member.device_name,
allow_soft_placement_);
if (!s.ok()) {
- return errors::InvalidArgument("Cannot colocate nodes '", x.name(),
- "' and '", y.name(), ": ",
- s.error_message());
+ return errors::InvalidArgument(
+ "Cannot colocate nodes ",
+ errors::FormatColocationNodeForError(x.name()), " and ",
+ errors::FormatColocationNodeForError(y.name()), ": ",
+ s.error_message());
}
// Ensure that the common root has at least one supported device
@@ -268,8 +269,10 @@ class ColocationGraph {
old_root_member.supported_device_types);
if (new_root_member.supported_device_types.empty()) {
return errors::InvalidArgument(
- "Cannot colocate nodes '", x.name(), "' and '", y.name(),
- "' because no device type supports both of those nodes and the "
+ "Cannot colocate nodes ",
+ errors::FormatColocationNodeForError(x.name()), " and ",
+ errors::FormatColocationNodeForError(y.name()),
+ " because no device type supports both of those nodes and the "
"other nodes colocated with them.",
DebugInfo(x_root), DebugInfo(y_root));
}
@@ -377,8 +380,9 @@ class ColocationGraph {
// merged set device is different, so print both.
return errors::InvalidArgument(
"Could not satisfy explicit device specification '",
- node->requested_device(),
- "' because the node was colocated with a group of nodes that "
+ node->requested_device(), "' because the node ",
+ errors::FormatColocationNodeForError(node->name()),
+ " was colocated with a group of nodes that ",
"required incompatible device '",
DeviceNameUtils::ParsedNameToString(
members_[node_root].device_name),
@@ -810,10 +814,10 @@ Status Placer::Run() {
std::vector<Device*>* devices;
Status status = colocation_graph.GetDevicesForNode(node, &devices);
if (!status.ok()) {
- return AttachDef(errors::InvalidArgument(
- "Cannot assign a device for operation ",
- RichNodeName(node), ": ", status.error_message()),
- *node);
+ return AttachDef(
+ errors::InvalidArgument("Cannot assign a device for operation ",
+ node->name(), ": ", status.error_message()),
+ *node);
}
// Returns the first device in sorted devices list so we will always
@@ -857,10 +861,10 @@ Status Placer::Run() {
std::vector<Device*>* devices;
Status status = colocation_graph.GetDevicesForNode(node, &devices);
if (!status.ok()) {
- return AttachDef(errors::InvalidArgument(
- "Cannot assign a device for operation ",
- RichNodeName(node), ": ", status.error_message()),
- *node);
+ return AttachDef(
+ errors::InvalidArgument("Cannot assign a device for operation ",
+ node->name(), ": ", status.error_message()),
+ *node);
}
int assigned_device = -1;
@@ -926,22 +930,4 @@ void Placer::LogDeviceAssignment(const Node* node) const {
}
}
-bool Placer::ClientHandlesErrorFormatting() const {
- return options_ != nullptr &&
- options_->config.experimental().client_handles_error_formatting();
-}
-
-// Returns the node name in single quotes. If the client handles formatted
-// errors, appends a formatting tag which the client will reformat into, for
-// example, " (defined at filename:123)".
-string Placer::RichNodeName(const Node* node) const {
- string quoted_name = strings::StrCat("'", node->name(), "'");
- if (ClientHandlesErrorFormatting()) {
- string file_and_line = error_format_tag(*node, "${defined_at}");
- return strings::StrCat(quoted_name, file_and_line);
- } else {
- return quoted_name;
- }
-}
-
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/placer.h b/tensorflow/core/common_runtime/placer.h
index cefcdd25db..f97ffe7372 100644
--- a/tensorflow/core/common_runtime/placer.h
+++ b/tensorflow/core/common_runtime/placer.h
@@ -87,8 +87,6 @@ class Placer {
// placement if the SessionOptions entry in 'options_' requests it.
void AssignAndLog(int assigned_device, Node* node) const;
void LogDeviceAssignment(const Node* node) const;
- bool ClientHandlesErrorFormatting() const;
- string RichNodeName(const Node* node) const;
Graph* const graph_; // Not owned.
const DeviceSet* const devices_; // Not owned.
diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc
index 87f2f2ceb9..9b8a95e3b6 100644
--- a/tensorflow/core/common_runtime/placer_test.cc
+++ b/tensorflow/core/common_runtime/placer_test.cc
@@ -800,11 +800,11 @@ TEST_F(PlacerTest, TestInvalidMultipleColocationGroups) {
}
Status s = Place(&g);
- EXPECT_TRUE(
- str_util::StrContains(s.error_message(),
- "Cannot colocate nodes 'foo' and 'in' because no "
- "device type supports both of those nodes and the "
- "other nodes colocated with them"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "Cannot colocate nodes {{colocation_node foo}} and "
+ "{{colocation_node in}} because no device type supports both of those "
+ "nodes and the other nodes colocated with them"));
}
TEST_F(PlacerTest, TestColocationGroupWithReferenceConnections) {
@@ -867,9 +867,9 @@ TEST_F(PlacerTest, TestColocationGroupWithUnsatisfiableReferenceConnections) {
Status s = Place(&g);
EXPECT_TRUE(str_util::StrContains(
s.error_message(),
- "Cannot colocate nodes 'var3' and 'assign3' because no "
- "device type supports both of those nodes and the other "
- "nodes colocated with them."));
+ "Cannot colocate nodes {{colocation_node var3}} and {{colocation_node "
+ "assign3}} because no device type supports both of those nodes and the "
+ "other nodes colocated with them."));
}
TEST_F(PlacerTest, TestColocationAndReferenceConnections) {
@@ -1154,36 +1154,12 @@ TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementFormatTag) {
}
SessionOptions options;
- options.config.mutable_experimental()->set_client_handles_error_formatting(
- true);
Status s = Place(&g, &options);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
LOG(WARNING) << s.error_message();
EXPECT_TRUE(str_util::StrContains(s.error_message(),
- "Cannot assign a device for operation 'in'"
- "^^node:in:${defined_at}^^"));
-}
-
-// Test that the "Cannot assign a device" error message does not contain a
-// format tag when not it shouldn't
-TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementNoFormatTag) {
- Graph g(OpRegistry::Global());
- { // Scope for temporary variables used to construct g.
- GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
- ops::SourceOp("TestDevice",
- b.opts().WithName("in").WithDevice("/device:fakegpu:11"));
- TF_EXPECT_OK(BuildGraph(b, &g));
- }
-
- SessionOptions options;
- options.config.mutable_experimental()->set_client_handles_error_formatting(
- false);
- Status s = Place(&g, &options);
- EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(str_util::StrContains(
- s.error_message(), "Cannot assign a device for operation 'in'"));
- EXPECT_FALSE(str_util::StrContains(
- s.error_message(), "'in' (defined at ^^node:in:${file}:${line}^^)"));
+ "Cannot assign a device for operation in"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "{{node in}}"));
}
// Test that placement fails when a node requests an explicit device that is not
@@ -1289,8 +1265,9 @@ TEST_F(PlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) {
Status s = Place(&g);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(str_util::StrContains(
- s.error_message(), "Cannot colocate nodes 'var' and 'assign'"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(),
+ "Cannot colocate nodes {{colocation_node "
+ "var}} and {{colocation_node assign}}"));
}
// Test that a generator node follows its consumers (where there are several
diff --git a/tensorflow/core/common_runtime/pool_allocator.cc b/tensorflow/core/common_runtime/pool_allocator.cc
index 10a24ed14c..66dc8f3322 100644
--- a/tensorflow/core/common_runtime/pool_allocator.cc
+++ b/tensorflow/core/common_runtime/pool_allocator.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@@ -39,8 +40,7 @@ PoolAllocator::PoolAllocator(size_t pool_size_limit, bool auto_resize,
auto_resize_(auto_resize),
pool_size_limit_(pool_size_limit),
allocator_(allocator),
- size_rounder_(size_rounder),
- allocation_begun_(false) {
+ size_rounder_(size_rounder) {
if (auto_resize) {
CHECK_LT(size_t{0}, pool_size_limit)
<< "size limit must be > 0 if auto_resize is true.";
@@ -92,7 +92,6 @@ ChunkPrefix* FindPrefix(void* user_ptr) {
} // namespace
void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
- if (!allocation_begun_) allocation_begun_ = true;
if (num_bytes == 0) return nullptr;
// If alignment is larger than kPoolAlignment, increase num_bytes so that we
@@ -128,9 +127,6 @@ void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
return PrepareChunk(r, alignment, num_bytes);
} else {
void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes);
- for (const auto& v : alloc_visitors_) {
- v(ptr, num_bytes);
- }
return PrepareChunk(ptr, alignment, num_bytes);
}
}
@@ -140,9 +136,6 @@ void PoolAllocator::DeallocateRaw(void* ptr) {
ChunkPrefix* cp = FindPrefix(ptr);
CHECK_LE((void*)cp, (void*)ptr);
if (!has_size_limit_ && !auto_resize_) {
- for (const auto& v : free_visitors_) {
- v(cp, cp->num_bytes);
- }
allocator_->Free(cp, cp->num_bytes);
} else {
mutex_lock lock(mutex_);
@@ -163,9 +156,6 @@ void PoolAllocator::Clear() {
mutex_lock lock(mutex_);
for (auto iter : pool_) {
PtrRecord* pr = iter.second;
- for (const auto& v : free_visitors_) {
- v(pr->ptr, pr->num_bytes);
- }
allocator_->Free(pr->ptr, pr->num_bytes);
delete pr;
}
@@ -220,9 +210,6 @@ void PoolAllocator::EvictOne() {
DCHECK(iter != pool_.end());
}
pool_.erase(iter);
- for (const auto& v : free_visitors_) {
- v(prec->ptr, prec->num_bytes);
- }
allocator_->Free(prec->ptr, prec->num_bytes);
delete prec;
++evicted_count_;
@@ -268,28 +255,19 @@ void PoolAllocator::EvictOne() {
}
}
-void PoolAllocator::AddAllocVisitor(Visitor visitor) {
- mutex_lock lock(mutex_);
- CHECK(!allocation_begun_)
- << "AddAllocVisitor may not be called after pool allocation "
- << "has begun.";
- alloc_visitors_.push_back(visitor);
-}
-
-void PoolAllocator::AddFreeVisitor(Visitor visitor) {
- mutex_lock lock(mutex_);
- CHECK(!allocation_begun_)
- << "AddFreeVisitor may not be called after pool allocation "
- << "has begun.";
- free_visitors_.push_back(visitor);
-}
-
void* BasicCPUAllocator::Alloc(size_t alignment, size_t num_bytes) {
- return port::AlignedMalloc(num_bytes, static_cast<int>(alignment));
+ void* ptr = nullptr;
+ if (num_bytes > 0) {
+ ptr = port::AlignedMalloc(num_bytes, static_cast<int>(alignment));
+ VisitAlloc(ptr, numa_node_, num_bytes);
+ }
+ return ptr;
}
void BasicCPUAllocator::Free(void* ptr, size_t num_bytes) {
- port::AlignedFree(ptr);
+ if (num_bytes > 0) {
+ VisitFree(ptr, numa_node_, num_bytes);
+ port::AlignedFree(ptr);
+ }
}
-
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/pool_allocator.h b/tensorflow/core/common_runtime/pool_allocator.h
index 607734445b..5b4623ba10 100644
--- a/tensorflow/core/common_runtime/pool_allocator.h
+++ b/tensorflow/core/common_runtime/pool_allocator.h
@@ -16,14 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_
-// Simple LRU pool allocators for various flavors of CPU RAM that
-// implement the VisitableAllocator interface.
+// Simple LRU pool allocators for various flavors of CPU RAM.
#include <atomic>
#include <map>
#include <memory>
#include <vector>
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -41,7 +40,7 @@ class RoundUpInterface {
// Size-limited pool of memory buffers obtained from a SubAllocator
// instance. Pool eviction policy is LRU.
-class PoolAllocator : public VisitableAllocator {
+class PoolAllocator : public Allocator {
public:
// "pool_size_limit" is the maximum number of returned, re-usable
// memory buffers to keep in the pool. If pool_size_limit == 0, the
@@ -64,14 +63,6 @@ class PoolAllocator : public VisitableAllocator {
void DeallocateRaw(void* ptr) override;
- // REQUIRES: The following functions may only be called prior
- // to the first Allocate*() call. Once allocation has begun, it is
- // illegal to register another visitor.
-
- void AddAllocVisitor(Visitor visitor) override;
-
- void AddFreeVisitor(Visitor visitor) override;
-
// Allocate an unused memory region of size "num_bytes". Fetch from
// the pool if available, otherwise call allocator_.
void* Get(size_t num_bytes);
@@ -141,12 +132,6 @@ class PoolAllocator : public VisitableAllocator {
int64 put_count_ GUARDED_BY(mutex_) = 0;
int64 allocated_count_ GUARDED_BY(mutex_) = 0;
int64 evicted_count_ GUARDED_BY(mutex_) = 0;
- // Write access to these is guarded by mutex_, but not read
- // access. They may only be modified prior to the first
- // allocation. Later attempts to modify will fail.
- std::vector<Visitor> alloc_visitors_;
- std::vector<Visitor> free_visitors_;
- std::atomic<bool> allocation_begun_;
};
// Do-nothing rounder. Passes through sizes unchanged.
@@ -166,7 +151,9 @@ class Pow2Rounder : public RoundUpInterface {
class BasicCPUAllocator : public SubAllocator {
public:
// Argument numa_node is currently ignored.
- explicit BasicCPUAllocator(int numa_node) : numa_node_(numa_node) {}
+ BasicCPUAllocator(int numa_node, const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors)
+ : SubAllocator(alloc_visitors, free_visitors), numa_node_(numa_node) {}
~BasicCPUAllocator() override {}
@@ -176,6 +163,8 @@ class BasicCPUAllocator : public SubAllocator {
private:
int numa_node_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BasicCPUAllocator);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/process_state.cc b/tensorflow/core/common_runtime/process_state.cc
index 447338e7bd..bcaa37fc8a 100644
--- a/tensorflow/core/common_runtime/process_state.cc
+++ b/tensorflow/core/common_runtime/process_state.cc
@@ -71,20 +71,28 @@ ProcessState::MemDesc ProcessState::PtrType(const void* ptr) {
return MemDesc();
}
-VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) {
+Allocator* ProcessState::GetCPUAllocator(int numa_node) {
CHECK_GE(numa_node, 0);
if (!numa_enabled_) numa_node = 0;
mutex_lock lock(mu_);
while (cpu_allocators_.size() <= static_cast<size_t>(numa_node)) {
+ // If visitors have been defined we need an Allocator built from
+ // a SubAllocator. Prefer BFCAllocator, but fall back to PoolAllocator
+ // depending on env var setting.
+ const bool alloc_visitors_defined =
+ (!cpu_alloc_visitors_.empty() || !cpu_free_visitors_.empty());
bool use_bfc_allocator = false;
- // TODO(reedwm): Switch default to BGFAllocator if it's at least as fast and
- // efficient.
- Status status = ReadBoolFromEnvVar("TF_CPU_ALLOCATOR_USE_BFC", false,
- &use_bfc_allocator);
+ Status status = ReadBoolFromEnvVar(
+ "TF_CPU_ALLOCATOR_USE_BFC", alloc_visitors_defined, &use_bfc_allocator);
if (!status.ok()) {
LOG(ERROR) << "GetCPUAllocator: " << status.error_message();
}
- VisitableAllocator* allocator;
+ Allocator* allocator = nullptr;
+ SubAllocator* sub_allocator =
+ (alloc_visitors_defined || use_bfc_allocator)
+ ? new BasicCPUAllocator(numa_enabled_ ? numa_node : -1,
+ cpu_alloc_visitors_, cpu_free_visitors_)
+ : nullptr;
if (use_bfc_allocator) {
// TODO(reedwm): evaluate whether 64GB by default is the best choice.
int64 cpu_mem_limit_in_mb = -1;
@@ -95,34 +103,63 @@ VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) {
LOG(ERROR) << "GetCPUAllocator: " << status.error_message();
}
int64 cpu_mem_limit = cpu_mem_limit_in_mb * (1LL << 20);
- allocator = new BFCAllocator(
- new BasicCPUAllocator(numa_enabled_ ? numa_node : -1), cpu_mem_limit,
- true /*allow_growth*/, "bfc_cpu_allocator_for_gpu" /*name*/);
+ DCHECK(sub_allocator);
+ allocator =
+ new BFCAllocator(sub_allocator, cpu_mem_limit, true /*allow_growth*/,
+ "bfc_cpu_allocator_for_gpu" /*name*/);
VLOG(2) << "Using BFCAllocator with memory limit of "
<< cpu_mem_limit_in_mb << " MB for ProcessState CPU allocator";
- } else {
- allocator = new PoolAllocator(
- 100 /*pool_size_limit*/, true /*auto_resize*/,
- new BasicCPUAllocator(numa_enabled_ ? numa_node : -1),
- new NoopRounder, "cpu_pool");
+ } else if (alloc_visitors_defined) {
+ DCHECK(sub_allocator);
+ allocator =
+ new PoolAllocator(100 /*pool_size_limit*/, true /*auto_resize*/,
+ sub_allocator, new NoopRounder, "cpu_pool");
VLOG(2) << "Using PoolAllocator for ProcessState CPU allocator "
<< "numa_enabled_=" << numa_enabled_
<< " numa_node=" << numa_node;
+ } else {
+ DCHECK(!sub_allocator);
+ allocator = cpu_allocator();
}
- if (LogMemory::IsEnabled()) {
+ if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) {
// Wrap the allocator to track allocation ids for better logging
// at the cost of performance.
- allocator = new TrackingVisitableAllocator(allocator, true);
+ allocator = new TrackingAllocator(allocator, true);
}
cpu_allocators_.push_back(allocator);
+ if (!sub_allocator) {
+ DCHECK(cpu_alloc_visitors_.empty() && cpu_free_visitors_.empty());
+ }
}
return cpu_allocators_[numa_node];
}
+void ProcessState::AddCPUAllocVisitor(SubAllocator::Visitor visitor) {
+ VLOG(1) << "AddCPUAllocVisitor";
+ mutex_lock lock(mu_);
+ CHECK_EQ(0, cpu_allocators_.size()) // Crash OK
+ << "AddCPUAllocVisitor must be called prior to first call to "
+ "ProcessState::GetCPUAllocator";
+ cpu_alloc_visitors_.push_back(std::move(visitor));
+}
+
+void ProcessState::AddCPUFreeVisitor(SubAllocator::Visitor visitor) {
+ mutex_lock lock(mu_);
+ CHECK_EQ(0, cpu_allocators_.size()) // Crash OK
+ << "AddCPUFreeVisitor must be called prior to first call to "
+ "ProcessState::GetCPUAllocator";
+ cpu_free_visitors_.push_back(std::move(visitor));
+}
+
void ProcessState::TestOnlyReset() {
mutex_lock lock(mu_);
+ // Don't delete this value because it's static.
+ Allocator* default_cpu_allocator = cpu_allocator();
mem_desc_map_.clear();
- gtl::STLDeleteElements(&cpu_allocators_);
+ for (Allocator* a : cpu_allocators_) {
+ if (a != default_cpu_allocator) delete a;
+ }
+ cpu_allocators_.clear();
gtl::STLDeleteElements(&cpu_al_);
}
diff --git a/tensorflow/core/common_runtime/process_state.h b/tensorflow/core/common_runtime/process_state.h
index 2892677333..cac312d849 100644
--- a/tensorflow/core/common_runtime/process_state.h
+++ b/tensorflow/core/common_runtime/process_state.h
@@ -30,7 +30,6 @@ limitations under the License.
namespace tensorflow {
class Allocator;
-class VisitableAllocator;
class PoolAllocator;
// Singleton that manages per-process state, e.g. allocation of
@@ -65,7 +64,15 @@ class ProcessState {
// Returns the one CPUAllocator used for the given numa_node.
// TEMPORARY: ignores numa_node.
- VisitableAllocator* GetCPUAllocator(int numa_node);
+ Allocator* GetCPUAllocator(int numa_node);
+
+ // Registers alloc visitor for the CPU allocator(s).
+ // REQUIRES: must be called before GetCPUAllocator.
+ void AddCPUAllocVisitor(SubAllocator::Visitor v);
+
+ // Registers free visitor for the CPU allocator(s).
+ // REQUIRES: must be called before GetCPUAllocator.
+ void AddCPUFreeVisitor(SubAllocator::Visitor v);
typedef std::unordered_map<const void*, MemDesc> MDMap;
@@ -87,7 +94,9 @@ class ProcessState {
mutex mu_;
- std::vector<VisitableAllocator*> cpu_allocators_ GUARDED_BY(mu_);
+ std::vector<Allocator*> cpu_allocators_ GUARDED_BY(mu_);
+ std::vector<SubAllocator::Visitor> cpu_alloc_visitors_ GUARDED_BY(mu_);
+ std::vector<SubAllocator::Visitor> cpu_free_visitors_ GUARDED_BY(mu_);
virtual ~ProcessState();
diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h
index 103eee03b3..9d59264899 100644
--- a/tensorflow/core/common_runtime/renamed_device.h
+++ b/tensorflow/core/common_runtime/renamed_device.h
@@ -72,9 +72,10 @@ class RenamedDevice : public Device {
return underlying_->MakeGpuDevice();
}
- void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
- DeviceContext* dc, Allocator* allocator) override {
- underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
+ Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
+ DeviceContext* dc,
+ Allocator* allocator) override {
+ return underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
}
Status MakeTensorFromProto(const TensorProto& tensor_proto,
diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc
index 1e3fed0d6f..43ca3f1e3e 100644
--- a/tensorflow/core/common_runtime/rendezvous_util.cc
+++ b/tensorflow/core/common_runtime/rendezvous_util.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/rendezvous_util.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/reffed_status_callback.h"
diff --git a/tensorflow/core/common_runtime/session_state.cc b/tensorflow/core/common_runtime/session_state.cc
index 65ff356e73..5b1915755d 100644
--- a/tensorflow/core/common_runtime/session_state.cc
+++ b/tensorflow/core/common_runtime/session_state.cc
@@ -70,7 +70,7 @@ Status TensorStore::SaveTensors(const std::vector<string>& output_names,
// Save only the tensors in output_names in the session.
for (const string& name : output_names) {
TensorId id(ParseTensorName(name));
- const string& op_name = std::string(id.first);
+ const string op_name(id.first);
auto it = tensors_.find(op_name);
if (it != tensors_.end()) {
// Save the tensor to the session state.
diff --git a/tensorflow/core/common_runtime/single_threaded_cpu_device.h b/tensorflow/core/common_runtime/single_threaded_cpu_device.h
index 04d5af9087..22650b0d83 100644
--- a/tensorflow/core/common_runtime/single_threaded_cpu_device.h
+++ b/tensorflow/core/common_runtime/single_threaded_cpu_device.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
namespace tensorflow {
diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc
index 9c2510e6a9..a70ab93d4a 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.cc
+++ b/tensorflow/core/common_runtime/step_stats_collector.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace {
@@ -40,46 +41,24 @@ struct AllocStats {
};
} // namespace
-NodeExecStatsWrapper::NodeExecStatsWrapper(const string& node_name)
- : NodeExecStatsWrapper(new NodeExecStats) {
- stats_->set_node_name(node_name);
-}
-NodeExecStatsWrapper::NodeExecStatsWrapper(NodeExecStats* stats)
- : stats_(stats) {}
-
-void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* v) {
- DCHECK(v);
- NodeOutput* no = stats_->add_output();
- no->set_slot(slot);
- v->FillDescription(no->mutable_tensor_description());
-}
-
-void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) {
- for (const auto& allocator_pair : ctx->wrapped_allocators()) {
- AddAllocation(allocator_pair.first, allocator_pair.second);
- }
- auto* ms = stats_->mutable_memory_stats();
- ms->set_temp_memory_size(ctx->temp_memory_allocated());
- for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
- ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
- }
- ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
+NodeExecStatsWrapper::NodeExecStatsWrapper(
+ const Node* node, StepStatsCollector* step_stats_collector)
+ : NodeExecStatsWrapper(MakeUnique<NodeExecStats>(), node,
+ step_stats_collector) {
+ stats_->set_node_name(node->name());
}
-void NodeExecStatsWrapper::SetReferencedTensors(
- const TensorReferenceVector& tensors) {
- // be careful not to increment the reference count on any tensor
- // while recording the information
- for (size_t i = 0; i < tensors.size(); ++i) {
- AllocationDescription* description = stats_->add_referenced_tensor();
- tensors.at(i).FillDescription(description);
- }
-}
-
-// TODO(tucker): merge with the DetailText function in session.cc
-// in a common location.
-bool NodeExecStatsWrapper::SetTimelineLabel(const Node* node) {
- bool is_transfer_node = false;
+NodeExecStatsWrapper::NodeExecStatsWrapper(
+ std::unique_ptr<NodeExecStats> stats, const Node* node,
+ StepStatsCollector* step_stats_collector)
+ : stats_(std::move(stats)),
+ node_(node),
+ step_stats_collector_(step_stats_collector) {}
+
+void NodeExecStatsWrapper::Done(const string& device) {
+ // TODO(tucker): merge with the DetailText function in session.cc in a common
+ // location.
+ DCHECK(node_);
string memory;
for (auto& all : stats_->memory()) {
int64 tot = all.total_bytes();
@@ -96,31 +75,96 @@ bool NodeExecStatsWrapper::SetTimelineLabel(const Node* node) {
}
}
}
- const AttrSlice attrs = node->attrs();
+ const AttrSlice attrs = node_->attrs();
string text;
- if (IsSend(node)) {
+ if (IsSend(node_)) {
string tensor_name;
TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
string recv_device;
TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device));
- text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
+ text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(),
"(", tensor_name, " @", recv_device);
- is_transfer_node = true;
- } else if (IsRecv(node)) {
+ } else if (IsRecv(node_)) {
string tensor_name;
TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
string send_device;
TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device));
- text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
+ text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(),
"(", tensor_name, " @", send_device);
- is_transfer_node = true;
} else {
text =
- strings::StrCat(memory, node->name(), " = ", node->type_string(), "(",
- str_util::Join(node->requested_inputs(), ", "), ")");
+ strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(",
+ str_util::Join(node_->requested_inputs(), ", "), ")");
}
stats_->set_timeline_label(text);
- return is_transfer_node;
+ step_stats_collector_->Save(device, this);
+}
+
+void NodeExecStatsWrapper::RecordExecutorStarted() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
+ stats_->set_all_start_nanos(now_nanos);
+}
+
+void NodeExecStatsWrapper::RecordComputeStarted() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::RecordComputeEnded() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::RecordExecutorEnded() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::SetScheduled(int64 nanos) {
+ stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
+ stats_->set_scheduled_nanos(nanos);
+}
+
+void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) {
+ for (const auto& allocator_pair : ctx->wrapped_allocators()) {
+ AddAllocation(allocator_pair.first, allocator_pair.second);
+ }
+ auto* ms = stats_->mutable_memory_stats();
+ ms->set_temp_memory_size(ctx->temp_memory_allocated());
+ for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
+ ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
+ }
+ ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
+}
+
+void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* tensor) {
+ DCHECK(tensor);
+ NodeOutput* node_output = stats_->add_output();
+ node_output->set_slot(slot);
+ tensor->FillDescription(node_output->mutable_tensor_description());
+}
+
+void NodeExecStatsWrapper::SetReferencedTensors(
+ const TensorReferenceVector& tensors) {
+ // be careful not to increment the reference count on any tensor
+ // while recording the information
+ for (size_t i = 0; i < tensors.size(); ++i) {
+ AllocationDescription* description = stats_->add_referenced_tensor();
+ tensors.at(i).FillDescription(description);
+ }
}
void NodeExecStatsWrapper::AddAllocation(
@@ -150,8 +194,8 @@ void NodeExecStatsWrapper::Finalize() {
allocations_.clear();
}
-StepStatsCollector::StepStatsCollector(StepStats* ss)
- : finalized_(false), step_stats_(ss) {}
+StepStatsCollector::StepStatsCollector(StepStats* step_stats)
+ : finalized_(false), step_stats_(step_stats) {}
static int ExtractGpuWithStreamAll(string device_name) {
// Check if the device name matches the ".*gpu:(\\d+)/stream:all$" regexp,
@@ -176,7 +220,7 @@ static int ExtractGpuWithStreamAll(string device_name) {
} else {
// Convert the captured string into an integer. But first we need to put
// the digits back in order
- string ordered_capture = std::string(capture);
+ string ordered_capture(capture);
std::reverse(ordered_capture.begin(), ordered_capture.end());
int gpu_id;
CHECK(strings::safe_strto32(ordered_capture, &gpu_id));
@@ -205,7 +249,7 @@ static int ExtractGpuWithoutStream(string device_name) {
} else {
// Convert the captured string into an integer. But first we need to put
// the digits back in order
- string ordered_capture = std::string(capture);
+ string ordered_capture(capture);
std::reverse(ordered_capture.begin(), ordered_capture.end());
int gpu_id;
CHECK(strings::safe_strto32(ordered_capture, &gpu_id));
@@ -252,7 +296,7 @@ void StepStatsCollector::BuildCostModel(
for (auto& itr : per_device_stats) {
const StringPiece device_name = itr.first;
- const int gpu_id = ExtractGpuWithoutStream(std::string(device_name));
+ const int gpu_id = ExtractGpuWithoutStream(string(device_name));
if (gpu_id >= 0) {
// Reference the gpu hardware stats in addition to the regular stats
// for this gpu device if they're available.
@@ -338,28 +382,40 @@ void StepStatsCollector::BuildCostModel(
}
}
-void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
- Save(device, new NodeExecStatsWrapper(nt));
+void StepStatsCollector::Save(const string& device,
+ NodeExecStats* node_stats_pb) {
+ Save(device,
+ new NodeExecStatsWrapper(std::unique_ptr<NodeExecStats>(node_stats_pb),
+ nullptr, this));
}
void StepStatsCollector::Save(const string& device,
- NodeExecStatsWrapper* stats) {
- if (!stats) return;
- VLOG(1) << "Save dev " << device << " nt " << stats->stats();
+ NodeExecStatsWrapper* node_stats) {
+ if (!node_stats) return;
+ VLOG(1) << "Save dev " << device << " node stats " << node_stats->stats();
{
mutex_lock l(mu_);
if (finalized_) {
LOG(WARNING) << "stats saved after finalize will not be collected.";
}
- if (!step_stats_ || collectedNodes >= kMaxCollectedNodes) {
+ if (!step_stats_ || collected_nodes_ >= kMaxCollectedNodes) {
VLOG(1) << "step_stats_ nullptr or already collected too many nodes.";
- delete stats;
+ delete node_stats;
return;
}
- auto& dss = dev_stats_[device];
- dss.push_back(std::unique_ptr<NodeExecStatsWrapper>(stats));
- collectedNodes++;
+ auto& device_stats = dev_stats_[device];
+ device_stats.push_back(std::unique_ptr<NodeExecStatsWrapper>(node_stats));
+ collected_nodes_++;
+ }
+}
+
+NodeExecStatsInterface* StepStatsCollector::CreateNodeExecStats(
+ const Node* node) {
+ // Only collect statistics for non-transfer nodes.
+ if (IsSend(node) || IsRecv(node)) {
+ return nullptr;
}
+ return new NodeExecStatsWrapper(node, this);
}
string StepStatsCollector::ReportAllocsOnResourceExhausted(const string& err) {
@@ -446,12 +502,12 @@ void StepStatsCollector::Finalize() {
FinalizeInternal();
}
-void StepStatsCollector::FinalizeAndSwap(StepStats* ss) {
+void StepStatsCollector::FinalizeAndSwap(StepStats* step_stats) {
mutex_lock l(mu_);
CHECK(step_stats_);
FinalizeInternal();
- ss->Swap(step_stats_);
- collectedNodes = 0;
+ step_stats->Swap(step_stats_);
+ collected_nodes_ = 0;
}
void StepStatsCollector::FinalizeInternal() {
diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h
index 7206fbf427..4365b11b19 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.h
+++ b/tensorflow/core/common_runtime/step_stats_collector.h
@@ -36,81 +36,78 @@ class Node;
class NodeExecStats;
class OpKernelContext;
class StepStats;
+class StepStatsCollector;
class Tensor;
class TrackingAllocator;
-// Wraps NodeExecStats and adds allocation to it.
-class NodeExecStatsWrapper {
+// Statistics collection interface for individual node execution.
+//
+// See `NodeExecStatsWrapper` for a concrete implementation of this interface
+// that interfaces with the `Session` layer.
+class NodeExecStatsInterface {
public:
- NodeExecStatsWrapper(const string& node_name);
- // Owns 'stats'.
- NodeExecStatsWrapper(NodeExecStats* stats);
+ virtual ~NodeExecStatsInterface() {}
- // Destructor calls Finalize() to release the TrackingAllocators.
- ~NodeExecStatsWrapper() { Finalize(); }
-
- // Records the absolute time in nanoseconds at which this node became
- // runnable (i.e. was scheduled for execution).
- void SetScheduled(int64 nanos) {
- stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
- stats_->set_scheduled_nanos(nanos);
- }
+ // Called when the statistics collection for the node has finished. Once this
+ // method is called, the caller should not make assumptions about the validity
+ // of this object.
+ virtual void Done(const string& device) = 0;
// Called immediately after this node starts being processed by the executor.
- void RecordExecutorStarted() {
- int64 now_nanos = Env::Default()->NowNanos();
- stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
- stats_->set_all_start_nanos(now_nanos);
- }
+ virtual void RecordExecutorStarted() = 0;
// Called immediately before this node's `Compute()` or `ComputeAsync()`
// method is called.
- void RecordComputeStarted() {
- int64 now_nanos = Env::Default()->NowNanos();
- DCHECK_NE(stats_->all_start_micros(), 0);
- DCHECK_NE(stats_->all_start_nanos(), 0);
- stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- stats_->all_start_micros());
- stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos());
- }
+ virtual void RecordComputeStarted() = 0;
// Called immediately after this node's `Compute()` method returned (or, for
// asynchronous operations, the callback passed to its `ComputeAsync()` method
// was called).
- void RecordComputeEnded() {
- int64 now_nanos = Env::Default()->NowNanos();
- DCHECK_NE(stats_->all_start_micros(), 0);
- DCHECK_NE(stats_->all_start_nanos(), 0);
- stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- stats_->all_start_micros());
- stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos());
- }
+ virtual void RecordComputeEnded() = 0;
// Called immediately after this executor finishes processing this node.
- void RecordExecutorEnded() {
- int64 now_nanos = Env::Default()->NowNanos();
- DCHECK_NE(stats_->all_start_micros(), 0);
- DCHECK_NE(stats_->all_start_nanos(), 0);
- stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- stats_->all_start_micros());
- stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos());
- }
-
- // Records information about the tensor produced by this node at the given
- // output slot.
- void SetOutput(int slot, const Tensor* v);
+ virtual void RecordExecutorEnded() = 0;
// Records information about the memory allocated during the execution of this
// node.
- void SetMemory(OpKernelContext* ctx);
+ virtual void SetMemory(OpKernelContext* ctx) = 0;
+
+ // Records information about the tensor produced by this node at the given
+ // output slot.
+ virtual void SetOutput(int slot, const Tensor* tensor) = 0;
// Records information about the tensors that were accessed during the
// execution of this node.
- void SetReferencedTensors(const TensorReferenceVector& tensors);
+ virtual void SetReferencedTensors(const TensorReferenceVector& tensors) = 0;
- // Sets the timeline_label field of the wrapped NodeExecStats, using data
- // from *node. Returns true iff the node is a transfer node.
- bool SetTimelineLabel(const Node* node);
+ // Records the absolute time in nanoseconds at which this node became
+ // runnable (i.e. was scheduled for execution).
+ virtual void SetScheduled(int64 nanos) = 0;
+};
+
+// Wraps NodeExecStats and adds allocation to it.
+class NodeExecStatsWrapper : public NodeExecStatsInterface {
+ public:
+ // Does not take ownership of `node` or `step_stats_collector`.
+ NodeExecStatsWrapper(const Node* node,
+ StepStatsCollector* step_stats_collector);
+
+ // Takes ownership of 'stats' but not `node` or `step_stats_collector`.
+ NodeExecStatsWrapper(std::unique_ptr<NodeExecStats> stats, const Node* node,
+ StepStatsCollector* step_stats_collector);
+
+ // Destructor calls Finalize() to release the TrackingAllocators.
+ ~NodeExecStatsWrapper() { Finalize(); }
+
+ void Done(const string& device) override;
+ void RecordExecutorStarted() override;
+ void RecordComputeStarted() override;
+ void RecordComputeEnded() override;
+ void RecordExecutorEnded() override;
+ void SetMemory(OpKernelContext* ctx) override;
+ void SetOutput(int slot, const Tensor* tensor) override;
+ void SetReferencedTensors(const TensorReferenceVector& tensors) override;
+ void SetScheduled(int64 nanos) override;
private:
friend class StepStatsCollector;
@@ -128,9 +125,11 @@ class NodeExecStatsWrapper {
gtl::InlinedVector<std::pair<AllocatorMemoryUsed*, TrackingAllocator*>, 2>
allocations_;
std::unique_ptr<NodeExecStats> stats_;
+ const Node* const node_; // Not owned.
+ StepStatsCollector* const step_stats_collector_; // Not owned.
};
-// Statistics collection interface for individual node execution.
+// Statistics collection interface for step execution.
//
// See `StepStatsCollector` for a concrete implementation of this interface
// that interfaces with the `Session` layer.
@@ -138,8 +137,9 @@ class StepStatsCollectorInterface {
public:
virtual ~StepStatsCollectorInterface() {}
- // Saves `stats` to the collector.
- virtual void Save(const string& device, NodeExecStatsWrapper* stats) = 0;
+ // Creates an instance of `NodeExecStatsInterface` that should be used for
+ // collecting statistics about individual node execution.
+ virtual NodeExecStatsInterface* CreateNodeExecStats(const Node* node) = 0;
// Generates a string reporting the currently used memory based
// on ResourceExhausted OOM `err` message.
@@ -154,8 +154,8 @@ class StepStatsCollectorInterface {
// Each DeviceStats object holds multiple NodeExecStats.
class StepStatsCollector : public StepStatsCollectorInterface {
public:
- // Does not take ownership of `ss`.
- explicit StepStatsCollector(StepStats* ss);
+ // Does not take ownership of `step_stats`.
+ explicit StepStatsCollector(StepStats* step_stats);
// BuildCostModel builds or updates a CostModel managed by cost_model_manager,
// using the currently collected DeviceStats associated with the devices in
@@ -164,11 +164,12 @@ class StepStatsCollector : public StepStatsCollectorInterface {
CostModelManager* cost_model_manager,
const std::unordered_map<string, const Graph*>& device_map);
- // Save saves nt to the DeviceStats object associated with device.
+ // Saves node statistics to the DeviceStats object associated with device.
// Should be called before Finalize.
- void Save(const string& device, NodeExecStats* nt);
- void Save(const string& device, NodeExecStatsWrapper* stats) override;
+ void Save(const string& device, NodeExecStats* node_stats_pb);
+ void Save(const string& device, NodeExecStatsWrapper* node_stats);
+ NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override;
string ReportAllocsOnResourceExhausted(const string& err) override;
// The following 2 Finalize methods populate the StepStats passed
@@ -176,20 +177,22 @@ class StepStatsCollector : public StepStatsCollectorInterface {
// User shouldn't call Save() methods after Finalize.
void Finalize();
// swaps the content of StepStats* from constructor with 'ss'.
- void FinalizeAndSwap(StepStats* ss);
+ void FinalizeAndSwap(StepStats* step_stats);
private:
+ // TODO(suharshs): Make this configurable if its not possible to find a value
+ // that works for all cases.
+ static const uint64 kMaxCollectedNodes = 1 << 20;
+
+ typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeStatsVector;
+
void FinalizeInternal() EXCLUSIVE_LOCKS_REQUIRED(mu_);
- typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeExecStatsVec;
- // TODO(suharshs): Make this configurable if its not possible to find a value
- // that works for all cases.
- const uint64 kMaxCollectedNodes = 1 << 20;
mutex mu_;
bool finalized_ GUARDED_BY(mu_);
- std::unordered_map<string, NodeExecStatsVec> dev_stats_ GUARDED_BY(mu_);
+ std::unordered_map<string, NodeStatsVector> dev_stats_ GUARDED_BY(mu_);
StepStats* step_stats_ GUARDED_BY(mu_);
- uint64 collectedNodes GUARDED_BY(mu_) = 0;
+ uint64 collected_nodes_ GUARDED_BY(mu_) = 0;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/tracing_device.h b/tensorflow/core/common_runtime/tracing_device.h
index 39215efa35..e1b163074f 100644
--- a/tensorflow/core/common_runtime/tracing_device.h
+++ b/tensorflow/core/common_runtime/tracing_device.h
@@ -35,8 +35,11 @@ class TracingDevice : public Device {
: Device(env, attributes) {}
void Compute(OpKernel* op_kernel, OpKernelContext* context) override {
+ const tracing::TraceCollector* trace_collector =
+ tracing::GetTraceCollector();
if (TF_PREDICT_FALSE(
- tracing::GetTraceCollector() ||
+ (trace_collector &&
+ trace_collector->IsEnabled(op_kernel->IsExpensive())) ||
tracing::GetEventCollector(tracing::EventCategory::kCompute))) {
const string& op_name = op_kernel->name();
tracing::ScopedActivity activity(op_name, op_kernel->type_string(),
diff --git a/tensorflow/core/common_runtime/visitable_allocator.h b/tensorflow/core/common_runtime/visitable_allocator.h
deleted file mode 100644
index ae0563a96a..0000000000
--- a/tensorflow/core/common_runtime/visitable_allocator.h
+++ /dev/null
@@ -1,79 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
-
-#include <functional>
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/tracking_allocator.h"
-
-namespace tensorflow {
-
-// Subclass VisitableAllocator instead of Allocator when a memory
-// allocator needs to enable some kind of registration/deregistration
-// of memory areas.
-class VisitableAllocator : public Allocator {
- public:
- // Visitor gets called with a pointer to a memory area and its
- // size in bytes.
- typedef std::function<void(void*, size_t)> Visitor;
-
- // Register a visitor guaranteed to be called exactly once on each
- // chunk of memory newly allocated from the underlying device.
- // Typically, chunks will be reused and possibly sub-divided by a
- // pool manager, so the calls will happen only once per process
- // execution, not once per tensor (re)allocation.
- virtual void AddAllocVisitor(Visitor visitor) = 0;
-
- // Register a visitor guaranteed to be called on each chunk of
- // memory returned to the underlying device.
- virtual void AddFreeVisitor(Visitor visitor) = 0;
-};
-
-// Needed for cases when a VisitableAllocator gets wrapped for tracking.
-// Multiple-inheritance is considered acceptable in this case because
-// VisitableAllocator is a pure virtual interface and only TrackingAllocator
-// has default implementation.
-class TrackingVisitableAllocator : public TrackingAllocator,
- public VisitableAllocator {
- public:
- TrackingVisitableAllocator(VisitableAllocator* allocator, bool track_ids)
- : TrackingAllocator(allocator, track_ids), allocator_(allocator) {}
- ~TrackingVisitableAllocator() override {}
-
- string Name() override { return TrackingAllocator::Name(); }
-
- void* AllocateRaw(size_t alignment, size_t num_bytes) override {
- return TrackingAllocator::AllocateRaw(alignment, num_bytes);
- }
-
- void DeallocateRaw(void* ptr) override {
- TrackingAllocator::DeallocateRaw(ptr);
- }
-
- void AddAllocVisitor(Visitor visitor) override {
- allocator_->AddAllocVisitor(visitor);
- }
-
- void AddFreeVisitor(Visitor visitor) override {
- allocator_->AddFreeVisitor(visitor);
- }
-
- protected:
- VisitableAllocator* allocator_;
-};
-} // namespace tensorflow
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index 38863db1cc..6994dec3b5 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -693,6 +693,7 @@ uint64 DebugFileIO::diskBytesUsed = 0;
mutex DebugFileIO::bytes_mu(LINKER_INITIALIZED);
bool DebugFileIO::requestDiskByteUsage(uint64 bytes) {
+ mutex_lock l(bytes_mu);
if (globalDiskBytesLimit == 0) {
const char* env_tfdbg_disk_bytes_limit = getenv("TFDBG_DISK_BYTES_LIMIT");
if (env_tfdbg_disk_bytes_limit == nullptr ||
@@ -707,7 +708,6 @@ bool DebugFileIO::requestDiskByteUsage(uint64 bytes) {
if (bytes == 0) {
return true;
}
- mutex_lock l(bytes_mu);
if (diskBytesUsed + bytes < globalDiskBytesLimit) {
diskBytesUsed += bytes;
return true;
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 6c146036ae..f7a2967d00 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -233,14 +233,11 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
params.function_library = lib;
params.create_kernel = [session, lib, opseg](const NodeDef& ndef,
OpKernel** kernel) {
- // We do not share the kernel via the OpSegment if the node is
- // stateless, or a function.
// NOTE(mrry): We must not share function kernels (implemented
// using `CallOp`) between subgraphs, because `CallOp::handle_`
// is tied to a particular subgraph. Even if the function itself
// is stateful, the `CallOp` that invokes it is not.
- if (!lib->IsStateful(ndef.op()) ||
- lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
+ if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
return lib->CreateKernel(ndef, kernel);
}
auto create_fn = [lib, &ndef](OpKernel** kernel) {
@@ -252,8 +249,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
return opseg->FindOrCreate(session, ndef.name(), kernel, create_fn);
};
params.delete_kernel = [lib](OpKernel* kernel) {
- // If the node is stateful, opseg owns it. Otherwise, delete it.
- if (kernel && !lib->IsStateful(kernel->type_string())) {
+ if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) {
delete kernel;
}
};
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index abd07e37b7..8e9eec1ed9 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -449,7 +449,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
*c->req.mutable_debug_options() =
callable_opts_.run_options().debug_options();
- c->req.set_collective_graph_key(bg_opts_.collective_graph_key);
+ c->req.set_collective_graph_key(client_graph()->collective_graph_key);
VLOG(2) << "Register " << c->req.graph_def().DebugString();
auto cb = [c, &done](const Status& s) {
c->status = s;
@@ -1111,10 +1111,6 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
}
- if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
- h = Hash64Combine(opts.collective_graph_key, h);
- }
-
return h;
}
@@ -1788,10 +1784,10 @@ Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
Status s = run_status;
if (s.ok()) {
pss->end_micros = Env::Default()->NowMicros();
- if (rcg->build_graph_options().collective_graph_key !=
+ if (rcg->client_graph()->collective_graph_key !=
BuildGraphOptions::kNoCollectiveGraphKey) {
env_->collective_executor_mgr->RetireStepId(
- rcg->build_graph_options().collective_graph_key, step_id);
+ rcg->client_graph()->collective_graph_key, step_id);
}
// Schedule post-processing and cleanup to be done asynchronously.
rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
@@ -1850,7 +1846,7 @@ Status MasterSession::DoRunWithLocalExecution(
// Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use.
- uint64 step_id = NewStepId(bgopts.collective_graph_key);
+ uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key);
TRACEPRINTF("stepid %llu", step_id);
std::unique_ptr<ProfileHandler> ph;
@@ -1914,8 +1910,7 @@ Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
// Prepare.
int64 count = rcg->get_and_increment_execution_count();
- const uint64 step_id =
- NewStepId(rcg->build_graph_options().collective_graph_key);
+ const uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key);
TRACEPRINTF("stepid %llu", step_id);
const RunOptions& run_options = rcg->callable_options().run_options();
diff --git a/tensorflow/core/example/example.proto b/tensorflow/core/example/example.proto
index e7142a4ef9..e36e51d8d5 100644
--- a/tensorflow/core/example/example.proto
+++ b/tensorflow/core/example/example.proto
@@ -199,7 +199,13 @@ message Example {
// to determine if all features within the FeatureList must
// have the same size. The same holds for this FeatureList across multiple
// examples.
-//
+// - For sequence modeling, e.g.:
+// http://colah.github.io/posts/2015-08-Understanding-LSTMs/
+// https://github.com/tensorflow/nmt
+// the feature lists represent a sequence of frames.
+// In this scenario, all FeatureLists in a SequenceExample have the same
+// number of Feature messages, so that the ith element in each FeatureList
+// is part of the ith frame (or time step).
// Examples of conformant and non-conformant examples' FeatureLists:
//
// Conformant FeatureLists:
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc
index 888ed0c57b..84cee5569c 100644
--- a/tensorflow/core/framework/allocator.cc
+++ b/tensorflow/core/framework/allocator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/tracking_allocator.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
@@ -56,6 +57,14 @@ void RunResourceDtor(ResourceHandle* p, size_t n) {
for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
}
+void Allocator::RunVariantCtor(Variant* p, size_t n) {
+ for (size_t i = 0; i < n; ++p, ++i) new (p) Variant();
+}
+
+void Allocator::RunVariantDtor(Variant* p, size_t n) {
+ for (size_t i = 0; i < n; ++p, ++i) p->~Variant();
+}
+
// If true, cpu allocator collects more stats.
static bool cpu_allocator_collect_stats = false;
// If true, cpu allocator collects full stats.
@@ -187,7 +196,7 @@ class CPUAllocatorFactory : public AllocatorFactory {
class CPUSubAllocator : public SubAllocator {
public:
explicit CPUSubAllocator(CPUAllocator* cpu_allocator)
- : cpu_allocator_(cpu_allocator) {}
+ : SubAllocator({}, {}), cpu_allocator_(cpu_allocator) {}
void* Alloc(size_t alignment, size_t num_bytes) override {
return cpu_allocator_->AllocateRaw(alignment, num_bytes);
@@ -213,4 +222,22 @@ Allocator* cpu_allocator() {
}
return cpu_alloc;
}
+
+SubAllocator::SubAllocator(const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors)
+ : alloc_visitors_(alloc_visitors), free_visitors_(free_visitors) {}
+
+void SubAllocator::VisitAlloc(void* ptr, int index, size_t num_bytes) {
+ for (const auto& v : alloc_visitors_) {
+ v(ptr, index, num_bytes);
+ }
+}
+
+void SubAllocator::VisitFree(void* ptr, int index, size_t num_bytes) {
+ // Although we don't guarantee any order of visitor application, strive
+ // to apply free visitors in reverse order of alloc visitors.
+ for (int i = free_visitors_.size() - 1; i >= 0; --i) {
+ free_visitors_[i](ptr, index, num_bytes);
+ }
+}
} // namespace tensorflow
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index 774b1fe137..8c23604625 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -23,12 +23,14 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/type_traits.h"
-#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
+class Variant;
+
// Attributes for a single allocation call. Different calls to the same
// allocator could potentially have different allocation attributes.
struct AllocationAttributes {
@@ -228,13 +230,9 @@ class Allocator {
for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
}
- virtual void RunVariantCtor(Variant* p, size_t n) {
- for (size_t i = 0; i < n; ++p, ++i) new (p) Variant();
- }
+ virtual void RunVariantCtor(Variant* p, size_t n);
- virtual void RunVariantDtor(Variant* p, size_t n) {
- for (size_t i = 0; i < n; ++p, ++i) p->~Variant();
- }
+ virtual void RunVariantDtor(Variant* p, size_t n);
// TODO(jeff): Maybe provide some interface to give info about
// current allocation state (total number of bytes available for
@@ -390,13 +388,36 @@ void EnableCPUAllocatorStats(bool enable);
// full statistics. By default, it's disabled.
void EnableCPUAllocatorFullStats(bool enable);
-// Abstract interface of an object that does the underlying suballoc/free of
-// memory for a higher-level allocator.
+// An object that does the underlying suballoc/free of memory for a higher-level
+// allocator. The expectation is that the higher-level allocator is doing some
+// kind of cache or pool management so that it will call SubAllocator::Alloc and
+// Free relatively infrequently, compared to the number of times its own
+// AllocateRaw and Free methods are called.
class SubAllocator {
public:
+ // Visitor gets called with a pointer to a memory area and its
+ // size in bytes. The index value will be numa_node for a CPU
+ // allocator and GPU id for a GPU allocator.
+ typedef std::function<void(void*, int index, size_t)> Visitor;
+
+ SubAllocator(const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors);
+
virtual ~SubAllocator() {}
virtual void* Alloc(size_t alignment, size_t num_bytes) = 0;
virtual void Free(void* ptr, size_t num_bytes) = 0;
+
+ protected:
+ // Implementation of Alloc() method must call this on newly allocated
+ // value.
+ void VisitAlloc(void* ptr, int index, size_t num_bytes);
+
+ // Implementation of Free() method must call this on value to be
+ // freed immediately before deallocation.
+ void VisitFree(void* ptr, int index, size_t num_bytes);
+
+ const std::vector<Visitor> alloc_visitors_;
+ const std::vector<Visitor> free_visitors_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/framework/allocator_registry.h b/tensorflow/core/framework/allocator_registry.h
index 24f282ce84..e907c52ba9 100644
--- a/tensorflow/core/framework/allocator_registry.h
+++ b/tensorflow/core/framework/allocator_registry.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/numa.h"
namespace tensorflow {
diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc
index 1a3994736c..4ffd732f8e 100644
--- a/tensorflow/core/framework/attr_value_util_test.cc
+++ b/tensorflow/core/framework/attr_value_util_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <numeric>
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index b0b27ce94f..284dafb886 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/graph/node_builder.h"
namespace tensorflow {
-
+namespace data {
namespace {
// A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
@@ -179,6 +179,13 @@ Status GraphDefBuilderWrapper::AddFunction(SerializationContext* ctx,
return Status::OK();
}
+void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val,
+ Node** output) {
+ *output = ops::SourceOp(
+ "Placeholder",
+ b_->opts().WithAttr("dtype", val.dtype()).WithAttr("shape", val.shape()));
+}
+
void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
Node** output) {
*output = ops::SourceOp(
@@ -322,4 +329,5 @@ void BackgroundWorker::WorkerLoop() {
}
}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index e06ca68bca..91b1e61d3c 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/model.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -40,6 +41,15 @@ limitations under the License.
namespace tensorflow {
+// Forward declarations to avoid introducing a dependency on headers in
+// "tensorflow/core/graph/...".
+class GraphDefBuilder;
+class Node;
+
+namespace data {
+// A constant that can be used to enable auto-tuning.
+constexpr int kAutoTune = -1;
+
class DatasetBase;
class SerializationContext;
@@ -66,11 +76,6 @@ class IteratorStateWriter {
virtual ~IteratorStateWriter() {}
};
-// Forward declarations to avoid introducing a dependency on headers in
-// "tensorflow/core/graph/...".
-class GraphDefBuilder;
-class Node;
-
// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
class GraphDefBuilderWrapper {
public:
@@ -110,10 +115,11 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
- // Adds a Const node with Tensor value to the Graph.
+ // Adds a `Const` node for the given tensor value to the graph.
+ //
// `*output` contains a pointer to the output `Node`. It is guaranteed to be
- // non-null if the method returns with an OK status.
- // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
+ // non-null if the method returns with an OK status. The returned `Node`
+ // pointer is owned by the backing graph of `GraphDefBuilder`.
Status AddTensor(const Tensor& val, Node** output) {
AddTensorInternal(val, output);
if (*output == nullptr) {
@@ -122,6 +128,20 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
+ // Adds a `Placeholder` node for the given tensor value to the graph.
+ //
+ // `*output` contains a pointer to the output `Node`. It is guaranteed to be
+ // non-null if the method returns with an OK status. The returned `Node`
+ // pointer is owned by the backing graph of `GraphDefBuilder`.
+ Status AddPlaceholder(const Tensor& val, Node** output) {
+ AddPlaceholderInternal(val, output);
+ if (*output == nullptr) {
+ return errors::Internal(
+ "AddPlaceholder: Failed to build Placeholder op.");
+ }
+ return Status::OK();
+ }
+
Status AddDataset(const DatasetBase* dataset,
const std::vector<Node*>& inputs, Node** output) {
return AddDataset(dataset, inputs, {}, output);
@@ -168,6 +188,7 @@ class GraphDefBuilderWrapper {
}
private:
+ void AddPlaceholderInternal(const Tensor& val, Node** output);
void AddTensorInternal(const Tensor& val, Node** output);
Status EnsureFunctionIsStateless(const FunctionLibraryDefinition& flib_def,
@@ -206,8 +227,7 @@ class GraphDefBuilderWrapper {
return (str_util::EndsWith(op_def->name(), "Dataset") &&
op_def->output_arg_size() == 1 &&
op_def->output_arg(0).type() == DT_VARIANT) ||
- dataset::WhitelistedStatefulOpRegistry::Global()->Contains(
- op_def->name());
+ WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name());
}
bool HasAttr(const string& op_type_name, const string& attr_name) const;
@@ -274,6 +294,9 @@ class IteratorContext {
// The Allocator to be used to allocate the output of an iterator.
std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
+
+ // If non-null, identifies the object used for performance modeling.
+ std::shared_ptr<model::Model> model = nullptr;
};
explicit IteratorContext(Params params) : params_(std::move(params)) {}
@@ -325,6 +348,10 @@ class IteratorContext {
return params_.stats_aggregator_getter;
}
+ std::shared_ptr<model::Model> model() { return params_.model; }
+
+ Params params() { return params_; }
+
private:
Params params_;
};
@@ -334,7 +361,8 @@ class SerializationContext {
public:
struct Params {
bool allow_stateful_functions = false;
- const FunctionLibraryDefinition* flib_def; // Not owned.
+ const FunctionLibraryDefinition* flib_def = nullptr; // Not owned.
+ std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned.
};
explicit SerializationContext(Params params) : params_(std::move(params)) {}
@@ -343,6 +371,10 @@ class SerializationContext {
const FunctionLibraryDefinition& flib_def() { return *params_.flib_def; }
+ std::vector<std::pair<string, Tensor>>* input_list() {
+ return params_.input_list;
+ }
+
private:
Params params_;
@@ -354,7 +386,11 @@ class SerializationContext {
// defined below.
class IteratorBase {
public:
- virtual ~IteratorBase() {}
+ virtual ~IteratorBase() {
+ for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
+ (*rit)();
+ }
+ }
// Gets the next output from the range that this iterator is traversing.
//
@@ -388,6 +424,10 @@ class IteratorBase {
// in the outputs of this iterator.
virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
+ // Returns a string that identifies the sequence of iterators leading up to
+ // this iterator.
+ virtual const string& prefix() const = 0;
+
// Performs initialization that needs to happen outside of a constructor to
// properly propagate errors.
virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
@@ -427,6 +467,18 @@ class IteratorBase {
IteratorStateReader* reader) {
return errors::Unimplemented("RestoreInternal");
}
+
+ private:
+ friend class DatasetBase; // for access to `AddCleanupFunction`
+
+ // Registers a cleanup function to be called upon object destruction.
+ //
+ // Registered functions are invoked in the reserve order of registration.
+ void AddCleanupFunction(std::function<void()>&& cleanup_fn) {
+ cleanup_fns_.push_back(std::move(cleanup_fn));
+ }
+
+ std::vector<std::function<void()>> cleanup_fns_;
};
// Represents runtime information needed to construct a dataset.
@@ -476,6 +528,27 @@ class DatasetBase : public core::RefCounted {
Status MakeIterator(IteratorContext* ctx, const string& prefix,
std::unique_ptr<IteratorBase>* iterator) const {
*iterator = MakeIteratorInternal(prefix);
+ if (ctx->model()) {
+ // The prefix might contain an index. We need to strip it to make it
+ // possible for the model to successfully identify the output node.
+ string sanitized_prefix = prefix;
+ if (str_util::EndsWith(prefix, "]")) {
+ sanitized_prefix = prefix.substr(0, prefix.rfind('['));
+ }
+ std::shared_ptr<model::Node> node =
+ ctx->model()->AddNode((*iterator)->prefix(), sanitized_prefix);
+ std::vector<string> tokens =
+ str_util::Split((*iterator)->prefix(), ':', str_util::SkipEmpty());
+ node->set_name(tokens[tokens.size() - 1]);
+ std::shared_ptr<model::Model> model = ctx->model();
+ const string& prefix = (*iterator)->prefix();
+ (*iterator)->AddCleanupFunction([model, node, prefix]() {
+ if (node->output()) {
+ node->output()->remove_input(node);
+ }
+ model->RemoveNode(prefix);
+ });
+ }
return (*iterator)->Initialize(ctx);
}
@@ -502,6 +575,8 @@ class DatasetBase : public core::RefCounted {
IteratorStateWriter* writer) const;
protected:
+ friend class DatasetToGraphOp; // For access to graph related members.
+
class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
public:
DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {}
@@ -519,8 +594,6 @@ class DatasetBase : public core::RefCounted {
virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const = 0;
- friend class DatasetToGraphOp; // For access to graph related members.
-
private:
const string name_;
};
@@ -543,7 +616,7 @@ class DatasetBaseIterator : public IteratorBase {
~DatasetBaseIterator() override { params_.dataset->Unref(); }
// The sequence of iterators leading up to this iterator.
- const string& prefix() const { return params_.prefix; }
+ const string& prefix() const override { return params_.prefix; }
const DataTypeVector& output_dtypes() const override {
return params_.dataset->output_dtypes();
@@ -556,7 +629,23 @@ class DatasetBaseIterator : public IteratorBase {
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) final {
tracing::ScopedActivity activity(params_.prefix);
- Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+ Status s;
+ if (ctx->model()) {
+ std::shared_ptr<model::Node> node =
+ ctx->model()->LookupNode(params_.prefix);
+ if (node->output()) {
+ node->output()->stop_work();
+ }
+ node->start_work();
+ s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+ node->stop_work();
+ node->add_element();
+ if (node->output()) {
+ node->output()->start_work();
+ }
+ } else {
+ s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+ }
if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
s = errors::Internal(
"Iterator \"", params_.prefix,
@@ -583,6 +672,60 @@ class DatasetBaseIterator : public IteratorBase {
return strings::StrCat(params_.prefix, ":", name);
}
+ // When performance modeling is enabled, this method adds a constant parameter
+ // to the model node corresponding to this iterator.
+ void AddConstantParameter(IteratorContext* ctx, const string& name,
+ int64 value) {
+ if (ctx->model()) {
+ std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
+ if (node) {
+ node->add_constant_param(name, value);
+ }
+ }
+ }
+
+ // When performance modeling is enabled, this method adds a tunable parameter
+ // to the model node corresponding to this iterator.
+ //
+ // The `set_fn` function should set the tunable parameter to the value of
+ // its input argument. The function should be thread-safe; in particular, the
+ // state it updates should be protected by a lock as the function can be
+ // invoked asynchronously. It is guaranteed that this function will not be
+ // invoked after the iterator is deleted because the model node that owns
+ // the function is deleted when the iterator is deleted.
+ void AddTunableParameter(IteratorContext* ctx, const string& name,
+ int64 value, int64 min, int64 max,
+ std::function<void(int64)>&& set_fn) {
+ if (ctx->model()) {
+ std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
+ if (node) {
+ node->add_tunable_param(name, value, min, max, std::move(set_fn));
+ }
+ }
+ }
+
+ // When performance modeling is enabled, this method records the fact that
+ // a thread of this iterator has started work.
+ void StartWork(IteratorContext* ctx) {
+ if (ctx->model()) {
+ std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
+ if (node) {
+ node->start_work();
+ }
+ }
+ }
+
+ // When performance modeling is enabled, this method records the fact that
+ // a thread of this iterator has stopped work.
+ void StopWork(IteratorContext* ctx) {
+ if (ctx->model()) {
+ std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
+ if (node) {
+ node->stop_work();
+ }
+ }
+ }
+
private:
BaseParams params_;
};
@@ -730,6 +873,21 @@ class BackgroundWorker {
std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_);
};
+} // namespace data
+
+// TODO(b/114112161): Remove these aliases when all users have moved over to the
+// `tensorflow::data` namespace.
+using data::DatasetBase;
+using data::DatasetContext;
+using data::DatasetIterator;
+using data::DatasetOpKernel;
+using data::IteratorBase;
+using data::IteratorContext;
+using data::IteratorStateReader;
+using data::IteratorStateWriter;
+using data::SerializationContext;
+using data::UnaryDatasetOpKernel;
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
diff --git a/tensorflow/core/framework/dataset_stateful_op_whitelist.h b/tensorflow/core/framework/dataset_stateful_op_whitelist.h
index 3b48999edb..74bd39cb61 100644
--- a/tensorflow/core/framework/dataset_stateful_op_whitelist.h
+++ b/tensorflow/core/framework/dataset_stateful_op_whitelist.h
@@ -16,38 +16,38 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
#define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
+#include <unordered_set>
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
-namespace dataset {
+namespace data {
// Registry for stateful ops that need to be used in dataset functions.
// See below macro for usage details.
class WhitelistedStatefulOpRegistry {
public:
- Status Add(StringPiece op_name) {
- op_names_.insert(op_name);
+ Status Add(string op_name) {
+ op_names_.insert(std::move(op_name));
return Status::OK();
}
- bool Contains(StringPiece op_name) {
- return op_names_.find(op_name) != op_names_.end();
- }
+ bool Contains(const string& op_name) { return op_names_.count(op_name); }
static WhitelistedStatefulOpRegistry* Global() {
- static WhitelistedStatefulOpRegistry* reg =
- new WhitelistedStatefulOpRegistry;
+ static auto* reg = new WhitelistedStatefulOpRegistry;
return reg;
}
private:
- WhitelistedStatefulOpRegistry() {}
- WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy);
+ WhitelistedStatefulOpRegistry() = default;
+ WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy) =
+ delete;
WhitelistedStatefulOpRegistry operator=(
- WhitelistedStatefulOpRegistry const& copy);
- std::set<StringPiece> op_names_;
+ WhitelistedStatefulOpRegistry const& copy) = delete;
+
+ std::unordered_set<string> op_names_;
};
-} // namespace dataset
+} // namespace data
// Use this macro to whitelist an op that is marked stateful but needs to be
// used inside a map_fn in an input pipeline. This is only needed if you wish
@@ -67,10 +67,9 @@ class WhitelistedStatefulOpRegistry {
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(__COUNTER__, name)
#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name)
-#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \
- static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \
- ::tensorflow::dataset::WhitelistedStatefulOpRegistry::Global()->Add( \
- name)
+#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \
+ static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \
+ ::tensorflow::data::WhitelistedStatefulOpRegistry::Global()->Add(name)
} // namespace tensorflow
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index 794250a2c1..53ac639b4c 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -214,10 +214,12 @@ class DeviceBase {
// This is overridden by GPU devices to reinitialize the derived
// type returned by MakeGpuDevice.
- virtual void ReinitializeGpuDevice(OpKernelContext* /*context*/,
- PerOpGpuDevice* /*device*/,
- DeviceContext* /*dc*/,
- Allocator* /*allocator*/) {}
+ virtual Status ReinitializeGpuDevice(OpKernelContext* /*context*/,
+ PerOpGpuDevice* /*device*/,
+ DeviceContext* /*dc*/,
+ Allocator* /*allocator*/) {
+ return Status::OK();
+ }
// Unimplemented by default
virtual const DeviceAttributes& attributes() const;
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 26f32677af..a17959a448 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1154,6 +1154,17 @@ Status FunctionLibraryDefinition::LookUp(
return default_registry_->LookUp(op, op_reg_data);
}
+string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
+ tf_shared_lock l(mu_);
+ int index = 0;
+ string name = strings::StrCat(prefix, index);
+ while (function_defs_.find(name) != function_defs_.end()) {
+ ++index;
+ name = strings::StrCat(prefix, index);
+ }
+ return name;
+}
+
const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
const NodeDef& ndef) const {
if (ndef.op() != kGradientOp) {
@@ -1283,6 +1294,18 @@ FunctionDef FunctionDefHelper::Create(
for (const auto& r : ret_def) {
fdef.mutable_ret()->insert({r.first, r.second});
}
+
+ auto* op_def_registry = OpRegistry::Global();
+ // Check if any op is stateful.
+ for (const auto& n : node_def) {
+ const OpDef* op_def = nullptr;
+ auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
+ // Lookup can fail if e.g. we are calling a function that was not yet
+ // defined. If it happens, conservatively assume the op is stateful.
+ if (!status.ok() || op_def->is_stateful()) {
+ fdef.mutable_signature()->set_is_stateful(true);
+ }
+ }
return fdef;
}
@@ -1344,6 +1367,7 @@ FunctionDef FunctionDefHelper::Define(const string& name,
strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
}
}
+ if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
}
// Returns
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 03296a7761..e01eb7503d 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -358,6 +358,10 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
const OpRegistrationData** op_reg_data) const override
LOCKS_EXCLUDED(mu_);
+ // Generates new function name with the specified prefix that is unique
+ // across this library.
+ string UniqueFunctionName(StringPiece prefix) const LOCKS_EXCLUDED(mu_);
+
// Ops created for function arguments bear the name given by `kArgOp`; those
// created for return values bear the name given by `kRetOp`.
static constexpr const char* const kArgOp = "_Arg";
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index 46b169dddc..d5c203d276 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -91,6 +91,40 @@ FunctionDef IsZero() {
});
}
+FunctionDef RandomUniform() {
+ const Tensor kZero = test::AsScalar<int64>(0);
+ const Tensor kTen = test::AsScalar<int64>(10);
+
+ return FDH::Define(
+ // Name
+ "RandomUniform",
+ // Args
+ {"x: T"},
+ // Return values
+ {"random_uniform: int64"},
+ // Attr def
+ {"T:{float, double, int32, int64, string}"},
+ {{{"random_uniform/shape"},
+ "Const",
+ {},
+ {{"value", kZero}, {"dtype", DT_INT64}}},
+ {{"random_uniform/min"},
+ "Const",
+ {},
+ {{"value", kZero}, {"dtype", DT_INT64}}},
+ {{"random_uniform/max"},
+ "Const",
+ {},
+ {{"value", kTen}, {"dtype", DT_INT64}}},
+ {{"random_uniform"},
+ "RandomUniformInt",
+ {},
+ {{"T", DT_INT64},
+ {"Tout", DT_INT64},
+ {"seed", 87654321},
+ {"seed2", 42}}}});
+}
+
FunctionDef XTimesTwo() {
const Tensor kTwo = test::AsScalar<int64>(2);
return FDH::Define(
@@ -110,6 +144,22 @@ FunctionDef XTimesTwo() {
});
}
+FunctionDef XAddX() {
+ return FDH::Define(
+ // Name
+ "XAddX",
+ // Args
+ {"x: T"},
+ // Return values
+ {"y: T"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"y"}, "Add", {"x", "x"}, {{"T", "$T"}}},
+ });
+}
+
FunctionDef XTimesTwoInt32() {
const Tensor kTwo = test::AsScalar<int64>(2);
return FDH::Define(
diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h
index 6d6476b936..a01743423b 100644
--- a/tensorflow/core/framework/function_testlib.h
+++ b/tensorflow/core/framework/function_testlib.h
@@ -63,6 +63,9 @@ GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
// x:T -> x * 2.
FunctionDef XTimesTwo();
+// x:T -> x + x.
+FunctionDef XAddX();
+
// x:T -> x * 2, where x is int32.
FunctionDef XTimesTwoInt32();
@@ -81,6 +84,9 @@ FunctionDef NonZero();
// x: T -> bool.
FunctionDef IsZero();
+// x: T -> int64
+FunctionDef RandomUniform();
+
// x:T, y:T -> y:T, x:T
FunctionDef Swap();
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
new file mode 100644
index 0000000000..112298c344
--- /dev/null
+++ b/tensorflow/core/framework/model.cc
@@ -0,0 +1,365 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/model.h"
+
+#include <memory>
+
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace tensorflow {
+namespace data {
+namespace model {
+
+// TODO(jsimsa): Use `Node` subclassing instead of types and node statements.
+void Node::CollectTunables(
+ std::vector<std::shared_ptr<Node::Tunable>>* tunables) {
+ mutex_lock l(mu_);
+ for (auto input : inputs_) {
+ input->CollectTunables(tunables);
+ }
+ switch (type_) {
+ case Type::MAP_AND_BATCH:
+ case Type::PARALLEL_INTERLEAVE_V2:
+ case Type::PARALLEL_MAP: {
+ if (auto* tunable_param =
+ gtl::FindOrNull(tunable_params_, "parallelism")) {
+ tunables->push_back(*tunable_param);
+ }
+ return;
+ }
+ default:
+ return;
+ }
+}
+
+int64 Node::GetParameterValue(const string& name) {
+ if (auto* tunable_param = gtl::FindOrNull(tunable_params_, name)) {
+ return (*tunable_param)->value;
+ }
+ return constant_params_[name];
+}
+
+int64 Node::ProcessingTimeLocked() {
+ switch (type_) {
+ case Type::BATCH:
+ case Type::MAP_AND_BATCH:
+ case Type::PADDED_BATCH: {
+ int64 batch_size = GetParameterValue("batch_size");
+ return NanosPerElementLocked() + batch_size * ProcessingTimeForInputs();
+ }
+ case Type::FILTER: {
+ std::shared_ptr<Node> input = inputs_.front();
+ double ratio = static_cast<double>(input->num_elements()) /
+ static_cast<double>(num_elements_);
+ return NanosPerElementLocked() +
+ static_cast<int64>(ratio *
+ static_cast<double>(ProcessingTimeForInputs()));
+ }
+ case Type::FLAT_MAP:
+ case Type::INTERLEAVE:
+ case Type::PARALLEL_INTERLEAVE:
+ case Type::PARALLEL_INTERLEAVE_V2: {
+ // TODO(jsimsa): model the first input
+ // TODO(jsimsa): use processing time history as a prior for future inputs
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 processing_time =
+ ProcessingTimeForInputs() - inputs_.front()->ProcessingTime();
+ return NanosPerElementLocked() +
+ static_cast<double>(processing_time) /
+ static_cast<double>(inputs_.size() - 1);
+ }
+ case Type::CACHE:
+ case Type::CONCATENATE:
+ case Type::MAP:
+ case Type::PARALLEL_MAP:
+ case Type::PREFETCH:
+ // TODO(jsimsa): use processing time history as a prior for future inputs
+ case Type::REPEAT:
+ case Type::SHUFFLE:
+ case Type::SKIP:
+ case Type::TAKE:
+ case Type::ZIP: {
+ return NanosPerElementLocked() + ProcessingTimeForInputs();
+ }
+ default:
+ return NanosPerElementLocked();
+ }
+}
+
+int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
+ switch (type_) {
+ case Type::BATCH:
+ case Type::PADDED_BATCH: {
+ double batch_size = GetParameterValue("batch_size");
+ int64 old_value = (*input_times)[input_times->size() - 1];
+ (*input_times)[input_times->size() - 1] = static_cast<int64>(
+ static_cast<double>(old_value + NanosPerElementLocked()) /
+ batch_size);
+ auto cleanup = gtl::MakeCleanup([input_times, old_value]() {
+ (*input_times)[input_times->size() - 1] = old_value;
+ });
+ return NanosPerElementLocked() +
+ batch_size * OutputTimeForInputs(input_times);
+ }
+ case Type::FILTER: {
+ std::shared_ptr<Node> input = inputs_.front();
+ int64 old_value = (*input_times)[input_times->size() - 1];
+ double ratio = static_cast<double>(input->num_elements()) /
+ static_cast<double>(num_elements_);
+ (*input_times)[input_times->size() - 1] = static_cast<int64>(
+ static_cast<double>(old_value + NanosPerElementLocked()) / ratio);
+ auto cleanup = gtl::MakeCleanup([input_times, old_value]() {
+ (*input_times)[input_times->size() - 1] = old_value;
+ });
+ return NanosPerElementLocked() +
+ static_cast<int64>(
+ static_cast<double>(OutputTimeForInputs(input_times)) * ratio);
+ }
+ case Type::FLAT_MAP:
+ case Type::INTERLEAVE: {
+ // TODO(jsimsa): model the first input
+ // TODO(jsimsa): use cycle length metadata instead of `inputs_.size() - 1`
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 delta =
+ static_cast<int64>(static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1));
+ (*input_times)[input_times->size() - 1] += delta;
+ auto cleanup = gtl::MakeCleanup([input_times, delta]() {
+ (*input_times)[input_times->size() - 1] -= delta;
+ });
+ int64 output_time = OutputTimeForInputs(input_times) -
+ inputs_.front()->OutputTime(input_times);
+ return NanosPerElementLocked() +
+ static_cast<double>(output_time) /
+ static_cast<double>(inputs_.size() - 1);
+ }
+ case Type::MAP_AND_BATCH: {
+ double batch_size = GetParameterValue("batch_size");
+ double parallelism = GetParameterValue("parallelism");
+ int64 delta =
+ static_cast<int64>(static_cast<double>(NanosPerElementLocked()) /
+ (batch_size * parallelism));
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 output_time = static_cast<int64>(
+ static_cast<double>(NanosPerElementLocked()) / parallelism +
+ batch_size * OutputTimeForInputs(input_times));
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PARALLEL_INTERLEAVE: {
+ // TODO(jsimsa): model the first input
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 delta = static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1);
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 inputs_output_time = OutputTimeForInputs(input_times) -
+ inputs_.front()->OutputTime(input_times);
+ double parallelism = GetParameterValue("parallelism");
+ int64 output_time =
+ NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
+ static_cast<double>(inputs_.size() - 1)) /
+ parallelism);
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PARALLEL_INTERLEAVE_V2: {
+ // TODO(jsimsa): model the first input
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 delta = static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1);
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 inputs_output_time = OutputTimeForInputs(input_times) -
+ inputs_.front()->OutputTime(input_times);
+ double parallelism =
+ std::min(static_cast<int>(GetParameterValue("cycle_length")),
+ static_cast<int>(GetParameterValue("parallelism")));
+ int64 output_time =
+ NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
+ static_cast<double>(inputs_.size() - 1)) /
+ parallelism);
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PARALLEL_MAP: {
+ double parallelism =
+ std::min(port::NumSchedulableCPUs(),
+ static_cast<int>(GetParameterValue("parallelism")));
+ int64 delta = static_cast<int64>(
+ static_cast<double>(NanosPerElementLocked()) / parallelism);
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 output_time =
+ static_cast<double>(NanosPerElementLocked()) / parallelism +
+ OutputTimeForInputs(input_times);
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PREFETCH: {
+ int64 delta = NanosPerElementLocked();
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ return std::max(0LL, NanosPerElementLocked() +
+ OutputTimeForInputs(input_times) -
+ input_times->at(input_times->size() - 2));
+ }
+ case Type::CACHE:
+ case Type::CONCATENATE:
+ case Type::MAP:
+ case Type::REPEAT:
+ case Type::SHUFFLE:
+ case Type::SKIP:
+ case Type::TAKE:
+ case Type::ZIP: {
+ int64 delta = NanosPerElementLocked();
+ (*input_times)[input_times->size() - 1] += delta;
+ auto cleanup = gtl::MakeCleanup([input_times, delta]() {
+ (*input_times)[input_times->size() - 1] -= delta;
+ });
+ return NanosPerElementLocked() + OutputTimeForInputs(input_times);
+ }
+ default:
+ return NanosPerElementLocked();
+ }
+}
+
+std::shared_ptr<Node> Model::AddNode(const string& name,
+ const string& output_name) {
+ mutex_lock l(mu_);
+ std::shared_ptr<Node> output;
+ auto it = lookup_table_.find(output_name);
+ if (it != lookup_table_.end()) {
+ output = it->second;
+ }
+ std::shared_ptr<Node> node(new Node(id_counter_++, output));
+ if (!output_) {
+ output_ = node;
+ }
+ if (output) {
+ output->add_input(node);
+ }
+ lookup_table_.insert(std::make_pair(name, node));
+ return node;
+}
+
+std::shared_ptr<Node> Model::LookupNode(const string& name) {
+ tf_shared_lock l(mu_);
+ std::shared_ptr<Node> result;
+ auto it = lookup_table_.find(name);
+ if (it != lookup_table_.end()) {
+ result = it->second;
+ }
+ return result;
+}
+
+// The optimization algorithm starts by setting all tunable parallelism
+// parameters to 1. It then repeatedly identifies the parameter that whose
+// increase in parallelism decreases the output time the most. This process is
+// repeated until all parameters reach their maximum values or the
+// projected output time is less than or equal to the processing time needed to
+// produce an element divided by CPU budget.
+void Model::Optimize(int64 cpu_budget) {
+ mutex_lock l(optimization_mu_);
+ std::vector<std::shared_ptr<Node::Tunable>> tunables;
+ {
+ mutex_lock l2(mu_);
+ const int64 processing_time = ProcessingTime();
+ tunables = CollectTunables();
+ for (auto tunable : tunables) {
+ tunable->value = 1;
+ }
+ while (true) {
+ const int64 output_time = OutputTime();
+ bool all_tunables = true;
+ for (auto& tunable : tunables) {
+ if (tunable->value < tunable->max) {
+ all_tunables = false;
+ break;
+ }
+ }
+ if (output_time < processing_time / cpu_budget || all_tunables) {
+ break;
+ }
+ int64 best_delta = -1;
+ Node::Tunable* best_tunable = nullptr;
+ for (auto& tunable : tunables) {
+ if (tunable->value == tunable->max) {
+ continue;
+ }
+ tunable->value++;
+ int64 delta = output_time - OutputTime();
+ if (delta > best_delta) {
+ best_delta = delta;
+ best_tunable = tunable.get();
+ }
+ tunable->value--;
+ }
+ if (!best_tunable) {
+ // NOTE: This can happen because we are performing the optimization
+ // while the model data is changing. If this becomes an issue, we should
+ // look into performing the optimization using a model snapshot.
+ break;
+ }
+ best_tunable->value++;
+ }
+ }
+ // The `set_fn` functions should be invoked without holding a lock to avoid a
+ // potential deadlock.
+ for (auto& tunable : tunables) {
+ tunable->set_fn(tunable->value);
+ }
+}
+
+void Model::RemoveNode(const string& prefix) {
+ // Nodes are not allowed to be removed when optimization is in progress to
+ // prevent the optimization from trying to access an iterator that was
+ // concurrently deleted.
+ mutex_lock l(optimization_mu_);
+ mutex_lock l2(mu_);
+ lookup_table_.erase(prefix);
+}
+
+std::vector<std::shared_ptr<Node::Tunable>> Model::CollectTunables() {
+ std::vector<std::shared_ptr<Node::Tunable>> tunables;
+ output_->CollectTunables(&tunables);
+ return tunables;
+}
+
+int64 Model::OutputTime() {
+ std::vector<int64> input_times(1, 0);
+ return output_->OutputTime(&input_times);
+}
+
+int64 Model::ProcessingTime() { return output_->ProcessingTime(); }
+
+} // namespace model
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
new file mode 100644
index 0000000000..f88ec06ef3
--- /dev/null
+++ b/tensorflow/core/framework/model.h
@@ -0,0 +1,379 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
+
+#include <list>
+#include <memory>
+#include <string>
+#include <thread> // (b/114492873): move this include into core/platform
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+namespace data {
+namespace model {
+
+class Model;
+class Node;
+
+// Abstract representation of a TensorFlow input pipeline node. It collects
+// information about inputs to this node, processing time spent executing the
+// node logic, number of elements produced by the node, various other
+// information (e.g. batch size or execution parallelism).
+//
+// Developers of tf.data transformations are not expected to interact with this
+// class directly. Boiler plate code for creating the abstract representation of
+// the input pipeline and collecting common information has been added to the
+// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
+//
+// In addition, `DatasetBaseIterator` provides wrappers that can be used for
+// transformation-specific information collection. The `SetMetadata` wrapper can
+// be used to pass arbitrary metadata to the modeling framework, while the
+// `StartWork` and `StopWork` wrappers should be used to correctly account for
+// processing time of multi-threaded transformation that yield the CPU; such
+// transformations should invoke `StartWork()` when a transformation thread
+// starts executing (e.g. when created or woken up) and `StopWork()` when a
+// transformation thread stops executing (e.g. when returning or waiting).
+//
+// TODO(jsimsa): Create an API to capture the abstract semantics of each
+// tf.data transformation and replace switch-case blocks with inheritance.
+class Node {
+ public:
+ Node(int64 id, std::shared_ptr<Node> output) : id_(id), output_(output) {}
+
+ // Adds a constant parameter.
+ void add_constant_param(const string& name, int64 value) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ constant_params_[name] = value;
+ }
+
+ // Records that the node produced an element.
+ void add_element() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ num_elements_++;
+ }
+
+ // Adds an input.
+ void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ inputs_.push_back(node);
+ }
+
+ // Increments the aggregate processing time by the given delta.
+ void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ processing_time_ += delta;
+ }
+
+ // Adds a tunable parameter.
+ void add_tunable_param(const string& name, int64 value, int64 min, int64 max,
+ std::function<void(int64)>&& set_fn)
+ LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ tunable_params_[name] =
+ std::make_shared<Tunable>(value, min, max, std::move(set_fn));
+ }
+
+ // Returns the unique node ID.
+ int64 id() LOCKS_EXCLUDED(mu_) { return id_; }
+
+ // Returns the node inputs.
+ std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return inputs_;
+ }
+
+ // Returns the node name.
+ const string& name() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return name_;
+ }
+
+ // Returns the number of elements produced by the node.
+ int64 num_elements() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return num_elements_;
+ }
+
+ // Returns the node output.
+ std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return output_;
+ }
+
+ // Removes an input.
+ void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ inputs_.remove(input);
+ }
+
+ // Sets the node name.
+ void set_name(const string& name) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ name_ = name;
+ type_ = TypeFromName(name);
+ }
+
+ // Set the node output.
+ void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ output_ = output;
+ }
+
+ // Records that a node thread has started work.
+ void start_work() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos();
+ }
+
+ // Records that a node thread has stopped work.
+ void stop_work() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ auto iter = work_start_.find(std::this_thread::get_id());
+ CHECK(work_start_.end() != iter)
+ << "Encountered a stop event that was not preceded by a start event.";
+ processing_time_ += Env::Default()->NowNanos() - iter->second;
+ work_start_.erase(iter);
+ }
+
+ private:
+ // Represents a tunable parameter.
+ struct Tunable {
+ Tunable(int64 value, int64 min, int64 max,
+ std::function<void(int64)> set_fn)
+ : value(value), min(min), max(max), set_fn(std::move(set_fn)) {}
+
+ int64 value;
+ int64 min;
+ int64 max;
+ std::function<void(int64)> set_fn;
+ };
+
+ enum class Type {
+ BATCH = 0,
+ CACHE,
+ CONCATENATE,
+ FILTER,
+ FLAT_MAP,
+ INTERLEAVE,
+ MAP,
+ MAP_AND_BATCH,
+ PADDED_BATCH,
+ PARALLEL_INTERLEAVE,
+ PARALLEL_INTERLEAVE_V2,
+ PARALLEL_MAP,
+ PREFETCH,
+ REPEAT,
+ SHUFFLE,
+ SKIP,
+ TAKE,
+ ZIP,
+ UNKNOWN,
+ };
+
+ // Collects tunable parameters in the subtree rooted in this node.
+ void CollectTunables(std::vector<std::shared_ptr<Node::Tunable>>* tunables)
+ LOCKS_EXCLUDED(mu_);
+
+ // Gets a value of the given parameter (tunable or constant).
+ int64 GetParameterValue(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Returns the per-element processing time spent in this node.
+ int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ return NanosPerElementLocked();
+ }
+
+ int64 NanosPerElementLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (num_elements_ == 0) {
+ return 0;
+ }
+ return (int64)((double)processing_time_ / (double)num_elements_);
+ }
+
+ // Returns the per-element output time for this node.
+ int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ return OutputTimeLocked(input_times);
+ }
+
+ int64 OutputTimeLocked(std::vector<int64>* input_times)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ int64 OutputTimeForInputs(std::vector<int64>* input_times)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 sum = 0;
+ for (auto input : inputs_) {
+ sum += input->OutputTime(input_times);
+ }
+ return sum;
+ }
+
+ // Returns the per-element processing time spent in the subtree rooted in this
+ // node.
+ int64 ProcessingTime() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ return ProcessingTimeLocked();
+ }
+
+ int64 ProcessingTimeLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Returns the per-element processing time spent in the inputs of this node.
+ int64 ProcessingTimeForInputs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 sum = 0;
+ for (auto input : inputs_) {
+ sum += input->ProcessingTimeLocked();
+ }
+ return sum;
+ }
+
+ Type TypeFromName(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (name_ == "Batch") {
+ return Type::BATCH;
+ }
+ if (str_util::EndsWith(name_, "Cache")) {
+ return Type::CACHE;
+ }
+ if (name_ == "Concatenate") {
+ return Type::CONCATENATE;
+ }
+ if (name_ == "Filter") {
+ return Type::FILTER;
+ }
+ if (name_ == "FlatMap") {
+ return Type::FLAT_MAP;
+ }
+ if (name_ == "Interleave") {
+ return Type::INTERLEAVE;
+ }
+ if (name_ == "Map") {
+ return Type::MAP;
+ }
+ if (name_ == "MapAndBatch") {
+ return Type::MAP_AND_BATCH;
+ }
+ if (name_ == "PaddedBatch") {
+ return Type::PADDED_BATCH;
+ }
+ if (name_ == "ParallelInterleave") {
+ return Type::PARALLEL_INTERLEAVE;
+ }
+ if (name_ == "ParallelInterleaveV2") {
+ return Type::PARALLEL_INTERLEAVE_V2;
+ }
+ if (name_ == "ParallelMap") {
+ return Type::PARALLEL_MAP;
+ }
+ if (name_ == "Prefetch") {
+ return Type::PREFETCH;
+ }
+ if (str_util::EndsWith(name_, "Repeat")) {
+ return Type::REPEAT;
+ }
+ if (name_ == "Shuffle") {
+ return Type::SHUFFLE;
+ }
+ if (str_util::EndsWith(name_, "Skip")) {
+ return Type::SKIP;
+ }
+ if (str_util::EndsWith(name_, "Take")) {
+ return Type::TAKE;
+ }
+ if (name_ == "Zip") {
+ return Type::ZIP;
+ }
+ return Type::UNKNOWN;
+ }
+
+ mutex mu_;
+ const int64 id_;
+ Type type_ GUARDED_BY(mu_);
+ string name_ GUARDED_BY(mu_);
+ int64 processing_time_ GUARDED_BY(mu_) = 0;
+ int64 num_elements_ GUARDED_BY(mu_) = 0;
+ std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
+ std::map<string, int64> constant_params_ GUARDED_BY(mu_);
+ // Tunables are shared with the model during optimization.
+ std::map<string, std::shared_ptr<Tunable>> tunable_params_ GUARDED_BY(mu_);
+ std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
+ std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+
+ friend class Model;
+};
+
+// Abstract representation of a TensorFlow input pipeline that can be used
+// for collecting runtime information and optimizing performance. It collects
+// runtime information about execution of the input pipeline that is used to
+// create a performance model, which is in turn used to identify optimal values
+// of tunable parameters.
+//
+// Developers of tf.data transformations are not expected to interact with this
+// class directly. Boiler plate code for creating the abstract representation of
+// the input pipeline and collecting runtime information has been added to the
+// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
+class Model {
+ public:
+ Model() = default;
+
+ // Returns the model output node.
+ std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return output_;
+ }
+
+ // Adds a node with the given name and given output (identified by name).
+ std::shared_ptr<Node> AddNode(const string& name, const string& output_name)
+ LOCKS_EXCLUDED(mu_);
+
+ // Looks up the node using the given name.
+ std::shared_ptr<Node> LookupNode(const string& name) LOCKS_EXCLUDED(mu_);
+
+ // Runs optimization.
+ void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_);
+
+ // Removes the node identified by the given name.
+ void RemoveNode(const string& prefix) LOCKS_EXCLUDED(mu_);
+
+ private:
+ std::vector<std::shared_ptr<Node::Tunable>> CollectTunables()
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ int64 OutputTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ int64 ProcessingTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Used for coordination between different input pipeline threads.
+ mutex mu_;
+ // Used for preventing iterator deletion when optimization is in progress
+ // because the optimization may try to update the values of tunable
+ // parameters.
+ mutex optimization_mu_ ACQUIRED_BEFORE(mu_);
+ int64 id_counter_ GUARDED_BY(mu_) = 1;
+ std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+ std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_);
+};
+
+} // namespace model
+} // namespace data
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index bacc1d72c4..42ec315a32 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -403,6 +403,14 @@ Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
return OutputTypesForNode(node_def, op_def, outputs);
}
+Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
+ int* num_outputs) {
+ DataTypeVector outputs;
+ TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs));
+ *num_outputs = outputs.size();
+ return Status::OK();
+}
+
Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
if (node_def.op() != op_def.name()) {
return errors::InvalidArgument("NodeDef op '", node_def.op(),
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 499034cab2..7528d3d306 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -261,6 +261,10 @@ Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
// REQUIRES: ValidateOpDef(op_def).ok()
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
DataTypeVector* inputs, DataTypeVector* outputs);
+// Computes the number of outputs for a specific node.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
+ int* num_outputs);
// Validates that the NodeDef:
// * Defines all expected attrs from the OpDef.
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index c694e10193..3e34bf0418 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -80,10 +81,8 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
// OpKernel ------------------------------------------------------------------
-// TODO(mrry): Convert to std::make_unique when available.
OpKernel::OpKernel(OpKernelConstruction* context)
- : OpKernel(context,
- std::unique_ptr<const NodeDef>(new NodeDef(context->def()))) {}
+ : OpKernel(context, MakeUnique<const NodeDef>(context->def())) {}
OpKernel::OpKernel(OpKernelConstruction* context,
std::unique_ptr<const NodeDef> node_def)
@@ -266,9 +265,12 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs)
params_->ensure_eigen_gpu_device();
if (params_->eigen_gpu_device != nullptr) {
Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
- params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device,
- params_->op_device_context,
- eigen_gpu_allocator);
+ Status s = params_->device->ReinitializeGpuDevice(
+ this, params_->eigen_gpu_device, params_->op_device_context,
+ eigen_gpu_allocator);
+ if (!s.ok()) {
+ SetStatus(s);
+ }
}
if (params_->record_tensor_accesses) {
referenced_tensors_.Init();
@@ -525,10 +527,8 @@ std::unique_ptr<Tensor> OpKernelContext::forward_input(
return nullptr;
}
}
- // TODO(rmlarsen): Use MakeUnique here. There is already a copy in
- // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of
- // general cleanup of ownership in this code.
- std::unique_ptr<Tensor> output_tensor(new Tensor());
+
+ auto output_tensor = MakeUnique<Tensor>();
CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
return output_tensor;
}
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index e752599de1..4bbd6c3d7d 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -372,18 +372,37 @@ class OpKernelConstruction {
template <typename ListType, typename ElementType>
class OpArgIterator {
public:
- typedef OpArgIterator<ListType, ElementType> ME;
+ using iterator_category = std::forward_iterator_tag;
+ using value_type = ElementType;
+ using pointer = ElementType*;
+ using reference = ElementType&;
+ using difference_type = ptrdiff_t;
+
OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
- bool operator==(const ME& rhs) {
+
+ bool operator==(const OpArgIterator& rhs) {
DCHECK(list_ == rhs.list_);
return i_ == rhs.i_;
}
- bool operator!=(const ME& rhs) {
+
+ bool operator!=(const OpArgIterator& rhs) {
DCHECK(list_ == rhs.list_);
return i_ != rhs.i_;
}
- void operator++() { ++i_; }
- ElementType& operator*() { return (*list_)[i_]; }
+
+ OpArgIterator operator++() { // prefix ++it
+ ++i_;
+ return *this;
+ }
+
+ OpArgIterator operator++(int) { // postfix it++
+ OpArgIterator old_value = *this;
+ ++i_;
+ return old_value;
+ }
+
+ reference operator*() { return (*list_)[i_]; }
+ pointer operator->() { return &(*list_)[i_]; }
private:
const ListType* const list_;
@@ -394,7 +413,7 @@ class OpArgIterator {
// that are passed to the op as a single named argument.
class OpInputList {
public:
- typedef OpArgIterator<OpInputList, const Tensor&> Iterator;
+ typedef OpArgIterator<OpInputList, const Tensor> Iterator;
OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpInputList(OpKernelContext* ctx, int start, int stop)
: ctx_(ctx), start_(start), stop_(stop) {}
diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc
index dfc5aa7747..75ed4a4eaf 100644
--- a/tensorflow/core/framework/op_segment.cc
+++ b/tensorflow/core/framework/op_segment.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_segment.h"
+#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -99,4 +100,11 @@ void OpSegment::RemoveHold(const string& session_handle) {
delete item;
}
+bool OpSegment::ShouldOwnKernel(FunctionLibraryRuntime* lib,
+ const string& node_op) {
+ // OpSegment should not own kernel if the node is stateless, or a function.
+ return lib->IsStateful(node_op) &&
+ lib->GetFunctionLibraryDefinition()->Find(node_op) == nullptr;
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/op_segment.h b/tensorflow/core/framework/op_segment.h
index 4433a2554f..37d939ea2b 100644
--- a/tensorflow/core/framework/op_segment.h
+++ b/tensorflow/core/framework/op_segment.h
@@ -60,6 +60,10 @@ class OpSegment {
Status FindOrCreate(const string& session_handle, const string& node_name,
OpKernel** kernel, CreateKernelFn create_fn);
+ // Returns true if OpSegment should own the kernel.
+ static bool ShouldOwnKernel(FunctionLibraryRuntime* lib,
+ const string& node_op);
+
private:
// op name -> OpKernel
typedef std::unordered_map<string, OpKernel*> KernelMap;
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index 0a19861efd..ebdaaec153 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -271,7 +271,7 @@ string ContainerInfo::DebugString() const {
"]");
}
-ResourceHandle HandleFromInput(OpKernelContext* ctx, int input) {
+const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input) {
return ctx->input(input).flat<ResourceHandle>()(0);
}
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index f8a587c9b5..d58deaa3fc 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -79,7 +79,7 @@ class ResourceBase : public core::RefCounted {
virtual string DebugString() = 0;
// Returns memory used by this resource.
- virtual int64 MemoryUsed() const { return 0; };
+ virtual int64 MemoryUsed() const { return 0; }
};
// Container used for per-step resources.
@@ -234,7 +234,7 @@ ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx,
const string& name);
// Returns a resource handle from a numbered op input.
-ResourceHandle HandleFromInput(OpKernelContext* ctx, int input);
+const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input);
Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
ResourceHandle* handle);
@@ -348,6 +348,8 @@ class ResourceHandleOp : public OpKernel {
void Compute(OpKernelContext* ctx) override;
+ bool IsExpensive() override { return false; }
+
private:
string container_;
string name_;
diff --git a/tensorflow/core/framework/stats_aggregator.h b/tensorflow/core/framework/stats_aggregator.h
index 4a18efc940..af53ed0a3c 100644
--- a/tensorflow/core/framework/stats_aggregator.h
+++ b/tensorflow/core/framework/stats_aggregator.h
@@ -25,6 +25,8 @@ namespace tensorflow {
class Summary;
+namespace data {
+
// A `StatsAggregator` accumulates statistics incrementally. A
// `StatsAggregator` can accumulate multiple different statistics, distinguished
// by a string name.
@@ -87,6 +89,7 @@ class StatsAggregatorResource : public ResourceBase {
const std::shared_ptr<StatsAggregator> stats_aggregator_;
};
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 1b19ab5da3..696fd277cd 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -37,11 +37,12 @@ namespace tensorflow {
class AllocationDescription;
class Allocator;
class OpKernelContext;
+class Tensor;
class TensorBuffer;
class TensorCApi;
class TensorDescription;
class TensorProto;
-class VariantTensorData;
+
namespace batch_util {
Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index);
Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index);
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index 84a373c196..9a78cdc91e 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/math/math_util.h"
diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h
index 4bda8f9eb8..a7cf600bab 100644
--- a/tensorflow/core/framework/tensor_util.h
+++ b/tensorflow/core/framework/tensor_util.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include <vector>
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
index 15b1add2c1..2e96b05787 100644
--- a/tensorflow/core/framework/types.h
+++ b/tensorflow/core/framework/types.h
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -39,6 +38,8 @@ limitations under the License.
namespace tensorflow {
+class Variant;
+
// MemoryType is used to describe whether input or output Tensors of
// an OpKernel should reside in "Host memory" (e.g., CPU memory) or
// "Device" Memory (CPU memory for CPU devices, GPU memory for GPU
diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc
index 5a507804b0..d43e3c72ec 100644
--- a/tensorflow/core/framework/variant.cc
+++ b/tensorflow/core/framework/variant.cc
@@ -23,11 +23,11 @@ limitations under the License.
namespace tensorflow {
-bool Variant::TryDecode(Variant* out) const {
- const VariantTensorDataProto* p = get<VariantTensorDataProto>();
- if (p == nullptr) return false;
- VariantTensorData data(*p);
- return out->Decode(data);
+bool Variant::Decode(VariantTensorData data) {
+ if (!is_empty()) {
+ return value_->Decode(std::move(data));
+ }
+ return true;
}
template <>
@@ -54,13 +54,12 @@ string TypeNameVariant(const VariantTensorDataProto& value) {
template <>
void EncodeVariant(const VariantTensorDataProto& value,
VariantTensorData* data) {
- data->FromProto(value);
+ data->FromConstProto(value);
}
template <>
-bool DecodeVariant(const VariantTensorData& data,
- VariantTensorDataProto* value) {
- data.ToProto(value);
+bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value) {
+ data->ToProto(value);
return true;
}
@@ -70,8 +69,8 @@ void EncodeVariant(const VariantTensorDataProto& value, string* buf) {
}
template <>
-bool DecodeVariant(const string& buf, VariantTensorDataProto* value) {
- return value->ParseFromString(buf);
+bool DecodeVariant(string* buf, VariantTensorDataProto* value) {
+ return value->ParseFromString(*buf);
}
void EncodeVariantList(const Variant* variant_array, int64 n,
@@ -93,8 +92,10 @@ bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d,
if (variant_array[i].is_empty()) {
variant_array[i] = VariantTensorDataProto();
}
+ // TODO(ebrevdo): Replace with StringPiece? Any way to make this a
+ // zero-copy operation that keeps a reference to the data in d?
string str(d->Data(sizes[i]), sizes[i]);
- if (!variant_array[i].Decode(str)) return false;
+ if (!variant_array[i].Decode(std::move(str))) return false;
if (!DecodeUnaryVariant(&variant_array[i])) {
LOG(ERROR) << "Could not decode variant with type_name: \""
<< variant_array[i].TypeName()
diff --git a/tensorflow/core/framework/variant.h b/tensorflow/core/framework/variant.h
index 52732801a0..10eabbc85f 100644
--- a/tensorflow/core/framework/variant.h
+++ b/tensorflow/core/framework/variant.h
@@ -23,7 +23,6 @@ limitations under the License.
#include <unordered_map>
#include <utility>
-#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/core/status.h"
@@ -38,17 +37,19 @@ string TypeNameVariant(const T& value);
template <typename T>
string DebugStringVariant(const T& value);
+// Allows for specializations of Variant Decoding. `data` may be modified in
+// the process of decoding to `value`.
template <typename T>
-void EncodeVariant(const T& value, VariantTensorData* data);
+bool DecodeVariant(VariantTensorData* data, T* value);
template <typename T>
-bool DecodeVariant(const VariantTensorData& data, T* value);
+bool DecodeVariant(string* buf, T* value);
template <typename T>
-void EncodeVariant(const T& value, string* buf);
+void EncodeVariant(const T& value, VariantTensorData* data);
template <typename T>
-bool DecodeVariant(const string& buf, T* value);
+void EncodeVariant(const T& value, string* buf);
// This is an implementation of a type-erased container that can store an
// object of any type. The implementation is very similar to std::any, but has
@@ -67,7 +68,7 @@ bool DecodeVariant(const string& buf, T* value);
//
// string TypeName() const;
// void Encode(VariantTensorData* data) const;
-// void Decode(const VariantTensorData& data);
+// void Decode(VariantTensorData data);
//
// Simple POD types can elide the Encode/Decode functions, they are provided by
// helper methods.
@@ -121,7 +122,7 @@ bool DecodeVariant(const string& buf, T* value);
// x.Encode(&serialized_f);
//
// Variant y = Foo(); // default constructed Foo.
-// y.Decode(&serialized_f);
+// y.Decode(std::move(serialized_f));
// EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>());
//
//
@@ -145,10 +146,6 @@ bool DecodeVariant(const string& buf, T* value);
// EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName()); // Looks like Foo.
// EXPECT_EQ(MakeTypeIndex<VariantTensorDataProto>(),
// y_type_unknown.TypeId());
-// // Decode and get y_type_unknown; compare to value in x.
-// Foo f_decoded;
-// EXPECT_TRUE(x.MaybeDecodeAndCopy(&f_decoded));
-// EXPECT_EQ(f_decoded, f);
//
class Variant {
public:
@@ -241,12 +238,7 @@ class Variant {
}
// Deserialize `data` and update the stored object.
- bool Decode(const VariantTensorData& data) {
- if (!is_empty()) {
- return value_->Decode(data);
- }
- return true;
- }
+ bool Decode(VariantTensorData data);
// Helper methods to directly serialize/deserialize from strings.
void Encode(string* buf) const {
@@ -254,31 +246,13 @@ class Variant {
value_->Encode(buf);
}
}
- bool Decode(const string& buf) {
+ bool Decode(string buf) {
if (!is_empty()) {
- return value_->Decode(buf);
+ return value_->Decode(std::move(buf));
}
return true;
}
- template <typename T>
- bool MaybeDecodeAndCopy(T* out) const {
- const T* ret = get<T>();
- if (ret != nullptr) {
- *out = std::move(*ret);
- return true;
- };
- Variant decoded = T();
- if (!TryDecode(&decoded)) return false;
- T* decoded_ret = decoded.get<T>();
- CHECK_NOTNULL(decoded_ret);
- *out = std::move(*decoded_ret);
- return true;
- }
-
- private:
- bool TryDecode(Variant* out) const;
-
private:
struct in_place_t {};
static constexpr in_place_t in_place{};
@@ -292,9 +266,9 @@ class Variant {
virtual string TypeName() const = 0;
virtual string DebugString() const = 0;
virtual void Encode(VariantTensorData* data) const = 0;
- virtual bool Decode(const VariantTensorData& data) = 0;
+ virtual bool Decode(VariantTensorData data) = 0;
virtual void Encode(string* buf) const = 0;
- virtual bool Decode(const string& data) = 0;
+ virtual bool Decode(string data) = 0;
};
template <typename T>
@@ -325,15 +299,13 @@ class Variant {
EncodeVariant(value, data);
}
- bool Decode(const VariantTensorData& data) override {
- return DecodeVariant(data, &value);
+ bool Decode(VariantTensorData data) override {
+ return DecodeVariant(&data, &value);
}
void Encode(string* buf) const override { EncodeVariant(value, buf); }
- bool Decode(const string& buf) override {
- return DecodeVariant(buf, &value);
- }
+ bool Decode(string buf) override { return DecodeVariant(&buf, &value); }
T value;
};
diff --git a/tensorflow/core/framework/variant_encode_decode.h b/tensorflow/core/framework/variant_encode_decode.h
index f155aa4892..5e08e5a7a6 100644
--- a/tensorflow/core/framework/variant_encode_decode.h
+++ b/tensorflow/core/framework/variant_encode_decode.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/abi.h"
@@ -81,7 +82,7 @@ void EncodeVariantImpl(const T& value,
// Specialization for POD type
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, true /* is_pod */, false /* Tensor */,
false /* protobuf */>,
T* value) {
@@ -90,7 +91,7 @@ bool DecodeVariantImpl(const VariantTensorData& data,
// Specialization for tensorflow::Tensor
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, true /* Tensor */,
false /* protobuf */>,
T* value) {
@@ -100,7 +101,7 @@ bool DecodeVariantImpl(const VariantTensorData& data,
// Specialization for protobuf
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, false /* Tensor */,
true /* protobuf */>,
T* value) {
@@ -111,11 +112,11 @@ bool DecodeVariantImpl(const VariantTensorData& data,
// Specialization for other types
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, false /* Tensor */,
false /* protobuf */>,
T* value) {
- return value->Decode(data);
+ return value->Decode(std::move(data));
}
template <typename C, typename = void>
@@ -224,8 +225,8 @@ void EncodeVariant(const T& value, VariantTensorData* data) {
}
template <typename T>
-bool DecodeVariant(const VariantTensorData& data, T* value) {
- return DecodeVariantImpl(data, TypeResolver<T>(), value);
+bool DecodeVariant(VariantTensorData* data, T* value) {
+ return DecodeVariantImpl(std::move(*data), TypeResolver<T>(), value);
}
template <typename T>
@@ -238,26 +239,31 @@ void EncodeVariant(const T& value, string* buf) {
}
template <typename T>
-bool DecodeVariant(const string& buf, T* value) {
+bool DecodeVariant(string* buf, T* value) {
VariantTensorData data;
- if (!data.ParseFromString(buf)) return false;
- if (!DecodeVariantImpl(data, TypeResolver<T>(), value)) return false;
+ if (!data.ParseFromString(*buf)) return false;
+ if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) {
+ return false;
+ }
return true;
}
// Specializations for VariantTensorDataProto
template <>
string TypeNameVariant(const VariantTensorDataProto& value);
+
template <>
void EncodeVariant(const VariantTensorDataProto& value,
VariantTensorData* data);
+
template <>
-bool DecodeVariant(const VariantTensorData& data,
- VariantTensorDataProto* value);
+bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value);
+
template <>
void EncodeVariant(const VariantTensorDataProto& value, string* buf);
+
template <>
-bool DecodeVariant(const string& buf, VariantTensorDataProto* value);
+bool DecodeVariant(string* buf, VariantTensorDataProto* value);
// Encodes an array of Variant objects in to the given StringListEncoder.
// `variant_array` is assumed to point to an array of `n` Variant objects.
diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc
index 60fa7bd559..daa744e877 100644
--- a/tensorflow/core/framework/variant_op_copy_test.cc
+++ b/tensorflow/core/framework/variant_op_copy_test.cc
@@ -90,15 +90,15 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(StoredTensorValue, "StoredTensorValue");
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
StoredTensorValue, VariantDeviceCopyDirection::HOST_TO_DEVICE,
- "StoredTensorValue", StoredTensorValue::CopyCPUToGPU);
+ StoredTensorValue::CopyCPUToGPU);
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_HOST,
- "StoredTensorValue", StoredTensorValue::CopyGPUToCPU);
+ StoredTensorValue::CopyGPUToCPU);
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
- "StoredTensorValue", StoredTensorValue::CopyGPUToGPU);
+ StoredTensorValue::CopyGPUToGPU);
REGISTER_OP("CreateTestVariant")
.Input("input: T")
diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc
index ee07db1aee..ef5b240aea 100644
--- a/tensorflow/core/framework/variant_op_registry.cc
+++ b/tensorflow/core/framework/variant_op_registry.cc
@@ -38,21 +38,19 @@ UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() {
}
UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn(
- StringPiece type_name) {
- auto found = shape_fns.find(type_name);
+ const TypeIndex& type_index) {
+ auto found = shape_fns.find(type_index);
if (found == shape_fns.end()) return nullptr;
return &found->second;
}
-void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name,
+void UnaryVariantOpRegistry::RegisterShapeFn(const TypeIndex& type_index,
const VariantShapeFn& shape_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantShape";
- VariantShapeFn* existing = GetShapeFn(type_name);
+ VariantShapeFn* existing = GetShapeFn(type_index);
CHECK_EQ(existing, nullptr)
- << "Unary VariantShapeFn for type_name: " << type_name
- << " already registered";
- shape_fns.insert(std::pair<StringPiece, VariantShapeFn>(
- GetPersistentStringPiece(type_name), shape_fn));
+ << "Unary VariantShapeFn for type_index: "
+ << port::MaybeAbiDemangle(type_index.name()) << " already registered";
+ shape_fns.insert(std::pair<TypeIndex, VariantShapeFn>(type_index, shape_fn));
}
Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
@@ -60,11 +58,11 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
CHECK_EQ(variant_tensor.dims(), 0);
const Variant& v = variant_tensor.scalar<Variant>()();
UnaryVariantOpRegistry::VariantShapeFn* shape_fn =
- UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeName());
+ UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeId());
if (shape_fn == nullptr) {
return errors::Internal(
- "No unary variant shape function found for Variant type_name: ",
- v.TypeName());
+ "No unary variant shape function found for Variant type_index: ",
+ port::MaybeAbiDemangle(v.TypeId().name()));
}
return (*shape_fn)(v, shape);
}
@@ -79,7 +77,7 @@ Status ScalarShape(const T&, TensorShape* shape) {
} // namespace
#define REGISTER_VARIANT_SHAPE_TYPE(T) \
- REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape<T>);
+ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, ScalarShape<T>);
// No encode/shape registered for std::complex<> and Eigen::half
// objects yet.
@@ -143,25 +141,24 @@ REGISTER_VARIANT_DECODE_TYPE(double);
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn*
UnaryVariantOpRegistry::GetDeviceCopyFn(
- const VariantDeviceCopyDirection direction, StringPiece type_name) {
- auto found = device_copy_fns.find(std::make_pair(direction, type_name));
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index) {
+ auto found = device_copy_fns.find(std::make_pair(direction, type_index));
if (found == device_copy_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterDeviceCopyFn(
- const VariantDeviceCopyDirection direction, const string& type_name,
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
const AsyncVariantDeviceCopyFn& device_copy_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDeviceCopy";
- AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_name);
+ AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index);
CHECK_EQ(existing, nullptr)
<< "UnaryVariantDeviceCopy for direction: " << direction
- << " and type_name: " << type_name << " already registered";
+ << " and type_index: " << port::MaybeAbiDemangle(type_index.name())
+ << " already registered";
device_copy_fns.insert(
- std::pair<std::pair<VariantDeviceCopyDirection, StringPiece>,
- AsyncVariantDeviceCopyFn>(
- std::make_pair(direction, GetPersistentStringPiece(type_name)),
- device_copy_fn));
+ std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>,
+ AsyncVariantDeviceCopyFn>(std::make_pair(direction, type_index),
+ device_copy_fn));
}
Status VariantDeviceCopy(
@@ -170,35 +167,34 @@ Status VariantDeviceCopy(
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) {
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn =
UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction,
- from.TypeName());
+ from.TypeId());
if (device_copy_fn == nullptr) {
return errors::Internal(
"No unary variant device copy function found for direction: ",
- direction, " and Variant type_name: ", from.TypeName());
+ direction, " and Variant type_index: ",
+ port::MaybeAbiDemangle(from.TypeId().name()));
}
return (*device_copy_fn)(from, to, copy_fn);
}
// Special casing UnaryOpFn per op and per device.
UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn(
- VariantUnaryOp op, StringPiece device, StringPiece type_name) {
- auto found = unary_op_fns.find({op, device, type_name});
+ VariantUnaryOp op, StringPiece device, const TypeIndex& type_index) {
+ auto found = unary_op_fns.find({op, device, type_index});
if (found == unary_op_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterUnaryOpFn(
- VariantUnaryOp op, const string& device, const string& type_name,
+ VariantUnaryOp op, const string& device, const TypeIndex& type_index,
const VariantUnaryOpFn& unary_op_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp";
- VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name);
+ VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index);
CHECK_EQ(existing, nullptr)
- << "Unary VariantUnaryOpFn for type_name: " << type_name
+ << "Unary VariantUnaryOpFn for type_index: "
+ << port::MaybeAbiDemangle(type_index.name())
<< " already registered for device type: " << device;
unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
- {op, GetPersistentStringPiece(device),
- GetPersistentStringPiece(type_name)},
- unary_op_fn));
+ {op, GetPersistentStringPiece(device), type_index}, unary_op_fn));
}
namespace {
@@ -212,7 +208,7 @@ Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
#define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
- DEVICE_CPU, T, TF_STR(T), \
+ DEVICE_CPU, T, \
ZerosLikeVariantPrimitiveType<T>);
// No zeros_like registered for std::complex<> or Eigen::half objects yet.
@@ -226,24 +222,22 @@ REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
// Special casing BinaryOpFn per op and per device.
UnaryVariantOpRegistry::VariantBinaryOpFn*
UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
- StringPiece type_name) {
- auto found = binary_op_fns.find({op, device, type_name});
+ const TypeIndex& type_index) {
+ auto found = binary_op_fns.find({op, device, type_index});
if (found == binary_op_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterBinaryOpFn(
- VariantBinaryOp op, const string& device, const string& type_name,
+ VariantBinaryOp op, const string& device, const TypeIndex& type_index,
const VariantBinaryOpFn& add_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp";
- VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name);
+ VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index);
CHECK_EQ(existing, nullptr)
- << "Unary VariantBinaryOpFn for type_name: " << type_name
+ << "Unary VariantBinaryOpFn for type_index: "
+ << port::MaybeAbiDemangle(type_index.name())
<< " already registered for device type: " << device;
binary_op_fns.insert(std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
- {op, GetPersistentStringPiece(device),
- GetPersistentStringPiece(type_name)},
- add_fn));
+ {op, GetPersistentStringPiece(device), type_index}, add_fn));
}
namespace {
@@ -257,8 +251,7 @@ Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
#define REGISTER_VARIANT_ADD_TYPE(T) \
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
- T, TF_STR(T), \
- AddVariantPrimitiveType<T>);
+ T, AddVariantPrimitiveType<T>);
// No add registered for std::complex<> or Eigen::half objects yet.
REGISTER_VARIANT_ADD_TYPE(int);
diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h
index e6a2665a56..7eb37e859f 100644
--- a/tensorflow/core/framework/variant_op_registry.h
+++ b/tensorflow/core/framework/variant_op_registry.h
@@ -22,10 +22,14 @@ limitations under the License.
#define EIGEN_USE_THREADS
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/abi.h"
namespace tensorflow {
@@ -90,10 +94,11 @@ class UnaryVariantOpRegistry {
AsyncVariantDeviceCopyFn;
// Add a shape lookup function to the registry.
- void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn);
+ void RegisterShapeFn(const TypeIndex& type_index,
+ const VariantShapeFn& shape_fn);
- // Returns nullptr if no shape function was found for the given TypeName.
- VariantShapeFn* GetShapeFn(StringPiece type_name);
+ // Returns nullptr if no shape function was found for the given TypeIndex.
+ VariantShapeFn* GetShapeFn(const TypeIndex& type_index);
// Add a decode function to the registry.
void RegisterDecodeFn(const string& type_name,
@@ -104,33 +109,33 @@ class UnaryVariantOpRegistry {
// Add a copy-to-GPU function to the registry.
void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,
- const string& type_name,
+ const TypeIndex& type_index,
const AsyncVariantDeviceCopyFn& device_copy_fn);
// Returns nullptr if no copy function was found for the given
// TypeName and direction.
AsyncVariantDeviceCopyFn* GetDeviceCopyFn(
- const VariantDeviceCopyDirection direction, StringPiece type_name);
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index);
// Add a unary op function to the registry.
void RegisterUnaryOpFn(VariantUnaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const VariantUnaryOpFn& unary_op_fn);
// Returns nullptr if no unary op function was found for the given
// op, device, and TypeName.
VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
- StringPiece type_name);
+ const TypeIndex& type_index);
// Add a binary op function to the registry.
void RegisterBinaryOpFn(VariantBinaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const VariantBinaryOpFn& add_fn);
// Returns nullptr if no binary op function was found for the given
// op, device and TypeName.
VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
- StringPiece type_name);
+ const TypeIndex& type_index);
// Get a pointer to a global UnaryVariantOpRegistry object
static UnaryVariantOpRegistry* Global();
@@ -145,24 +150,26 @@ class UnaryVariantOpRegistry {
static std::unordered_set<string>* PersistentStringStorage();
private:
- std::unordered_map<StringPiece, VariantShapeFn, StringPieceHasher> shape_fns;
- std::unordered_map<StringPiece, VariantDecodeFn, StringPieceHasher>
- decode_fns;
+ struct TypeIndexHash {
+ std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); }
+ };
+
+ gtl::FlatMap<TypeIndex, VariantShapeFn, TypeIndexHash> shape_fns;
+ gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns;
// Map std::pair<Direction, type_name> to function.
struct PairHash {
template <typename Direction>
- std::size_t operator()(const std::pair<Direction, StringPiece>& x) const {
+ std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const {
// The hash of an enum is just its value as a std::size_t.
std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
- ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
+ ret = Hash64Combine(ret, std::get<1>(x).hash_code());
return ret;
}
- StringPieceHasher sp_hasher_;
};
- std::unordered_map<std::pair<VariantDeviceCopyDirection, StringPiece>,
- AsyncVariantDeviceCopyFn, PairHash>
+ gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>,
+ AsyncVariantDeviceCopyFn, PairHash>
device_copy_fns;
// Map std::tuple<Op, device, type_name> to function.
@@ -172,10 +179,11 @@ class UnaryVariantOpRegistry {
// and references therein
template <typename Op>
struct FuncTuple {
- FuncTuple(const Op& op, const StringPiece& dev, const StringPiece& tname)
- : op_type_(op), device_(dev), typename_(tname){};
+ FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index)
+ : op_type_(op), device_(dev), type_index_(type_index) {}
Op op_type_;
- StringPiece device_, typename_;
+ StringPiece device_;
+ TypeIndex type_index_;
};
// friend declaration for operator==
// needed for clang
@@ -184,11 +192,11 @@ class UnaryVariantOpRegistry {
struct TupleHash {
template <typename Op>
std::size_t operator()(
- const std::tuple<Op, StringPiece, StringPiece>& x) const {
+ const std::tuple<Op, StringPiece, TypeIndex>& x) const {
// The hash of an enum is just its value as a std::size_t.
std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
- ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x)));
+ ret = Hash64Combine(ret, std::get<2>(x).hash_code());
return ret;
}
@@ -197,14 +205,14 @@ class UnaryVariantOpRegistry {
// The hash of an enum is just its value as a std::size_t.
std::size_t ret = static_cast<std::size_t>(x.op_type_);
ret = Hash64Combine(ret, sp_hasher_(x.device_));
- ret = Hash64Combine(ret, sp_hasher_(x.typename_));
+ ret = Hash64Combine(ret, x.type_index_.hash_code());
return ret;
}
StringPieceHasher sp_hasher_;
};
- std::unordered_map<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
+ gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
unary_op_fns;
- std::unordered_map<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
+ gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
binary_op_fns;
// Find or insert a string into a persistent string storage
@@ -225,7 +233,7 @@ template <typename Op>
inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
- (lhs.typename_ == rhs.typename_);
+ (lhs.type_index_ == rhs.type_index_);
}
// Gets a TensorShape from a Tensor containing a scalar Variant.
// Returns an Internal error if the Variant does not have a registered shape
@@ -276,7 +284,7 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
Variant* v_out) {
const string& device = DeviceName<Device>::value;
UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
- UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName());
+ UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
if (unary_op_fn == nullptr) {
return errors::Internal(
"No unary variant unary_op function found for unary variant op enum: ",
@@ -297,15 +305,15 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
template <typename Device>
Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
const Variant& a, const Variant& b, Variant* out) {
- if (a.TypeName() != b.TypeName()) {
+ if (a.TypeId() != b.TypeId()) {
return errors::Internal(
"BianryOpVariants: Variants a and b have different "
- "type names: '",
+ "type ids. Type names: '",
a.TypeName(), "' vs. '", b.TypeName(), "'");
}
const string& device = DeviceName<Device>::value;
UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
- UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName());
+ UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
if (binary_op_fn == nullptr) {
return errors::Internal(
"No unary variant binary_op function found for binary variant op "
@@ -323,16 +331,18 @@ class UnaryVariantShapeRegistration {
public:
typedef std::function<Status(const T& t, TensorShape*)> LocalVariantShapeFn;
- UnaryVariantShapeRegistration(const string& type_name,
+ UnaryVariantShapeRegistration(const TypeIndex& type_index,
const LocalVariantShapeFn& shape_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterShapeFn(
- type_name,
- [type_name, shape_fn](const Variant& v, TensorShape* s) -> Status {
+ type_index,
+ [type_index_name, shape_fn](const Variant& v,
+ TensorShape* s) -> Status {
const T* t = v.get<T>();
if (t == nullptr) {
return errors::Internal(
- "VariantShapeFn: Could not access object, type_name: ",
- type_name);
+ "VariantShapeFn: Could not access object, type_index: ",
+ type_index_name);
}
return shape_fn(*t, s);
});
@@ -355,11 +365,11 @@ class UnaryVariantDecodeRegistration {
return false;
}
Variant decoded = T();
- VariantTensorData data(*t);
- if (!decoded.Decode(data)) {
+ VariantTensorData data(std::move(*t));
+ if (!decoded.Decode(std::move(data))) {
return false;
}
- *v = std::move(decoded);
+ std::swap(decoded, *v);
return true;
});
}
@@ -372,11 +382,12 @@ class UnaryVariantDeviceCopyRegistration {
UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)>
LocalVariantDeviceCopyFn;
UnaryVariantDeviceCopyRegistration(
- const VariantDeviceCopyDirection direction, const string& type_name,
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
const LocalVariantDeviceCopyFn& device_copy_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
- direction, type_name,
- [type_name, device_copy_fn](
+ direction, type_index,
+ [type_index_name, device_copy_fn](
const Variant& from, Variant* to,
UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn
device_copy_tensor_fn) -> Status {
@@ -384,8 +395,8 @@ class UnaryVariantDeviceCopyRegistration {
*to = T();
if (from.get<T>() == nullptr) {
return errors::Internal(
- "VariantCopyToGPUFn: Could not access object, type_name: ",
- type_name);
+ "VariantCopyToGPUFn: Could not access object, type_index: ",
+ type_index_name);
}
const T& t = *from.get<T>();
T* t_out = to->get<T>();
@@ -401,18 +412,19 @@ class UnaryVariantUnaryOpRegistration {
public:
UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const LocalVariantUnaryOpFn& unary_op_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
- op, device, type_name,
- [type_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
- Variant* v_out) -> Status {
+ op, device, type_index,
+ [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
+ Variant* v_out) -> Status {
DCHECK_NE(v_out, nullptr);
*v_out = T();
if (v.get<T>() == nullptr) {
return errors::Internal(
- "VariantUnaryOpFn: Could not access object, type_name: ",
- type_name);
+ "VariantUnaryOpFn: Could not access object, type_index: ",
+ type_index_name);
}
const T& t = *v.get<T>();
T* t_out = v_out->get<T>();
@@ -429,23 +441,25 @@ class UnaryVariantBinaryOpRegistration {
public:
UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const LocalVariantBinaryOpFn& binary_op_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
- op, device, type_name,
- [type_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
- const Variant& b, Variant* out) -> Status {
+ op, device, type_index,
+ [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
+ const Variant& b,
+ Variant* out) -> Status {
DCHECK_NE(out, nullptr);
*out = T();
if (a.get<T>() == nullptr) {
return errors::Internal(
- "VariantBinaryOpFn: Could not access object 'a', type_name: ",
- type_name);
+ "VariantBinaryOpFn: Could not access object 'a', type_index: ",
+ type_index_name);
}
if (b.get<T>() == nullptr) {
return errors::Internal(
- "VariantBinaryOpFn: Could not access object 'b', type_name: ",
- type_name);
+ "VariantBinaryOpFn: Could not access object 'b', type_index: ",
+ type_index_name);
}
const T& t_a = *a.get<T>();
const T& t_b = *b.get<T>();
@@ -459,19 +473,19 @@ class UnaryVariantBinaryOpRegistration {
// Register a unary shape variant function with the signature:
// Status ShapeFn(const T& t, TensorShape* s);
-// to Variants having TypeName type_name.
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, type_name, shape_function) \
- REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name, \
- shape_function)
+// to Variants having TypeIndex type_index.
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, shape_function) \
+ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, T, MakeTypeIndex<T>(), shape_function)
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_name, \
- shape_function) \
- REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, shape_function)
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_index, \
+ shape_function) \
+ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, shape_function)
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, \
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, \
shape_function) \
static variant_op_registry_fn_registration::UnaryVariantShapeRegistration<T> \
- register_unary_variant_op_shape_registration_fn_##ctr(type_name, \
+ register_unary_variant_op_shape_registration_fn_##ctr(type_index, \
shape_function)
// Register a unary decode variant function for the given type.
@@ -519,63 +533,63 @@ class UnaryVariantBinaryOpRegistration {
// ****** NOTE ******
// FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE.
// ****** NOTE ******
-#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- T, direction, type_name, device_copy_fn) \
- INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
- __COUNTER__, T, direction, type_name, device_copy_fn)
+#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction, \
+ device_copy_fn) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, T, direction, MakeTypeIndex<T>(), device_copy_fn)
#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
- ctr, T, direction, type_name, device_copy_fn) \
+ ctr, T, direction, type_index, device_copy_fn) \
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
- ctr, T, direction, type_name, device_copy_fn)
+ ctr, T, direction, type_index, device_copy_fn)
-#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
- ctr, T, direction, type_name, device_copy_fn) \
- static variant_op_registry_fn_registration:: \
- UnaryVariantDeviceCopyRegistration<T> \
- register_unary_variant_op_device_copy_fn_##ctr(direction, type_name, \
- device_copy_fn)
+#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
+ ctr, T, direction, type_index, device_copy_fn) \
+ static variant_op_registry_fn_registration:: \
+ UnaryVariantDeviceCopyRegistration<T> \
+ register_unary_variant_op_device_copy_fn_##ctr( \
+ direction, type_index, device_copy_fn)
// Register a unary unary_op variant function with the signature:
// Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
-// to Variants having TypeName type_name, for device string device,
+// to Variants having TypeIndex type_index, for device string device,
// for UnaryVariantOp enum op.
-#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \
- unary_op_function) \
- REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
- __COUNTER__, op, device, T, type_name, unary_op_function)
+#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, \
+ unary_op_function) \
+ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, op, device, T, MakeTypeIndex<T>(), unary_op_function)
-#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
- ctr, op, device, T, type_name, unary_op_function) \
- REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \
- unary_op_function)
+#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
+ ctr, op, device, T, type_index, unary_op_function) \
+ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \
+ type_index, unary_op_function)
#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \
- ctr, op, device, T, type_name, unary_op_function) \
+ ctr, op, device, T, type_index, unary_op_function) \
static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \
T> \
- register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
+ register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
unary_op_function)
// Register a binary_op variant function with the signature:
// Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
-// to Variants having TypeName type_name, for device string device,
+// to Variants having TypeIndex type_index, for device string device,
// for BinaryVariantOp enum OP.
-#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \
- binary_op_function) \
- REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
- __COUNTER__, op, device, T, type_name, binary_op_function)
+#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, \
+ binary_op_function) \
+ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, op, device, T, MakeTypeIndex<T>(), binary_op_function)
#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
- ctr, op, device, T, type_name, binary_op_function) \
+ ctr, op, device, T, type_index, binary_op_function) \
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
- ctr, op, device, T, type_name, binary_op_function)
+ ctr, op, device, T, type_index, binary_op_function)
-#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
- ctr, op, device, T, type_name, binary_op_function) \
- static variant_op_registry_fn_registration:: \
- UnaryVariantBinaryOpRegistration<T> \
- register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
+#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
+ ctr, op, device, T, type_index, binary_op_function) \
+ static variant_op_registry_fn_registration:: \
+ UnaryVariantBinaryOpRegistration<T> \
+ register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
binary_op_function)
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc
index 7055e62c0e..b2443e8676 100644
--- a/tensorflow/core/framework/variant_op_registry_test.cc
+++ b/tensorflow/core/framework/variant_op_registry_test.cc
@@ -89,41 +89,37 @@ struct VariantValue {
int value;
};
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue",
- VariantValue::ShapeFn);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, VariantValue::ShapeFn);
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue");
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
VariantValue, VariantDeviceCopyDirection::HOST_TO_DEVICE,
- "TEST VariantValue", VariantValue::CPUToGPUCopyFn);
+ VariantValue::CPUToGPUCopyFn);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_CPU, VariantValue,
- "TEST VariantValue",
VariantValue::CPUZerosLikeFn);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_GPU, VariantValue,
- "TEST VariantValue",
VariantValue::GPUZerosLikeFn);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- VariantValue, "TEST VariantValue",
- VariantValue::CPUAddFn);
+ VariantValue, VariantValue::CPUAddFn);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- VariantValue, "TEST VariantValue",
- VariantValue::GPUAddFn);
+ VariantValue, VariantValue::GPUAddFn);
} // namespace
TEST(VariantOpShapeRegistryTest, TestBasic) {
- EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn("YOU SHALL NOT PASS"),
+ class Blah {};
+ EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn(MakeTypeIndex<Blah>()),
nullptr);
- auto* shape_fn =
- UnaryVariantOpRegistry::Global()->GetShapeFn("TEST VariantValue");
+ auto* shape_fn = UnaryVariantOpRegistry::Global()->GetShapeFn(
+ MakeTypeIndex<VariantValue>());
EXPECT_NE(shape_fn, nullptr);
TensorShape shape;
@@ -142,10 +138,11 @@ TEST(VariantOpShapeRegistryTest, TestBasic) {
TEST(VariantOpShapeRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantShapeFn f;
- string kTypeName = "fjfjfj";
- registry.RegisterShapeFn(kTypeName, f);
- EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f),
- "fjfjfj already registered");
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
+ registry.RegisterShapeFn(kTypeIndex, f);
+ EXPECT_DEATH(registry.RegisterShapeFn(kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpDecodeRegistryTest, TestBasic) {
@@ -180,13 +177,14 @@ TEST(VariantOpDecodeRegistryTest, TestDuplicate) {
TEST(VariantOpCopyToGPURegistryTest, TestBasic) {
// No registered copy fn for GPU<->GPU.
- EXPECT_EQ(
- UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
- VariantDeviceCopyDirection::DEVICE_TO_DEVICE, "TEST VariantValue"),
- nullptr);
+ EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
+ VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
+ MakeTypeIndex<VariantValue>()),
+ nullptr);
auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
- VariantDeviceCopyDirection::HOST_TO_DEVICE, "TEST VariantValue");
+ VariantDeviceCopyDirection::HOST_TO_DEVICE,
+ MakeTypeIndex<VariantValue>());
EXPECT_NE(copy_to_gpu_fn, nullptr);
VariantValue vv{true /* early_exit */};
@@ -208,17 +206,19 @@ TEST(VariantOpCopyToGPURegistryTest, TestBasic) {
TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE,
- kTypeName, f);
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterDeviceCopyFn(
- VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeName, f),
- "fjfjfj already registered");
+ VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
- ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
+ ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
@@ -242,8 +242,9 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
#if GOOGLE_CUDA
TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
- ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
+ ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
@@ -269,25 +270,26 @@ TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantUnaryOpFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
- registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName,
- f);
+ registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU,
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
- DEVICE_CPU, kTypeName, f),
- "fjfjfj already registered");
+ DEVICE_CPU, kTypeIndex, f),
+ "FjFjFj already registered");
- registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName,
- f);
+ registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU,
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
- DEVICE_GPU, kTypeName, f),
- "fjfjfj already registered");
+ DEVICE_GPU, kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpAddRegistryTest, TestBasicCPU) {
- return;
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
- ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
+ ADD_VARIANT_BINARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
@@ -312,8 +314,9 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) {
#if GOOGLE_CUDA
TEST(VariantOpAddRegistryTest, TestBasicGPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
- ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
+ ADD_VARIANT_BINARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
@@ -340,17 +343,18 @@ TEST(VariantOpAddRegistryTest, TestBasicGPU) {
TEST(VariantOpAddRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantBinaryOpFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
- registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f);
+ registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeIndex, f);
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- kTypeName, f),
- "fjfjfj already registered");
+ kTypeIndex, f),
+ "FjFjFj already registered");
- registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f);
+ registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeIndex, f);
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- kTypeName, f),
- "fjfjfj already registered");
+ kTypeIndex, f),
+ "FjFjFj already registered");
}
} // namespace tensorflow
diff --git a/tensorflow/core/framework/variant_tensor_data.cc b/tensorflow/core/framework/variant_tensor_data.cc
index 99712dc114..3e67e4a864 100644
--- a/tensorflow/core/framework/variant_tensor_data.cc
+++ b/tensorflow/core/framework/variant_tensor_data.cc
@@ -22,8 +22,8 @@ namespace tensorflow {
VariantTensorData::VariantTensorData() {}
-VariantTensorData::VariantTensorData(const VariantTensorDataProto& proto) {
- FromProto(proto);
+VariantTensorData::VariantTensorData(VariantTensorDataProto proto) {
+ FromProto(std::move(proto));
}
VariantTensorData::~VariantTensorData() {}
@@ -52,7 +52,19 @@ void VariantTensorData::ToProto(VariantTensorDataProto* proto) const {
}
}
-bool VariantTensorData::FromProto(const VariantTensorDataProto& proto) {
+bool VariantTensorData::FromProto(VariantTensorDataProto proto) {
+ // TODO(ebrevdo): Do this lazily.
+ set_type_name(proto.type_name());
+ set_metadata(proto.metadata());
+ for (const auto& tensor : proto.tensors()) {
+ Tensor tmp;
+ if (!tmp.FromProto(tensor)) return false;
+ tensors_.push_back(tmp);
+ }
+ return true;
+}
+
+bool VariantTensorData::FromConstProto(const VariantTensorDataProto& proto) {
set_type_name(proto.type_name());
set_metadata(proto.metadata());
for (const auto& tensor : proto.tensors()) {
@@ -75,10 +87,10 @@ bool VariantTensorData::SerializeToString(string* buf) {
return proto.SerializeToString(buf);
}
-bool VariantTensorData::ParseFromString(const string& s) {
+bool VariantTensorData::ParseFromString(string s) {
VariantTensorDataProto proto;
const bool status = proto.ParseFromString(s);
- if (status) FromProto(proto);
+ if (status) FromProto(std::move(proto));
return status;
}
diff --git a/tensorflow/core/framework/variant_tensor_data.h b/tensorflow/core/framework/variant_tensor_data.h
index 7500e77d43..8a240ee1e3 100644
--- a/tensorflow/core/framework/variant_tensor_data.h
+++ b/tensorflow/core/framework/variant_tensor_data.h
@@ -19,13 +19,13 @@ limitations under the License.
#include <algorithm>
#include <vector>
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class VariantTensorDataProto;
-class Tensor;
// The serialization format for Variant objects. Objects with references to
// other Tensors can simply store those tensors in the `tensors` field, and
@@ -38,7 +38,7 @@ class Tensor;
class VariantTensorData {
public:
VariantTensorData();
- VariantTensorData(const VariantTensorDataProto& proto);
+ VariantTensorData(VariantTensorDataProto proto);
~VariantTensorData();
// Name of the type of objects being serialized.
@@ -68,12 +68,14 @@ class VariantTensorData {
// Conversion to and from VariantTensorDataProto
void ToProto(VariantTensorDataProto* proto) const;
- bool FromProto(const VariantTensorDataProto& proto);
+ // This allows optimizations via std::move.
+ bool FromProto(VariantTensorDataProto proto);
+ bool FromConstProto(const VariantTensorDataProto& proto);
// Serialization via VariantTensorDataProto
string SerializeAsString() const;
bool SerializeToString(string* buf);
- bool ParseFromString(const string& s);
+ bool ParseFromString(string s);
string DebugString() const;
diff --git a/tensorflow/core/framework/variant_test.cc b/tensorflow/core/framework/variant_test.cc
index eef5c47d15..08d09de7b8 100644
--- a/tensorflow/core/framework/variant_test.cc
+++ b/tensorflow/core/framework/variant_test.cc
@@ -144,8 +144,8 @@ TEST(VariantTest, TypeMismatch) {
struct TensorList {
void Encode(VariantTensorData* data) const { data->tensors_ = vec; }
- bool Decode(const VariantTensorData& data) {
- vec = data.tensors_;
+ bool Decode(VariantTensorData data) {
+ vec = std::move(data.tensors_);
return true;
}
@@ -186,7 +186,7 @@ TEST(VariantTest, TensorListTest) {
x.Encode(&serialized);
Variant y = TensorList();
- y.Decode(serialized);
+ y.Decode(std::move(serialized));
const TensorList& decoded_vec = *y.get<TensorList>();
for (int i = 0; i < 4; ++i) {
@@ -204,15 +204,6 @@ TEST(VariantTest, TensorListTest) {
EXPECT_EQ(y_unknown.DebugString(),
strings::StrCat(
"Variant<type: TensorList value: ", data.DebugString(), ">"));
-
- TensorList unknown_decoded_vec;
- EXPECT_TRUE(y_unknown.MaybeDecodeAndCopy(&unknown_decoded_vec));
- for (int i = 0; i < 4; ++i) {
- EXPECT_EQ(unknown_decoded_vec.vec[i].flat<int>()(0), i);
- }
- for (int i = 0; i < 4; ++i) {
- EXPECT_EQ(unknown_decoded_vec.vec[i + 4].flat<float>()(0), 2 * i);
- }
}
TEST(VariantTest, VariantArray) {
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index ee10194142..eeb5c14eaa 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -1042,12 +1042,12 @@ Status GraphConstructor::Convert() {
}
if (processed < node_defs_.size()) {
- LOG(WARNING) << "IN " << __func__ << (node_defs_.size() - processed)
+ LOG(WARNING) << "IN " << __func__ << " " << (node_defs_.size() - processed)
<< " NODES IN A CYCLE";
for (int64 i = 0; i < node_defs_.size(); i++) {
if (pending_count_[i] != 0) {
LOG(WARNING) << "PENDING: " << SummarizeNodeDef(*node_defs_[i])
- << "WITH PENDING COUNT = " << pending_count_[i];
+ << " WITH PENDING COUNT = " << pending_count_[i];
}
}
return errors::InvalidArgument(node_defs_.size() - processed,
@@ -1162,7 +1162,9 @@ Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
const NodeDef* node_def = node_defs_[pair->second.gdef_index];
const OpDef* op_def;
TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
- if (key.second >= op_def->output_arg_size()) {
+ int num_outputs;
+ TF_RETURN_IF_ERROR(NumOutputsForNode(*node_def, *op_def, &num_outputs));
+ if (key.second >= num_outputs) {
// key's index out of bounds
missing_unused_input_map_keys_->push_back(key);
}
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 73142ebde7..3eef6bd2bd 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -199,6 +199,10 @@ REGISTER_OP("TestOneInputOneOutput")
.Output("y: T")
.Attr("T: {float, int64}")
.SetShapeFn(shape_inference::UnchangedShape);
+REGISTER_OP("TestVariadicOutput")
+ .Output("outputs: N * int32")
+ .Attr("N: int >= 0")
+ .SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("TestDefaultAttr")
.Attr("default_int: int=31415")
.SetShapeFn(shape_inference::NoOutputs);
@@ -1463,12 +1467,15 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapMissingUnusedKeys) {
opts.input_map[TensorId("DNE", 0)] = TensorId("input", 0);
// Unused but not missing
opts.input_map[TensorId("t1", 0)] = TensorId("W1", 0);
+ // Unused but not missing
+ opts.input_map[TensorId("variadic", 4)] = TensorId("input", 0);
ExpectOK(
R"EOF(
node { name: 'W2' op: 'TestParams' }
node { name: 'new_input' op: 'TestInput' input: [ '^W2' ] }
node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] }
- node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] }
+ node { name: 'variadic' op: 'TestVariadicOutput'
+ attr { key: "N" value { i: 5 } } }
)EOF",
opts, &refiner, &results);
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 2e644fe987..f5b0105862 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index ea7788f654..0a38aa1c91 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -485,6 +485,33 @@ Node* DiagPart(Graph* g, Node* in, DataType type) {
return ret;
}
+Node* CheckNumerics(Graph* g, Node* in, const string& message) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics")
+ .Input(in)
+ .Attr("message", message)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Arg(Graph* g, int64 index, DataType type) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg")
+ .Attr("T", type)
+ .Attr("index", index)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Retval(Graph* g, int64 index, Node* in) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval")
+ .Input(in)
+ .Attr("index", index)
+ .Finalize(g, &ret));
+ return ret;
+}
+
void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
} // end namespace graph
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index 8585b35a19..bd0284d43a 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -209,6 +209,15 @@ Node* Diag(Graph* g, Node* in, DataType type);
// Add a DiagPart node in "g".
Node* DiagPart(Graph* g, Node* in, DataType type);
+// Add a CheckNumerics node in "g".
+Node* CheckNumerics(Graph* g, Node* in, const string& message);
+
+// Add an _Arg node in "g".
+Node* Arg(Graph* g, int64 index, DataType type);
+
+// Add a _Retval node in "g".
+Node* Retval(Graph* g, int64 index, Node* in);
+
} // end namespace graph
} // end namespace test
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 6710ff9df3..56c8339d57 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -260,13 +260,13 @@ typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
}
bool IsEnqueue(const NodeDef& n) {
- return (n.op().find("Enqueue") != std::string::npos &&
- n.op().find("EnqueueMany") == std::string::npos);
+ return (n.op().find("Enqueue") != string::npos &&
+ n.op().find("EnqueueMany") == string::npos);
}
bool IsDequeue(const NodeDef& n) {
- return (n.op().find("Dequeue") != std::string::npos &&
- n.op().find("DequeueMany") == std::string::npos);
+ return (n.op().find("Dequeue") != string::npos &&
+ n.op().find("DequeueMany") == string::npos);
}
bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
@@ -345,6 +345,56 @@ void VerboseLogUnknownDimensionSources(
}
}
+bool IsShapeFullyDefinedIntegerVectorOrScalar(
+ InferenceContext* ic, const ShapeHandle& shape,
+ const ShapeHandle& tensor_as_shape, const DataType& dtype) {
+ if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 ||
+ !ic->FullyDefined(tensor_as_shape) ||
+ (dtype != DT_INT32 && dtype != DT_INT64)) {
+ return false;
+ }
+ return true;
+}
+
+// Returned tensor's shape is like `shape`, and its values and dtype are from
+// `tensor_as_shape` and `dtype`.
+TensorProto MakeTensorProtoFromShape(InferenceContext* ic,
+ const ShapeHandle& shape,
+ const ShapeHandle& tensor_as_shape,
+ const DataType& dtype) {
+ TensorProto tensor_proto;
+ tensor_proto.set_dtype(dtype);
+ auto* shape_proto = tensor_proto.mutable_tensor_shape();
+ if (ic->Rank(shape) == 1) {
+ shape_proto->add_dim()->set_size(ic->Rank(tensor_as_shape));
+ }
+ // For a scalar tensor, tensor_shape field will be left empty; no dim.
+ for (int i = 0; i < ic->Rank(tensor_as_shape); i++) {
+ int64 value = ic->Value(ic->Dim(tensor_as_shape, i));
+ if (dtype == DT_INT32) {
+ tensor_proto.add_int_val(value);
+ } else {
+ tensor_proto.add_int64_val(value);
+ }
+ }
+ return tensor_proto;
+}
+
+// Returns a Const NodeDef with shape = `shape`, values = `tensor_as_shape`,
+// and dtype = `dtype`.
+NodeDef MakeConstNodeDefFromShape(InferenceContext* ic,
+ const ShapeHandle& shape,
+ const ShapeHandle& tensor_as_shape,
+ const DataType& dtype) {
+ NodeDef const_node;
+ const_node.set_name("const_from_shape");
+ const_node.set_op("Const");
+ auto* attr = const_node.mutable_attr();
+ (*attr)["dtype"].set_type(dtype);
+ auto* tensor = (*attr)["value"].mutable_tensor();
+ *tensor = MakeTensorProtoFromShape(ic, shape, tensor_as_shape, dtype);
+ return const_node;
+}
} // namespace
// Queue of nodes to process. Nodes can be enqueued in any order, but will be
@@ -429,18 +479,22 @@ class SymbolicShapeRefiner {
// perform shape inference on the function body.
//
// Propagate shape information of final function body node
- // to function node `node`.
+ // to function node `function_node`.
//
- // In the event of an error, UpdateNode will simply set `node`'s
+ // In the event of an error, UpdateNode will simply set `function_node`'s
// output shape to be Unknown.
- Status UpdateFunction(const NodeDef* node) {
- auto it = fun_to_grappler_function_item_.find(node->op());
+ Status UpdateFunction(const NodeDef* function_node) {
+ auto it = fun_to_grappler_function_item_.find(function_node->op());
if (it == fun_to_grappler_function_item_.end()) {
return errors::InvalidArgument(
- node->op(), " was not previously added to SymbolicShapeRefiner.");
+ function_node->op(),
+ " was not previously added to SymbolicShapeRefiner.");
}
- GrapplerFunctionItem& grappler_function_item = it->second;
+ // Copy (not reference) so that changes we make here (e.g., replacing
+ // Placeholder with Const) don't affect one in
+ // fun_to_grappler_function_item_.
+ GrapplerFunctionItem grappler_function_item = it->second;
GraphView gv(&grappler_function_item.graph);
// Forward shapes from function input nodes to argument nodes.
@@ -453,7 +507,7 @@ class SymbolicShapeRefiner {
"supported.");
}
NodeDef* fun_node = gv.GetNode(fun_input.input_name);
- const string& input = node->input(i);
+ const string& input = function_node->input(i);
const string& node_name = NodeName(input);
if (IsControlInput(input)) {
@@ -478,17 +532,48 @@ class SymbolicShapeRefiner {
TensorShapeProto proto;
const auto& handle = input_inference_context->output(output_port_num);
input_inference_context->ShapeHandleToProto(handle, &proto);
+ // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
+ for (int i = 0; i < proto.dim_size(); i++) {
+ if (proto.dim(i).size() < -1) {
+ proto.mutable_dim(i)->set_size(-1);
+ }
+ }
*attr_output_shape.mutable_shape() = proto;
(*fun_node->mutable_attr())["shape"] = attr_output_shape;
}
+ // Replace input Placeholders with Consts, if values are known. Note that
+ // we don't check exceptions here as it's done in the above loop.
+ auto* ctx = GetNodeContext(function_node);
+ auto* ic = ctx->inference_context.get();
+ for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
+ const string& input = function_node->input(i);
+ const string& node_name = NodeName(input);
+ NodeDef* input_node = graph_.GetNode(node_name);
+ if (IsConstant(*input_node)) {
+ TF_CHECK_OK(
+ ReplaceInputWithConst(*input_node, i, &grappler_function_item));
+ } else if (ic->input_tensors_as_shapes().size() > i &&
+ IsShapeFullyDefinedIntegerVectorOrScalar(
+ ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+ ctx->input_types[i])) {
+ // We have fully defined input_tensors_as_shapes for this input; use it
+ // as a const input to the function node.
+ NodeDef const_input_node = MakeConstNodeDefFromShape(
+ ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+ ctx->input_types[i]);
+ TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
+ &grappler_function_item));
+ }
+ }
+
// Perform inference on function body.
GraphProperties gp(grappler_function_item);
TF_RETURN_IF_ERROR(gp.InferStatically(true));
// Add return nodes for output shapes.
- auto ic = GetContext(node);
int output = 0;
+ ctx->output_tensors_as_shapes.resize(grappler_function_item.output_size());
for (auto const& out_arg : grappler_function_item.outputs()) {
if (out_arg.output_tensors.size() > 1) {
// TODO(jmdecker): Handle case of multiple output tensors
@@ -505,8 +590,9 @@ class SymbolicShapeRefiner {
const NodeDef* retnode = gv.GetNode(node_name);
if (retnode == nullptr) {
- return errors::FailedPrecondition("Unable to find return node ",
- node_name, " for ", node->name());
+ return errors::FailedPrecondition(
+ "Unable to find return function_node ", node_name, " for ",
+ function_node->name());
}
auto output_properties = gp.GetOutputProperties(retnode->name());
@@ -520,6 +606,14 @@ class SymbolicShapeRefiner {
ShapeHandle out;
TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
ic->set_output(output, out);
+ if (outprop.has_value()) {
+ // Forward tensor value to output_tensors_as_shape.
+ Tensor tensor;
+ if (tensor.FromProto(outprop.value())) {
+ MaybeSetTensorValueToShape(ic, tensor,
+ &ctx->output_tensors_as_shapes[output]);
+ }
+ }
output++;
}
@@ -562,21 +656,9 @@ class SymbolicShapeRefiner {
if (const_values[dst_input].FromProto(
input->attr().at("value").tensor())) {
input_tensors[dst_input] = &const_values[dst_input];
- // Integer tensors of rank one can also be interpreted as a shape
- // provided all their values are >= -1.
- if (const_values[dst_input].dims() == 1 &&
- (const_values[dst_input].dtype() == DT_INT32 ||
- const_values[dst_input].dtype() == DT_INT64)) {
- ShapeHandle tensor_shape = inference_context->Vector(
- const_values[dst_input].NumElements());
- ShapeHandle shp;
- if (inference_context
- ->MakeShapeFromTensor(input_tensors[dst_input],
- tensor_shape, &shp)
- .ok()) {
- input_tensors_as_shapes[dst_input] = shp;
- }
- }
+ MaybeSetTensorValueToShape(inference_context,
+ const_values[dst_input],
+ &input_tensors_as_shapes[dst_input]);
}
} else if (IsRank(*input)) {
if (c->inference_context->RankKnown(c->inference_context->input(0))) {
@@ -671,11 +753,13 @@ class SymbolicShapeRefiner {
// true, as the updates to the call node will have changed, even if it's
// the same function being called twice with the same input shapes.
// Example: simple_function.pbtxt
- if (UpdateFunction(node).ok()) {
+ auto s = UpdateFunction(node);
+ if (s.ok()) {
return Status::OK();
} else {
VLOG(1) << "UpdateFunction failed for " << node->op()
- << ". Defaulting to ShapeUnknown.";
+ << ". Defaulting to ShapeUnknown.\n"
+ << s.ToString();
}
}
@@ -942,13 +1026,25 @@ class SymbolicShapeRefiner {
: t->scalar<int64>()();
dims.push_back(size < 0 ? ic->UnknownDim() : ic->MakeDim(size));
} else {
- dims.push_back(ic->UnknownDim());
+ // Don't have tensor value, but use input_tensors_as_shapes, if
+ // possible.
+ const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
+ if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 &&
+ ic->ValueKnown(ic->Dim(shape_handle, 0))) {
+ dims.push_back(ic->Dim(shape_handle, 0));
+ } else {
+ dims.push_back(ic->UnknownDim());
+ }
}
}
if (valid) {
c->output_tensors_as_shapes.resize(1);
c->output_tensors_as_shapes[0] = ic->MakeShape(dims);
}
+ } else if (IsIdentity(node)) {
+ // Pass input_tensors_as_shapes to output_tensors_as_shapes.
+ c->output_tensors_as_shapes.resize(1);
+ c->output_tensors_as_shapes[0] = ic->input_tensors_as_shapes()[0];
} else if (IsSlice(node)) {
ShapeHandle input = ic->input_tensors_as_shapes()[0];
bool valid = ic->RankKnown(input);
@@ -1053,6 +1149,46 @@ class SymbolicShapeRefiner {
}
private:
+ bool IsIntegerVector(const Tensor& tensor) {
+ if (tensor.dims() == 1 &&
+ (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64)) {
+ return true;
+ }
+ return false;
+ }
+
+ bool IsIntegerScalar(const Tensor& tensor) {
+ if (tensor.dims() == 0 &&
+ (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64) &&
+ tensor.NumElements() == 1) {
+ return true;
+ }
+ return false;
+ }
+
+ void MaybeSetTensorValueToShape(InferenceContext* ic, const Tensor& tensor,
+ ShapeHandle* tensors_as_shapes) {
+ // Integer tensors of rank one can also be interpreted as a shape
+ // provided all their values are >= -1.
+ if (IsIntegerVector(tensor)) {
+ ShapeHandle tensor_shape = ic->Vector(tensor.NumElements());
+ ShapeHandle shp;
+ // Note that MakeShapeFromTensor filters out invalid values (e.g., < -1).
+ if (ic->MakeShapeFromTensor(&tensor, tensor_shape, &shp).ok()) {
+ *tensors_as_shapes = shp;
+ }
+ } else if (IsIntegerScalar(tensor)) {
+ // Scalar constant.
+ int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0)
+ : tensor.flat<int64>()(0);
+ // Ideally, values can be < -1, but MakeDim() fails with a value < -1.
+ // It's a limitation as we use ShapeHandle as a means to pass values.
+ if (value >= -1) {
+ *tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)});
+ }
+ }
+ }
+
const GraphView& graph_;
int graph_def_version_;
std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
@@ -1528,6 +1664,8 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
continue;
}
+ auto* ic = ctx->inference_context.get();
+
// Fill input properties.
{
auto& input_properties = input_properties_[node.name()];
@@ -1535,19 +1673,26 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
// Should always be empty, node names in graph are supposed to be unique.
CHECK_EQ(input_properties.size(), 0);
- input_properties.resize(ctx->inference_context->num_inputs());
+ input_properties.resize(ic->num_inputs());
GraphView::InputPort input(&node, -1);
- for (int i = 0; i < ctx->inference_context->num_inputs(); ++i) {
- shape_manager.AsTensorProperties(ctx->inference_context->input(i),
- ctx->input_types[i],
+ for (int i = 0; i < ic->num_inputs(); ++i) {
+ shape_manager.AsTensorProperties(ic->input(i), ctx->input_types[i],
&input_properties[i]);
input.port_id = i;
GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
- if (!IsConstant(*fanin.node)) {
- continue;
+ // Export tensor value (either const tensor or input_tensors_as_shapes)
+ // to input_properties.value.
+ if (IsConstant(*fanin.node)) {
+ const TensorProto& raw_val = fanin.node->attr().at("value").tensor();
+ *input_properties[i].mutable_value() = raw_val;
+ } else if (ic->input_tensors_as_shapes().size() > i &&
+ IsShapeFullyDefinedIntegerVectorOrScalar(
+ ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+ ctx->input_types[i])) {
+ *input_properties[i].mutable_value() = MakeTensorProtoFromShape(
+ ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+ ctx->input_types[i]);
}
- const TensorProto& raw_val = fanin.node->attr().at("value").tensor();
- *input_properties[i].mutable_value() = raw_val;
}
}
@@ -1558,11 +1703,23 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
// Should always be empty, node names in graph are supposed to be unique.
CHECK_EQ(output_properties.size(), 0);
- output_properties.resize(ctx->inference_context->num_outputs());
- for (int i = 0; i < ctx->inference_context->num_outputs(); ++i) {
- shape_manager.AsTensorProperties(ctx->inference_context->output(i),
- ctx->output_types[i],
+ output_properties.resize(ic->num_outputs());
+ for (int i = 0; i < ic->num_outputs(); ++i) {
+ shape_manager.AsTensorProperties(ic->output(i), ctx->output_types[i],
&output_properties[i]);
+ // Export tensor value (either const tensor or input_tensors_as_shapes)
+ // to output_properties.value.
+ if (IsConstant(node)) {
+ const TensorProto& raw_val = node.attr().at("value").tensor();
+ *output_properties[i].mutable_value() = raw_val;
+ } else if (ctx->output_tensors_as_shapes.size() > i &&
+ IsShapeFullyDefinedIntegerVectorOrScalar(
+ ic, ic->output(i), ctx->output_tensors_as_shapes[i],
+ ctx->output_types[i])) {
+ *output_properties[i].mutable_value() = MakeTensorProtoFromShape(
+ ic, ic->output(i), ctx->output_tensors_as_shapes[i],
+ ctx->output_types[i]);
+ }
}
}
}
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 8938b7c32e..362092a6cf 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -44,6 +44,30 @@ class GraphPropertiesTest : public ::testing::Test {
// Provision a single machine with 3 cpu cores
cluster_.reset(new SingleMachine(5 * 60, 3, 0));
TF_CHECK_OK(cluster_->Provision());
+
+ // This function is simply
+ // out = Fill(shape, value), but
+ // Fill requires values in the shape input, not just shape of it, to infer
+ // output shape.
+ auto f = FunctionDefHelper::Create(
+ // Name
+ "MyFillFunc",
+ // Inputs
+ {"shape: int32", "value: float"},
+ // Outputs
+ {"out: float"},
+ // Attrs
+ {},
+ // Nodes
+ {
+ {{"a"},
+ "Fill",
+ {"shape", "value"},
+ {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
+ },
+ // Returns
+ {{"out", "a:output:0"}});
+ function_lib_.add_function()->Swap(&f);
}
void TearDown() override {
@@ -69,7 +93,29 @@ class GraphPropertiesTest : public ::testing::Test {
return s;
}
+ // Compare values of integer (DT_INT32 or DT_INT64) tensor against expected
+ // ones.
+ void ExpectTensorValues(const std::vector<int64>& expected,
+ const TensorProto& tensor_proto_to_compare) {
+ Tensor tensor;
+ EXPECT_TRUE(tensor.FromProto(tensor_proto_to_compare));
+ EXPECT_EQ(expected.size(), tensor.NumElements());
+ // We're interested in only integer tensors as only shapes are exported as
+ // graph properties values.
+ CHECK(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64);
+ if (tensor.dtype() == DT_INT32) {
+ for (int i = 0; i < tensor.NumElements(); i++) {
+ EXPECT_EQ(expected[i], tensor.flat<int32>()(i));
+ }
+ } else {
+ for (int i = 0; i < tensor.NumElements(); i++) {
+ EXPECT_EQ(expected[i], tensor.flat<int64>()(i));
+ }
+ }
+ }
+
std::unique_ptr<SingleMachine> cluster_;
+ FunctionDefLibrary function_lib_;
};
TEST_F(GraphPropertiesTest, StaticProperties) {
@@ -785,7 +831,220 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
EXPECT_EQ("float: [128,256]", PropToString(prop));
}
-TEST_F(GraphPropertiesTest, FunctionWithScalarInputTest) {
+TEST_F(GraphPropertiesTest, TensorAsShapesPropagation) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
+ Output a1 = ops::Identity(s.WithOpName("a1"), a);
+ Output b = ops::Const(s.WithOpName("b"), 99, {});
+ Output b1 = ops::Identity(s.WithOpName("b1"), b);
+ Output c = ops::Const(s.WithOpName("c"), 1, {4, 4, 4});
+ Output c1 = ops::Identity(s.WithOpName("c1"), c);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ // Check output shapes.
+ EXPECT_EQ("int32: [2]", PropToString(properties.GetOutputProperties("a")[0]));
+ EXPECT_EQ("int32: [2]",
+ PropToString(properties.GetOutputProperties("a1")[0]));
+ EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b")[0]));
+ EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b1")[0]));
+ EXPECT_EQ("int32: [4,4,4]",
+ PropToString(properties.GetOutputProperties("c")[0]));
+ EXPECT_EQ("int32: [4,4,4]",
+ PropToString(properties.GetOutputProperties("c1")[0]));
+
+ // Check has_value.
+ EXPECT_TRUE(properties.GetOutputProperties("a")[0].has_value());
+ EXPECT_TRUE(properties.GetInputProperties("a1")[0].has_value());
+ EXPECT_TRUE(properties.GetOutputProperties("a1")[0].has_value());
+ EXPECT_TRUE(properties.GetOutputProperties("b")[0].has_value());
+ EXPECT_TRUE(properties.GetInputProperties("b1")[0].has_value());
+ EXPECT_TRUE(properties.GetOutputProperties("b1")[0].has_value());
+ EXPECT_TRUE(properties.GetOutputProperties("c")[0].has_value());
+ EXPECT_TRUE(properties.GetInputProperties("c1")[0].has_value());
+ // Note that we propagate tensro value of only 1D vector and scalar.
+ EXPECT_FALSE(properties.GetOutputProperties("c1")[0].has_value());
+
+ // Check values.
+ ExpectTensorValues({5, 7}, properties.GetOutputProperties("a")[0].value());
+ ExpectTensorValues({5, 7}, properties.GetInputProperties("a1")[0].value());
+ ExpectTensorValues({5, 7}, properties.GetOutputProperties("a1")[0].value());
+ ExpectTensorValues({99}, properties.GetOutputProperties("b")[0].value());
+ ExpectTensorValues({99}, properties.GetInputProperties("b1")[0].value());
+ ExpectTensorValues({99}, properties.GetOutputProperties("b1")[0].value());
+ std::vector<int64> c_values;
+ for (int i = 0; i < 4 * 4 * 4; i++) {
+ c_values.push_back(1);
+ }
+ ExpectTensorValues({c_values},
+ properties.GetOutputProperties("c")[0].value());
+ ExpectTensorValues({c_values},
+ properties.GetInputProperties("c1")[0].value());
+ // No output value for c1, as it's neither 1D vector nor scalar.
+}
+
+TEST_F(GraphPropertiesTest, IdentityPassingShape) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 5, {2});
+ Output b = ops::Identity(s.WithOpName("b"), a);
+ Output c = ops::Const(s.WithOpName("const"), 0.1f, {});
+ // Fill needs not only e's shape but also the value of e to figure out output
+ // shape; hence, Identity op (b) should pass a's value as
+ // output_tensors_as_shape.
+ Output d = ops::Fill(s.WithOpName("fill"), b, c);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("fill");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [5,5]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, PackWithConstInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {});
+ Output b = ops::Const(s.WithOpName("b"), 2, {});
+ Output c = ops::Const(s.WithOpName("c"), 3, {});
+ Output d = ops::Const(s.WithOpName("d"), 4, {});
+ // Note ops::Stack instantiates Pack op.
+ Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
+ // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
+ Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
+ // Fill needs not only e's shape but also its value to figure out output
+ // shape.
+ Output g = ops::Fill(s.WithOpName("fill"), e, f);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("fill");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, PackWithIdentityInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ // Same to PackWithConstInput test case, but a, b, c, and d are Identity ops
+ // from Const.
+ // If output_tensors_as_shape is not not set for those Shape ops or Pack op
+ // doesn't take input_tensors_as_shape, Fill op's input doesn't have value;
+ // hence, its output shape becomes unknown.
+ Output a0 = ops::Const(s.WithOpName("a0"), 1, {});
+ Output b0 = ops::Const(s.WithOpName("b0"), 2, {});
+ Output c0 = ops::Const(s.WithOpName("c0"), 3, {});
+ Output d0 = ops::Const(s.WithOpName("d0"), 4, {});
+ Output a = ops::Identity(s.WithOpName("a"), a0);
+ Output b = ops::Identity(s.WithOpName("b"), b0);
+ Output c = ops::Identity(s.WithOpName("c"), c0);
+ Output d = ops::Identity(s.WithOpName("d"), d0);
+ // Note ops::Stack instantiates Pack op.
+ Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
+ // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
+ Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
+ // Fill needs not only e's shape but also its value to figure out output
+ // shape.
+ Output g = ops::Fill(s.WithOpName("fill"), e, f);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("fill");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_));
+ Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4});
+ Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
+ auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
+ s.graph()->op_registry());
+ tensorflow::Node* func_op;
+ auto _shape = tensorflow::ops::AsNodeOut(s, shape);
+ auto _value = tensorflow::ops::AsNodeOut(s, value);
+ TF_CHECK_OK(
+ builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("MyFillFunc");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, FunctionWithIdentityOfConstInput) {
+ // Same to FunctionWithConstInput, but function inputs are Identity of Const,
+ // so tensor shapes, not tensor value, should be used as Const input to
+ // function.
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_));
+ Output shape_ = ops::Const(s.WithOpName("shape_"), {1, 2, 3, 4});
+ Output shape = ops::Identity(s.WithOpName("shape"), shape_);
+ Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
+ auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
+ s.graph()->op_registry());
+ tensorflow::Node* func_op;
+ auto _shape = tensorflow::ops::AsNodeOut(s, shape);
+ auto _value = tensorflow::ops::AsNodeOut(s, value);
+ TF_CHECK_OK(
+ builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("MyFillFunc");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) {
+ FunctionDefLibrary library;
+ *library.add_function() = FunctionDefHelper::Create(
+ "MyFunc", // Name
+ {"x: int32"}, // Inputs
+ {"out: int32"}, // Outputs
+ {}, // Attrs
+ {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_INT32}}}}, // Nodes
+ {{"out", "a:output:0"}}); // Returns
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
+
+ // MyFunc takes Const (shape) and passes it with Identity. Expect function
+ // output has the same shape as well as value (output_tensors_as_shape) as
+ // input Const tensor.
+ Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
+ auto _shape = tensorflow::ops::AsNodeOut(s, shape);
+ auto builder =
+ tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
+ tensorflow::Node* func_op;
+ TF_CHECK_OK(builder.Input(_shape).Finalize(s.graph(), &func_op));
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(true));
+ const auto out_props = properties.GetOutputProperties("MyFunc");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("int32: [2]", PropToString(out_prop0));
+ EXPECT_TRUE(out_prop0.has_value());
+ ExpectTensorValues({5, 7}, out_prop0.value());
+ ExpectTensorValues({5, 7},
+ properties.GetInputProperties("MyFunc")[0].value());
+}
+
+TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
// Create graph with a function that takes a scalar value so that we use
// Placeholder with scalar as for input to the function shape inference.
// Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of
@@ -818,7 +1077,7 @@ TEST_F(GraphPropertiesTest, FunctionWithScalarInputTest) {
// MyFunc output shouldn't be unknown rank.
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically(false));
+ TF_CHECK_OK(properties.InferStatically(true));
const auto out_props = properties.GetOutputProperties("MyFunc");
const OpInfo::TensorProperties out_prop0 = out_props[0];
EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
@@ -856,18 +1115,10 @@ TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
@@ -882,51 +1133,25 @@ TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
EXPECT_EQ(2, out_props.size());
const OpInfo::TensorProperties& out_prop0 = out_props[0];
- EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
- EXPECT_EQ(4, out_prop0.shape().dim_size());
- EXPECT_EQ(128, out_prop0.shape().dim(0).size());
- EXPECT_EQ(112, out_prop0.shape().dim(1).size());
- EXPECT_EQ(112, out_prop0.shape().dim(2).size());
- EXPECT_EQ(64, out_prop0.shape().dim(3).size());
+ EXPECT_EQ("float: [128,112,112,64]", PropToString(out_prop0));
const OpInfo::TensorProperties& out_prop1 = out_props[1];
- EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
- EXPECT_EQ(128, out_prop1.shape().dim(0).size());
- EXPECT_EQ(112, out_prop1.shape().dim(1).size());
- EXPECT_EQ(112, out_prop1.shape().dim(2).size());
- EXPECT_EQ(24, out_prop1.shape().dim(3).size());
+ EXPECT_EQ("float: [128,112,112,24]", PropToString(out_prop1));
const auto in_props = properties.GetInputProperties("y0");
EXPECT_EQ(4, in_props.size());
const OpInfo::TensorProperties& in_prop0 = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop0.dtype());
- EXPECT_EQ(1, in_prop0.shape().dim_size());
- EXPECT_EQ(64, in_prop0.shape().dim(0).size());
+ EXPECT_EQ("float: [64]", PropToString(in_prop0));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_EQ(4, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(1, in_prop1.shape().dim(1).size());
- EXPECT_EQ(24, in_prop1.shape().dim(2).size());
- EXPECT_EQ(64, in_prop1.shape().dim(3).size());
+ EXPECT_EQ("float: [1,1,24,64]", PropToString(in_prop1));
const OpInfo::TensorProperties& in_prop2 = in_props[2];
- EXPECT_EQ(DT_FLOAT, in_prop2.dtype());
- EXPECT_EQ(4, in_prop2.shape().dim_size());
- EXPECT_EQ(128, in_prop2.shape().dim(0).size());
- EXPECT_EQ(224, in_prop2.shape().dim(1).size());
- EXPECT_EQ(224, in_prop2.shape().dim(2).size());
- EXPECT_EQ(3, in_prop2.shape().dim(3).size());
+ EXPECT_EQ("float: [128,224,224,3]", PropToString(in_prop2));
const OpInfo::TensorProperties& in_prop3 = in_props[3];
- EXPECT_EQ(DT_FLOAT, in_prop3.dtype());
- EXPECT_EQ(4, in_prop3.shape().dim_size());
- EXPECT_EQ(7, in_prop3.shape().dim(0).size());
- EXPECT_EQ(7, in_prop3.shape().dim(1).size());
- EXPECT_EQ(3, in_prop3.shape().dim(2).size());
- EXPECT_EQ(8, in_prop3.shape().dim(3).size());
+ EXPECT_EQ("float: [7,7,3,8]", PropToString(in_prop3));
}
TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) {
@@ -986,18 +1211,10 @@ TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
@@ -1022,27 +1239,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
const OpInfo::TensorProperties& out_prop = out_props[0];
EXPECT_EQ(DT_FLOAT, out_prop.dtype());
- EXPECT_FALSE(out_prop.shape().unknown_rank());
- EXPECT_EQ(2, out_prop.shape().dim_size());
- EXPECT_EQ(1, out_prop.shape().dim(0).size());
- EXPECT_EQ(2, out_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(out_prop));
const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
@@ -1066,28 +1272,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
TF_CHECK_OK(properties.InferStatically(false));
const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
const OpInfo::TensorProperties& out_prop = out_props[0];
- EXPECT_EQ(DT_FLOAT, out_prop.dtype());
- EXPECT_FALSE(out_prop.shape().unknown_rank());
- EXPECT_EQ(2, out_prop.shape().dim_size());
- EXPECT_EQ(1, out_prop.shape().dim(0).size());
- EXPECT_EQ(2, out_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(out_prop));
const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
@@ -1115,28 +1309,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
TF_CHECK_OK(properties.InferStatically(false));
const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I");
const OpInfo::TensorProperties& out_prop = out_props[0];
- EXPECT_EQ(DT_FLOAT, out_prop.dtype());
- EXPECT_FALSE(out_prop.shape().unknown_rank());
- EXPECT_EQ(2, out_prop.shape().dim_size());
- EXPECT_EQ(1, out_prop.shape().dim(0).size());
- EXPECT_EQ(2, out_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(out_prop));
const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I");
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(3, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,3]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, SymbolicShapes) {
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index 7691f25327..5415324b48 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -127,7 +127,7 @@ static void ExtractExtraProperties(
// For filename input, the file size can also be useful.
if (op_def && i < op_def->input_arg_size() &&
- op_def->input_arg(i).name().find("filename") != std::string::npos) {
+ op_def->input_arg(i).name().find("filename") != string::npos) {
Tensor tensor;
if (!tensor.FromProto(t)) {
continue;
@@ -153,7 +153,7 @@ static void ExtractExtraProperties(
// When the input is a handle (e.g. look up table handle), the information
// in the op itself is not sufficient to predict the op memory.
if (op_def && i < op_def->input_arg_size() &&
- op_def->input_arg(i).name().find("handle") != std::string::npos) {
+ op_def->input_arg(i).name().find("handle") != string::npos) {
string new_key = strings::StrCat("parent_", i, "_op");
AttrValue attr;
attr.set_s(input_node->op());
@@ -320,8 +320,8 @@ void TensorSizeHistogram::Merge(const TensorSizeHistogram& src) {
buckets_.begin(), std::plus<uint64>());
}
-std::string TensorSizeHistogram::ToString() const {
- std::string r;
+string TensorSizeHistogram::ToString() const {
+ string r;
char buf[200];
snprintf(buf, sizeof(buf), "Count: %lld, Average: ", num_elem_);
r.append(buf);
diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h
index d2c7c67666..5fd6717712 100644
--- a/tensorflow/core/grappler/costs/utils.h
+++ b/tensorflow/core/grappler/costs/utils.h
@@ -80,7 +80,7 @@ class TensorSizeHistogram {
uint64 Max() const { return max_; }
uint64 NumElem() const { return num_elem_; }
uint64 SumElem() const { return sum_elem_; }
- std::string ToString() const;
+ string ToString() const;
protected:
const int Index(const uint64 value) const;
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index 02a379fca8..80889afc86 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -1999,13 +1999,13 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
// Helper lambda to extract port num from _Send and _Recv op name.
auto get_port_num = [](const string& name) -> int {
- if (name.find("bn_0") != std::string::npos) {
+ if (name.find("bn_0") != string::npos) {
return 0;
- } else if (name.find("bn_1") != std::string::npos) {
+ } else if (name.find("bn_1") != string::npos) {
return 1;
- } else if (name.find("bn_2") != std::string::npos) {
+ } else if (name.find("bn_2") != string::npos) {
return 2;
- } else if (name.find("bn_minus1") != std::string::npos) {
+ } else if (name.find("bn_minus1") != string::npos) {
return -1;
}
return -999;
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h
index 26d38a4931..97626346c7 100644
--- a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h
@@ -138,7 +138,7 @@ class GraphAnalyzer {
// The entries are owned by collation_map_, so must be removed from
// ordered_collation_ before removing them from collation_map_.
struct ReverseLessByCount {
- bool operator()(CollationEntry* left, CollationEntry* right) {
+ bool operator()(CollationEntry* left, CollationEntry* right) const {
return left->count > right->count; // Reverse order.
}
};
diff --git a/tensorflow/core/grappler/inputs/utils.cc b/tensorflow/core/grappler/inputs/utils.cc
index 5029dff877..def9198a69 100644
--- a/tensorflow/core/grappler/inputs/utils.cc
+++ b/tensorflow/core/grappler/inputs/utils.cc
@@ -14,10 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/inputs/utils.h"
-#include "tensorflow/core/platform/env.h"
#include <vector>
+#include "tensorflow/core/platform/env.h"
+
namespace tensorflow {
namespace grappler {
@@ -29,12 +30,12 @@ bool FilesExist(const std::set<string>& files) {
return FilesExist(std::vector<string>(files.begin(), files.end()), nullptr);
}
-bool FileExists(const std::string& file, Status* status) {
+bool FileExists(const string& file, Status* status) {
*status = Env::Default()->FileExists(file);
return status->ok();
}
-Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path,
+Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path,
GraphDef* result) {
Status status;
if (FileExists(graph_def_pbtxt_path, &status)) {
diff --git a/tensorflow/core/grappler/inputs/utils.h b/tensorflow/core/grappler/inputs/utils.h
index 627dd5359f..4b9cb0a9ad 100644
--- a/tensorflow/core/grappler/inputs/utils.h
+++ b/tensorflow/core/grappler/inputs/utils.h
@@ -29,9 +29,9 @@ bool FilesExist(const std::vector<string>& files,
std::vector<Status>* status = nullptr);
bool FilesExist(const std::set<string>& files);
-bool FileExists(const std::string& file, Status* status);
+bool FileExists(const string& file, Status* status);
-Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path,
+Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path,
GraphDef* result);
} // end namespace grappler
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 653b088b1d..3521669b63 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -135,16 +135,37 @@ bool IsDequeueOp(const NodeDef& node) {
bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
-bool IsElementWiseMonotonic(const NodeDef& node) {
- static const std::unordered_set<string>* element_wise_monotonic_ops =
+// Returns true if node represents a unary elementwise function that is
+// monotonic. If *is_non_decreasing is true, the function is non-decreasing,
+// e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
+// e.g. inv.
+bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
+ static const std::unordered_set<string>* monotonic_non_decreasing_ops =
CHECK_NOTNULL((new std::unordered_set<string>{
- "Relu",
- "Relu6",
- "Sigmoid",
- "Sqrt",
- "Tanh",
+ "Asinh", "Atanh", "Ceil", "Elu", "Erf", "Exp", "Expm1",
+ "Floor", "Log", "Log1p", "Relu", "Relu", "Relu6", "Rint",
+ "Selu", "Sigmoid", "Sign", "Sinh", "Sqrt", "Tanh",
+ }));
+ static const std::unordered_set<string>* monotonic_non_increasing_ops =
+ CHECK_NOTNULL((new std::unordered_set<string>{
+ "Inv",
+ "Reciprocal",
+ "Erfc",
+ "Rsqrt",
+ "Neg",
}));
- return element_wise_monotonic_ops->count(node.op()) > 0;
+ if (monotonic_non_decreasing_ops->count(node.op()) > 0) {
+ if (is_non_decreasing) {
+ *is_non_decreasing = true;
+ }
+ return true;
+ } else if (monotonic_non_increasing_ops->count(node.op()) > 0) {
+ if (is_non_decreasing) {
+ *is_non_decreasing = false;
+ }
+ return true;
+ }
+ return false;
}
bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
@@ -470,7 +491,7 @@ bool IsFreeOfSideEffect(const NodeDef& node) {
}
}
// Queue ops modify the queue which is a side effect.
- if (node.op().find("Queue") != std::string::npos) {
+ if (node.op().find("Queue") != string::npos) {
return false;
}
return !ModifiesInputsInPlace(node);
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 94439265c9..25ab6b65ac 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -55,7 +55,7 @@ bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node);
bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node);
bool IsDiv(const NodeDef& node);
-bool IsElementWiseMonotonic(const NodeDef& node);
+bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing);
bool IsEluGrad(const NodeDef& node);
bool IsEnter(const NodeDef& node);
bool IsEqual(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 70ad9f9a9b..029205248b 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -110,12 +110,13 @@ cc_library(
],
)
-tf_cuda_cc_test(
+tf_cc_test(
name = "constant_folding_test",
srcs = ["constant_folding_test.cc"],
- tags = ["requires-gpu-sm35"],
+ shard_count = 5,
deps = [
":constant_folding",
+ ":dependency_optimizer",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/core:all_kernels",
@@ -514,6 +515,7 @@ cc_library(
":custom_graph_optimizer_registry",
":debug_stripper",
":dependency_optimizer",
+ ":experimental_implementation_selector",
":function_optimizer",
":graph_optimizer",
":layout_optimizer",
@@ -845,3 +847,68 @@ tf_cc_test(
"//third_party/eigen3",
],
)
+
+cc_library(
+ name = "function_api_info",
+ srcs = ["function_api_info.cc"],
+ hdrs = ["function_api_info.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_cc_test(
+ name = "function_api_info_test",
+ size = "small",
+ srcs = ["function_api_info_test.cc"],
+ deps = [
+ ":function_api_info",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "experimental_implementation_selector",
+ srcs = ["experimental_implementation_selector.cc"],
+ hdrs = ["experimental_implementation_selector.h"],
+ deps = [
+ ":custom_graph_optimizer",
+ ":custom_graph_optimizer_registry",
+ ":function_api_info",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/costs:graph_properties",
+ ],
+)
+
+tf_cc_test(
+ name = "experimental_implementation_selector_test",
+ size = "small",
+ srcs = ["experimental_implementation_selector_test.cc"],
+ deps = [
+ ":custom_graph_optimizer",
+ ":custom_graph_optimizer_registry",
+ ":experimental_implementation_selector",
+ ":function_api_info",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ "//tensorflow/core/grappler/utils:grappler_test",
+ ],
+)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 4fed88d536..992e85d2c6 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1121,11 +1121,8 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
NodeDef* tail = node;
- // TODO(rmlarsen): Enable after debugging breakage in Bayesflow.
- if (ctx().opt_level == RewriterConfig::AGGRESSIVE) {
- tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
- *ctx().nodes_to_preserve);
- }
+ tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
+ *ctx().nodes_to_preserve);
NodeDef* first_transpose;
TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
@@ -1328,38 +1325,26 @@ class RemoveNegationStage : public ArithmeticOptimizerStage {
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- const string node_name = node->name();
NodeDef* x;
NodeDef* y;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
bool updated = false;
- if (IsAdd(*node)) {
- if (IsNeg(*x)) {
- // (-a) + b = b - a
- node->set_op("Sub");
- node->mutable_input()->SwapElements(0, 1);
- node->set_input(1, x->input(0));
- node->add_input(AsControlDependency(x->name()));
- ctx().node_map->AddOutput(NodeName(x->input(0)), node_name);
- updated = true;
- } else if (IsNeg(*y)) {
- // a + (-b) = a - b
- node->set_op("Sub");
- node->set_input(1, y->input(0));
- node->add_input(AsControlDependency(y->name()));
- ctx().node_map->AddOutput(NodeName(y->input(0)), node_name);
- updated = true;
- }
- } else if (IsSub(*node)) {
- if (IsNeg(*y)) {
- // a - (-b) = a + b
- node->set_op("Add");
- node->set_input(1, y->input(0));
- node->add_input(AsControlDependency(y->name()));
- ctx().node_map->AddOutput(NodeName(y->input(0)), node_name);
- updated = true;
- }
+ if (IsNeg(*y)) {
+ // a - (-b) = a + b or a + (-b) = a - b
+ ForwardControlDependencies(node, {y});
+ ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0));
+ node->set_op(IsAdd(*node) ? "Sub" : "Add");
+ node->set_input(1, y->input(0));
+ updated = true;
+ } else if (IsAdd(*node) && IsNeg(*x)) {
+ // (-a) + b = b - a
+ ForwardControlDependencies(node, {x});
+ ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0));
+ node->set_op("Sub");
+ node->mutable_input()->SwapElements(0, 1);
+ node->set_input(1, x->input(0));
+ updated = true;
}
if (updated) {
AddToOptimizationQueue(node);
@@ -2706,8 +2691,9 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
// 0. inner_function is not in the preserve set,
// 1. inner_function's Op is element-wise monotonic
// 2. inner_function's output is not being consumed elsewhere.
+ bool is_non_decreasing = false;
if (!IsInPreserveSet(*inner_function) &&
- IsElementWiseMonotonic(*inner_function) &&
+ IsElementWiseMonotonic(*inner_function, &is_non_decreasing) &&
ctx().node_map->GetOutputs(inner_function->name()).size() == 1) {
// Swap the first inputs of the inner function Op & the reduction Op.
NodeDef* inner_input;
@@ -2719,7 +2705,12 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
UpdateConsumers(reduction_node, inner_function->name());
ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
reduction_node->name());
-
+ if (!is_non_decreasing) {
+ // Flip Min<->Max if the function is non-increasing, e.g.
+ // Max(Neg(x)) = Neg(Min(x)).
+ const string opposite = IsMax(*reduction_node) ? "Min" : "Max";
+ reduction_node->set_op(opposite);
+ }
AddToOptimizationQueue(reduction_node);
AddToOptimizationQueue(inner_function);
AddToOptimizationQueue(inner_input);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 551c3652bf..d457eb6d21 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -61,7 +61,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool fold_multiply_into_conv = true;
bool fold_transpose_into_matmul = true;
bool hoist_common_factor_out_of_aggregation = true;
- bool hoist_cwise_unary_chains = false;
+ bool hoist_cwise_unary_chains = true;
bool minimize_broadcasts = true;
bool optimize_max_or_min_of_monotonic = true;
bool remove_idempotent = true;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index bfccc0affd..88839d944c 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -581,7 +581,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
const NodeDef* new_const = node_map.GetNode(optimized_const_name);
ASSERT_NE(new_const, nullptr);
EXPECT_EQ("^x", new_const->input(0));
- EXPECT_EQ(std::string("\0\0\0@", 4),
+ EXPECT_EQ(string("\0\0\0@", 4),
new_const->attr().at("value").tensor().tensor_content());
const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
@@ -625,7 +625,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
const NodeDef* new_const = node_map.GetNode(optimized_const_name);
ASSERT_NE(new_const, nullptr);
EXPECT_EQ("^x", new_const->input(0));
- EXPECT_EQ(std::string("\0\0\0@", 4),
+ EXPECT_EQ(string("\0\0\0@", 4),
new_const->attr().at("value").tensor().tensor_content());
const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
@@ -2353,9 +2353,14 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y);
Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y);
Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y);
- auto add_all = ops::AddN(s.WithOpName("add_all"),
- {add_x_y, add_negx_y, add_x_negy, add_negx_negy,
- sub_x_y, sub_negx_y, sub_x_negy, sub_negx_negy});
+ Output neg_x_with_dep = ops::Neg(
+ s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x);
+ Output add_negx_with_dep_y =
+ ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y);
+ auto add_all =
+ ops::AddN(s.WithOpName("add_all"),
+ {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y,
+ sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y});
GrapplerItem item;
item.fetch = {"add_all"};
@@ -2370,7 +2375,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveNegation(&optimizer);
- OptimizeAndPrune(&optimizer, &item, &output);
+ OptimizeTwice(&optimizer, &item, &output);
EXPECT_EQ(item.graph.node_size(), output.node_size());
int found = 0;
@@ -2379,42 +2384,43 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
if (node.name() == "Add_negx_y") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("x", node.input(1));
- EXPECT_EQ("^Neg_x", node.input(2));
} else if (node.name() == "Add_x_negy") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("y", node.input(1));
- EXPECT_EQ("^Neg_y", node.input(2));
} else if (node.name() == "Add_negx_negy") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(3, node.input_size());
- EXPECT_EQ("Neg_y", node.input(0));
- EXPECT_EQ("x", node.input(1));
- EXPECT_EQ("^Neg_x", node.input(2));
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("Neg_x", node.input(0));
+ EXPECT_EQ("y", node.input(1));
} else if (node.name() == "Sub_x_negy") {
++found;
EXPECT_EQ("Add", node.op());
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("y", node.input(1));
- EXPECT_EQ("^Neg_y", node.input(2));
} else if (node.name() == "Sub_negx_negy") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(4, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("x", node.input(1));
- EXPECT_EQ("^Neg_y", node.input(2));
- EXPECT_EQ("^Neg_x", node.input(3));
+ } else if (node.name() == "Add_negx_with_dep_y") {
+ ++found;
+ EXPECT_EQ("Sub", node.op());
+ EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ("y", node.input(0));
+ EXPECT_EQ("x", node.input(1));
+ EXPECT_EQ("^Add_x_y", node.input(2));
}
}
- EXPECT_EQ(5, found);
+ EXPECT_EQ(6, found);
auto tensors = EvaluateNodes(output, item.fetch, feed);
EXPECT_EQ(1, tensors.size());
@@ -3248,6 +3254,48 @@ TEST_F(ArithmeticOptimizerTest,
VerifyGraphsMatch(item.graph, output, __LINE__);
}
+TEST_F(ArithmeticOptimizerTest,
+ OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output neg = ops::Neg(s.WithOpName("neg"), x);
+ Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0});
+ Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
+
+ GrapplerItem item;
+ item.fetch = {"final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+ // Check if the inputs are switched
+ int required_node_count = 0;
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ if (node.name() == "neg") {
+ EXPECT_EQ("Neg", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("reduce_max", node.input(0));
+ ++required_node_count;
+ } else if (node.name() == "reduce_max") {
+ EXPECT_EQ("Min", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ ++required_node_count;
+ }
+ }
+ EXPECT_EQ(2, required_node_count);
+}
+
TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 815bd23307..99737a71eb 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -136,6 +136,27 @@ bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
return removed_input;
}
+bool GetConcatAxis(const GraphProperties& properties, NodeDef* node,
+ int* axis) {
+ if (node->op() != "ConcatV2" ||
+ properties.GetInputProperties(node->name()).empty()) {
+ return false;
+ }
+ const auto& axis_input = properties.GetInputProperties(node->name()).back();
+ if (!TensorShape::IsValid(axis_input.shape()) || !axis_input.has_value()) {
+ return false;
+ }
+
+ Tensor axis_tensor(axis_input.dtype(), axis_input.shape());
+ if (!axis_tensor.FromProto(axis_input.value())) {
+ return false;
+ }
+ *axis = axis_input.dtype() == DT_INT64
+ ? static_cast<int>(axis_tensor.scalar<int64>()())
+ : axis_tensor.scalar<int32>()();
+ return true;
+}
+
} // namespace
ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
@@ -852,19 +873,7 @@ DataType GetDataTypeFromNodeOrProps(const NodeDef& node,
}
return dtype;
}
-bool IsValidConstShapeForNCHW(const TensorShapeProto& shape) {
- if (shape.dim_size() != 4) {
- return false;
- }
- int num_dim_larger_than_one = 0;
- for (const auto& dim : shape.dim()) {
- if (dim.size() > 1) ++num_dim_larger_than_one;
- }
- return num_dim_larger_than_one <= 1;
-}
-const string& GetShape(const NodeDef& node) {
- return node.attr().at("data_format").s();
-}
+
} // namespace
// static
@@ -1711,7 +1720,7 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
return Status::OK();
}
- if (MulConvPushDown(*properties, optimized_graph, node)) {
+ if (MulConvPushDown(node, *properties)) {
graph_modified_ = true;
return Status::OK();
}
@@ -1731,6 +1740,11 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
return Status::OK();
}
+ if (MergeConcat(*properties, use_shape_info, optimized_graph, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
return Status::OK();
}
@@ -2553,9 +2567,8 @@ bool ConstantFolding::ConstantPushDown(NodeDef* node) {
return false;
}
-bool ConstantFolding::MulConvPushDown(const GraphProperties& properties,
- GraphDef* optimized_graph,
- NodeDef* node) {
+bool ConstantFolding::MulConvPushDown(NodeDef* node,
+ const GraphProperties& properties) {
// Push down multiplication on ConvND.
// * ConvND
// / \ / \
@@ -2631,14 +2644,12 @@ bool ConstantFolding::MulConvPushDown(const GraphProperties& properties,
}
const auto& const_shape = const_props[0].shape();
- if (GetShape(*conv_node) == "NHWC") {
- TensorShapeProto new_filter_shape;
- if (!ShapeAfterBroadcast(filter_shape, const_shape, &new_filter_shape)) {
- return false;
- }
- if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) {
- return false;
- }
+ TensorShapeProto new_filter_shape;
+ if (!ShapeAfterBroadcast(filter_shape, const_shape, &new_filter_shape)) {
+ return false;
+ }
+ if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) {
+ return false;
}
string mul_new_name =
@@ -2672,69 +2683,6 @@ bool ConstantFolding::MulConvPushDown(const GraphProperties& properties,
}
node_map_->AddNode(mul_new_name, node);
- if (GetShape(*conv_node) == "NCHW") {
- if (const_node->attr().at("value").tensor().tensor_shape().dim_size() <=
- 1) {
- // Broadcast should work for scalar or 1D. No need to reshape.
- return true;
- }
- if (!IsValidConstShapeForNCHW(
- const_node->attr().at("value").tensor().tensor_shape())) {
- return false;
- }
- // Adds Const node for Reshape.
- auto* shape_const_node = optimized_graph->add_node();
- const string shape_const_node_name =
- OptimizedNodeName(*const_node, "_new_shape");
- shape_const_node->set_name(shape_const_node_name);
- shape_const_node->set_op("Const");
- shape_const_node->set_device(const_node->device());
- (*shape_const_node->mutable_attr())["dtype"].set_type(DT_INT32);
- Tensor t(DT_INT32, {4});
- t.flat<int32>()(0) = 1;
- t.flat<int32>()(1) = 1;
- t.flat<int32>()(2) = 1;
- t.flat<int32>()(3) = const_node->attr()
- .at("value")
- .tensor()
- .tensor_shape()
- .dim(1) // IsValidConstShapeForNCHW guarantees
- // dim 1 is the dim to reshape
- .size();
- t.AsProtoTensorContent(
- (*shape_const_node->mutable_attr())["value"].mutable_tensor());
- node_map_->AddNode(shape_const_node_name, shape_const_node);
-
- // Adds Reshape node.
- auto* reshape_node = optimized_graph->add_node();
- const string reshape_node_name =
- OptimizedNodeName(*const_node, "_reshape");
- reshape_node->set_op("Reshape");
- reshape_node->set_name(reshape_node_name);
- reshape_node->set_device(const_node->device());
- (*reshape_node->mutable_attr())["T"].set_type(
- const_node->attr().at("dtype").type());
- (*reshape_node->mutable_attr())["Tshape"].set_type(DT_INT32);
- node_map_->AddNode(reshape_node_name, reshape_node);
-
- // const_node -> reshape_node
- node_map_->RemoveOutput(const_node->name(), node->name());
- *reshape_node->add_input() = const_node->name();
- node_map_->AddOutput(const_node->name(), reshape_node_name);
-
- // shape_const_node -> reshape_node
- *reshape_node->add_input() = shape_const_node_name;
- node_map_->AddOutput(shape_const_node_name, reshape_node_name);
-
- // reshape_node -> node (Mul)
- node_map_->AddOutput(reshape_node_name, node->name());
- if (left_child_is_constant) {
- node->set_input(0, reshape_node_name);
- } else {
- node->set_input(1, reshape_node_name);
- }
- }
-
return true;
}
return false;
@@ -2988,6 +2936,55 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
return false;
}
+bool ConstantFolding::MergeConcat(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node) {
+ // We only optimize for ConcatV2.
+ int axis;
+ if (!use_shape_info || !GetConcatAxis(properties, node, &axis) ||
+ nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
+ node_map_->GetOutputs(node->name()).size() != 1) {
+ return false;
+ }
+
+ NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
+ int parent_axis;
+ if (!GetConcatAxis(properties, parent, &parent_axis) || axis != parent_axis) {
+ return false;
+ }
+
+ const int index = NumNonControlInputs(*node) - 1;
+ auto inputs = parent->input();
+ parent->clear_input();
+ for (int i = 0; i < inputs.size(); ++i) {
+ if (IsSameInput(inputs.Get(i), node->name())) {
+ for (int j = 0; j < node->input_size(); ++j) {
+ if (j < index) {
+ // Input tensors (non axis), add to input list of parent.
+ parent->add_input(node->input(j));
+ node_map_->RemoveOutput(node->input(j), node->name());
+ node_map_->AddOutput(node->input(j), parent->name());
+ }
+ // Skip j == index, which means axis tensor.
+ if (j > index) {
+ // Control Dependencies, push back to inputs so they can be forwarded
+ // to parent.
+ *inputs.Add() = node->input(j);
+ }
+ }
+ } else {
+ parent->add_input(inputs.Get(i));
+ }
+ }
+ node->clear_input();
+ node->set_op("NoOp");
+ node->clear_attr();
+ node_map_->RemoveNode(node->name());
+ (*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
+
+ return true;
+}
+
Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
const GrapplerItem& item,
GraphDef* optimized_graph) {
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 051dfb681e..8593b3e0b8 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -125,8 +125,7 @@ class ConstantFolding : public GraphOptimizer {
// Aggregate constants present around a conv operator. Returns true if the
// transformation was applied successfully.
- bool MulConvPushDown(const GraphProperties& properties,
- GraphDef* optimized_graph, NodeDef* node);
+ bool MulConvPushDown(NodeDef* node, const GraphProperties& properties);
// Strength reduces floating point division by a constant Div(x, const) to
// multiplication by the reciprocal Mul(x, Reciprocal(const)).
@@ -210,6 +209,10 @@ class ConstantFolding : public GraphOptimizer {
// Removes Split or SplitV node if possible.
bool RemoveSplitOrSplitV(const GraphProperties& properties,
GraphDef* optimized_graph, NodeDef* node);
+
+ bool MergeConcat(const GraphProperties& properties, bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node);
+
// Points to an externally provided device or to owned_device_;
RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 0683572dcc..2a19b3f95a 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -240,7 +240,7 @@ TEST_F(ConstantFoldingTest, AddTree) {
}
}
-TEST_F(ConstantFoldingTest, ConvPushDownTestNHWC) {
+TEST_F(ConstantFoldingTest, ConvPushDownTest) {
// Tests if the following rewrite is performed:
//
// * Conv2D
@@ -2030,6 +2030,130 @@ TEST_F(ConstantFoldingTest, TileWithMultipliesBeingOne) {
CompareGraphs(want, got);
}
+TEST_F(ConstantFoldingTest, MergeConcat) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
+ Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
+ Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
+ Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
+
+ ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
+ ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
+
+ GrapplerItem item;
+ item.fetch = {"c2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("in3", "VariableV2", {}, {}, &want);
+ AddNode("axis", "Const", {}, {}, &want);
+ AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, MergeConcat_SameInput) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
+ Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
+ Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
+ Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
+
+ ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
+ ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3, Output(c1)}, axis);
+
+ GrapplerItem item;
+ item.fetch = {"c2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("in3", "VariableV2", {}, {}, &want);
+ AddNode("axis", "Const", {}, {}, &want);
+ AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "in1", "in2", "axis"}, {},
+ &want);
+
+ CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, MergeConcat_ConcatWithConst) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 6}, DT_FLOAT);
+ Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
+ Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
+ Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
+
+ ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
+ ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
+
+ GrapplerItem item;
+ item.fetch = {"c2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("in3", "VariableV2", {}, {}, &want);
+ AddNode("axis", "Const", {}, {}, &want);
+ AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, MergeConcat_AxisMismatch) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 5}, DT_FLOAT);
+ Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
+ Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
+ Output axis1 = ops::Const(scope.WithOpName("axis1"), 0, {});
+ Output axis2 = ops::Const(scope.WithOpName("axis2"), 1, {});
+
+ ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis2);
+ ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis1);
+
+ GrapplerItem item;
+ item.fetch = {"c2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("in3", "VariableV2", {}, {}, &want);
+ AddNode("axis1", "Const", {}, {}, &want);
+ AddNode("axis2", "Const", {}, {}, &want);
+ AddNode("c1", "ConcatV2", {"in1", "in2", "axis2"}, {}, &want);
+ AddNode("c2", "ConcatV2", {"c1", "in3", "axis1"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
TEST_F(ConstantFoldingTest, PaddingWithZeroSize) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
@@ -3080,110 +3204,6 @@ TEST_F(ConstantFoldingTest, FoldingPreservesDenormalFlushing) {
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
-#if GOOGLE_CUDA
-TEST_F(ConstantFoldingTest, ConvPushDownTestNCHW) {
- // Tests if the following rewrite is performed:
- //
- // * Conv2D
- // / \ / \
- // c Conv2D --> x (c * filter)
- // / \
- // x filter
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-
- int input_channel = 1;
- int output_channel = 2;
- int filter_size = 1;
-
- TensorShape filter_shape(
- {filter_size, filter_size, input_channel, output_channel});
-
- // Filter shape: [1, 1, 1, 2]
- // Filter for output channel 0 = {2.f}
- // Filter for output channel 1 = {-2.f}
- // clang-format off
- Output filter =
- ops::Const(s.WithOpName("filter"), {
- {
- {{2.f, -2.f}}
- }
- });
- // clang-format on
-
- int batch_size = 1;
- int matrix_size = 3;
- // input shape: [1,1,3,3]
- TensorShape input_shape(
- {batch_size, input_channel, matrix_size, matrix_size});
- Output input = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
- ops::Placeholder::Shape(input_shape));
-
- Output conv = ops::Conv2D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1},
- "VALID", ops::Conv2D::DataFormat("NCHW"));
- Output c = ops::Const(s.WithOpName("c"), 2.0f, /* shape */ {1, 2, 1, 1});
- Output mul = ops::Mul(s.WithOpName("mul"), c, conv);
-
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
- ConstantFolding fold(nullptr);
- GraphDef output;
- Status status = fold.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
-
- // Here only op/IO are checked. The values are verified by EvaluateNodes
- // below.
- int found = 0;
- for (const auto& node : output.node()) {
- if (node.name() == "mul") {
- ++found;
- EXPECT_EQ("Conv2D", node.op());
- EXPECT_EQ(2, node.input_size());
- EXPECT_EQ("x", node.input(0));
- EXPECT_EQ("conv/merged_input", node.input(1));
- } else if (node.name() == "conv/merged_input") {
- ++found;
- EXPECT_EQ("Const", node.op());
- EXPECT_EQ(0, node.input_size());
- }
- }
- EXPECT_EQ(2, found);
-
- // Check that const folded multiplication node has the expected value.
- std::vector<string> fetch = {"mul"};
- // Input shape (NCHW) is [1,1,3,3], filter is [1,1,1,2] output shape should be
- // (NCHW) [1,2,3,3]
- ::tensorflow::Input::Initializer x{
- {
- {
- {1.f, 2.f, 3.f}, // H = 0
- {4.f, 5.f, 6.f}, // H = 1
- {7.f, 8.f, 9.f} // H = 2
- } // C = 0
- } // N = 0
- };
-
- // |1,2,3|
- // conv( |4,5,6|, // input
- // |7,8,9|
- // [[[2,-2]]]) // filter
- // * [1,2,1,1] // mul by const
- // =
- // [
- // |4, 8, 12|
- // |16,20,24| ==> output channel 0
- // |28,32,36|
- //
- // | -4, -8,-12|
- // |-16,-20,-24| ==> output channel 1
- // |-28,-32,-36|
- // ]
- auto actual = EvaluateNodes(output, fetch, {{"x", x.tensor}});
- auto expected = EvaluateNodes(item.graph, fetch, {{"x", x.tensor}});
- test::ExpectTensorEqual<float>(expected[0], actual[0]);
-}
-#endif
-
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 530c957068..e84df10778 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -19,7 +19,6 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
- "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
@@ -56,8 +55,8 @@ cc_library(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
- "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:functional_ops",
+ "//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core:lib_internal",
] + tf_protos_all(),
@@ -107,7 +106,6 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
- "//tensorflow/core/kernels:cast_op",
],
)
@@ -164,7 +162,6 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
- "//tensorflow/core/kernels:cast_op", # Must be linked for the testlib functions to work.
],
)
@@ -256,7 +253,6 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
- "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
@@ -275,6 +271,43 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/kernels:control_flow_ops",
+ ],
+)
+
+cc_library(
+ name = "map_parallelization",
+ srcs = ["map_parallelization.cc"],
+ hdrs = [
+ "map_parallelization.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/utils:topological_sort",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "map_parallelization_test",
+ srcs = ["map_parallelization_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ ":map_parallelization",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
],
)
@@ -355,6 +388,7 @@ cc_library(
":map_and_batch_fusion",
":map_and_filter_fusion",
":map_fusion",
+ ":map_parallelization",
":map_vectorization",
":noop_elimination",
":shuffle_and_repeat_fusion",
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index 5a7fe19265..d4ab444036 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -273,7 +273,7 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
string name = string(prefix);
int id = graph->node_size();
while (ContainsGraphNodeWithName(name, *graph)) {
- if (name.rfind("_generated") != std::string::npos &&
+ if (name.rfind("_generated") != string::npos &&
(name.rfind("_generated") == (name.size() - strlen("_generated")))) {
name.insert(name.rfind("_generated"), strings::StrCat("/_", id));
} else {
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
new file mode 100644
index 0000000000..305325e434
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
@@ -0,0 +1,106 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_parallelization.h"
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+bool CanParallelize(const FunctionDef& function,
+ const FunctionLibraryDefinition& library) {
+ if (!function.signature().is_stateful()) return true;
+
+ for (const auto& node : function.node_def()) {
+ const OpDef* op_def;
+ TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def));
+ // Assert is marked as stateful, but it does not have any state (except
+ // changing io). Similarly to CUDA, we do not give guarantee that the
+ // assert operation that would fail would be the first one, so that we can
+ // parallelize it.
+ if (op_def->is_stateful() && op_def->name() != "Assert") return false;
+ }
+
+ return true;
+}
+
+NodeDef MakeParallelMap(const NodeDef& map_node, MutableGraphView* graph) {
+ NodeDef parallel_map = map_node;
+ graph_utils::SetUniqueGraphNodeName("parallel_map", graph->GetGraph(),
+ &parallel_map);
+ parallel_map.set_op("ParallelMapDataset");
+ // TODO(b/114475558): We want to set `num_parallel_calls` to a special value,
+ // so that dynamic tunning will pick the optimal value at runtime. Because
+ // this feature is not yet implemented, we set it to 2, which is the smallest
+ // value that introduces parallelism.
+ auto* num_parallel_calls = graph_utils::AddScalarConstNode(2, graph);
+ parallel_map.add_input(num_parallel_calls->name());
+
+ return parallel_map;
+}
+
+} // namespace
+
+Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+ FunctionLibraryDefinition function_library(OpRegistry::Global(),
+ item.graph.library());
+ auto get_map_node = [](const NodeDef& node) -> const NodeDef* {
+ if (node.op() == "MapDataset") return &node;
+ return nullptr;
+ };
+
+ for (const NodeDef& node : item.graph.node()) {
+ const NodeDef* map_node = get_map_node(node);
+ if (!map_node) continue;
+
+ auto* function =
+ function_library.Find(map_node->attr().at("f").func().name());
+ if (!CanParallelize(*function, function_library)) continue;
+
+ auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph));
+ graph.ReplaceInput(*map_node, *parallel_map);
+
+ // TODO(prazek): we could also remove map functions from library if they
+ // are not used anymore.
+ nodes_to_delete.insert(map_node->name());
+ }
+
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+void MapParallelization::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output,
+ double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(MapParallelization, "map_parallelization");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.h b/tensorflow/core/grappler/optimizers/data/map_parallelization.h
new file mode 100644
index 0000000000..ac9cf7e12a
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// This optimization parallelizes MapDataset when function is stateless.
+class MapParallelization : public CustomGraphOptimizer {
+ public:
+ MapParallelization() = default;
+ ~MapParallelization() override = default;
+
+ string name() const override { return "map_parallelization"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
new file mode 100644
index 0000000000..b2a5d9b6af
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
@@ -0,0 +1,94 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_parallelization.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "MapDataset", {string(input_node_name)},
+ {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", {}},
+ {"output_types", {}}});
+}
+
+const char stateless_fun_name[] = "XTimesTwo";
+const char stateful_fun_name[] = "RandomUniform";
+
+TEST(MapParallelizationTest, ParallelizeSimpleMap) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ MakeMapNode("map1", "range", stateless_fun_name)},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+
+ MapParallelization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
+}
+
+TEST(MapParallelization, ParallelizeAssert) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ MakeMapNode("map1", "range", stateful_fun_name),
+ MakeMapNode("map2", "map1", stateless_fun_name),
+ NDef("cache", "CacheDataset", {"map2", "filename"}, {})},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ test::function::RandomUniform(),
+ });
+
+ MapParallelization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output));
+ EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map1", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
new file mode 100644
index 0000000000..2c36c9b7b3
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
@@ -0,0 +1,111 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h"
+
+#include <string>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+Status ExperimentalImplementationSelector::LoadFunctions(
+ const GraphDef& graph) {
+ lib_info_.reset(new FunctionLibraryApiInfo);
+ TF_RETURN_IF_ERROR(lib_info_->Init(graph.library()));
+ return Status::OK();
+}
+
+Status ExperimentalImplementationSelector::MaybeOptimizeFunctionCall(
+ NodeDef* node_def) const {
+ // There are two ways of calling functions:
+ // 1. By specifying an op name as a function name, or
+ // 2. Via the @defun functional interface, where the real function name
+ // appear as the attribute with type func.
+ std::vector<string> function_attribute_names;
+ for (const auto& attr : node_def->attr()) {
+ if (attr.second.has_func() &&
+ lib_info_->GetApiInfo(attr.second.func().name()) != nullptr) {
+ function_attribute_names.emplace_back(attr.first);
+ }
+ }
+
+ if (function_attribute_names.empty() &&
+ lib_info_->GetApiInfo(node_def->op()) == nullptr) {
+ // A regular op, or a function which has no interface.
+ return Status::OK();
+ }
+
+ string task, device;
+ if (!DeviceNameUtils::SplitDeviceName(node_def->device(), &task, &device)) {
+ return errors::Internal("Could not split device name:", node_def->device());
+ }
+ VLOG(2) << "Op " << node_def->name() << " runs on " << node_def->device()
+ << " = (" << task << ", " << device << ")";
+ DeviceNameUtils::ParsedName parsed_name;
+ DeviceNameUtils::ParseLocalName(device, &parsed_name);
+
+ for (const auto& attr_name : function_attribute_names) {
+ string function_name = node_def->attr().at(attr_name).func().name();
+ string best_function_name;
+ lib_info_->GetBestImplementation(function_name, parsed_name.type,
+ &best_function_name);
+ if (function_name != best_function_name) {
+ node_def->mutable_attr()
+ ->find(attr_name)
+ ->second.mutable_func()
+ ->set_name(best_function_name);
+ }
+ }
+ if (lib_info_->GetApiInfo(node_def->op()) != nullptr) {
+ string best_function_name;
+ lib_info_->GetBestImplementation(node_def->op(), parsed_name.type,
+ &best_function_name);
+ if (node_def->op() != best_function_name) {
+ node_def->set_op(best_function_name);
+ }
+ }
+ return Status::OK();
+}
+
+Status ExperimentalImplementationSelector::SelectImplementation(
+ GraphDef* graph) const {
+ for (int k = 0; k < graph->node_size(); ++k)
+ TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph->mutable_node(k)));
+
+ return Status::OK();
+}
+
+Status ExperimentalImplementationSelector::Optimize(Cluster* cluster,
+ const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ *optimized_graph = item.graph;
+ TF_RETURN_IF_ERROR(LoadFunctions(*optimized_graph));
+ return SelectImplementation(optimized_graph);
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h
new file mode 100644
index 0000000000..82f7473a14
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h
@@ -0,0 +1,115 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// -- EXPERIMENTAL --
+// This transformation replaces function calls by the appropriate function
+// definition based on properties of the runtime system. For instance,
+// we may choose one implementation over another if we have a GPU with
+// enough memory available.
+//
+// It is a way for the programmer to specify alternative implementations
+// of the same functionality in the graph, and let TensorFlow pick the
+// most appropriate one at runtime.
+//
+// For instance, the python code might specify:
+// @Defun(tf.float32,
+// experimental_api_implements='plus_one',
+// experimental_api_preferred_device='GPU')
+// def plus_one_gpu(x): return x + 1.0
+//
+// @Defun(tf.float32,
+// experimental_api_implements='plus_one')
+// def plus_one_reference_implementation(x): return x + 1.0
+// input = tf.constant(2.0, dtype=tf.float32)
+//
+// z = plus_one_reference_implementation(input)
+// z = plus_one_gpu(input)
+// print(sess.run(z))
+//
+// At runtime, we will trim either `plus_one_gpu` or
+// `plus_one_reference_implementation` based on the availability of the GPU.
+//
+// Available annotations:
+// - experimental_api_implements(string): all functions mapping to the same
+// string can be interchanged. For now, all functions must have the same
+// signature and overloads are not allowed. Defuns within defuns are
+// allowed.
+// - experimental_api_preferred_device(string): sets which device is preferred.
+class ExperimentalImplementationSelector : public CustomGraphOptimizer {
+ public:
+ ExperimentalImplementationSelector() = default;
+ ~ExperimentalImplementationSelector() override = default;
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+ string name() const override {
+ return "experimental_implementation_selector";
+ }
+
+ // This call is not thread-safe.
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override;
+
+ // Does not take any feedback.
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+
+ private:
+ Status LoadFunctions(const GraphDef& graph);
+ Status MaybeOptimizeFunctionCall(NodeDef* node_def) const;
+
+ // Finds all call sites for functions, then replace with the appropriate
+ // implementation.
+ // There are two ways of calling functions:
+ // 1. By specifying an op name as a function name, and
+ // 2. Via the functional interface, where the function name appears as an
+ // Attr.
+ //
+ // There may be multiple call sites for a given function. The function body
+ // may call into another function, so a function might have to be duplicated.
+ // For simplicity, we do not change function bodies. Also, we do not change
+ // gradients.
+ Status SelectImplementation(GraphDef* graph) const;
+
+ std::unique_ptr<FunctionLibraryApiInfo> lib_info_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExperimentalImplementationSelector);
+};
+
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
new file mode 100644
index 0000000000..3f1ebefac6
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
@@ -0,0 +1,138 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+constexpr char CpuDevice[] = "/device:CPU:0";
+constexpr char GpuDevice[] = "/device:GPU:0";
+
+class ExperimentalImplementationSelectorTest : public GrapplerTest {};
+
+TEST_F(ExperimentalImplementationSelectorTest, NoUpdate) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {CpuDevice});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ std::unique_ptr<CustomGraphOptimizer> optimizer(
+ new ExperimentalImplementationSelector);
+ ASSERT_NE(nullptr, optimizer);
+ TF_ASSERT_OK(optimizer->Init());
+
+ GraphDef output;
+ const Status status = optimizer->Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ // This is a trivial graph so there is nothing to update.
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+}
+
+TEST_F(ExperimentalImplementationSelectorTest, SwapImplementation) {
+ using test::function::NDef;
+ auto cpu_def = test::function::XTimesTwo();
+ auto* func_attr = cpu_def.mutable_attr();
+ (*func_attr)["experimental_api_implements"].set_s("times_two");
+ (*func_attr)["experimental_api_preferred_device"].set_s("CPU");
+
+ auto gpu_def = test::function::XAddX();
+ auto* func2_attr = gpu_def.mutable_attr();
+ (*func2_attr)["experimental_api_implements"].set_s("times_two");
+ (*func2_attr)["experimental_api_preferred_device"].set_s("GPU");
+
+ ExperimentalImplementationSelector optimizer;
+ GraphDef output;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, GpuDevice),
+ NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
+ NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice),
+ NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
+ NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)},
+ // FunctionLib
+ {cpu_def, gpu_def});
+
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(output.node_size(), 5);
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "y1") {
+ // Make sure the implementation has been swapped to use the GPU version.
+ EXPECT_EQ("XAddX", node.op());
+ } else if (node.name() == "y2") {
+ // Make sure the implementation is not changed.
+ EXPECT_EQ("XTimesTwo", node.op());
+ }
+ }
+}
+
+TEST_F(ExperimentalImplementationSelectorTest, SwapImplementationEval) {
+ using test::function::NDef;
+ auto cpu_def = test::function::XTimesTwo();
+ auto* func_attr = cpu_def.mutable_attr();
+ (*func_attr)["experimental_api_implements"].set_s("random_boost");
+ (*func_attr)["experimental_api_preferred_device"].set_s("CPU");
+
+ auto gpu_def = test::function::XTimesFour();
+ auto* func2_attr = gpu_def.mutable_attr();
+ (*func2_attr)["experimental_api_implements"].set_s("random_boost");
+ (*func2_attr)["experimental_api_preferred_device"].set_s("GPU");
+
+ ExperimentalImplementationSelector optimizer;
+ GraphDef output;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, CpuDevice),
+ NDef("y", "XTimesFour", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
+ NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, CpuDevice)},
+ // FunctionLib
+ {cpu_def, gpu_def});
+
+ const Tensor input = test::AsScalar<float>(1.0f);
+ item.fetch = {"z"};
+ item.feed.emplace_back("x", input);
+
+ const auto four_times_boosted_tensor = EvaluateFetchNodes(item);
+ test::ExpectTensorEqual<float>(four_times_boosted_tensor[0],
+ test::AsScalar<float>(4.0f));
+
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+ GrapplerItem optimized(item, std::move(output));
+ const auto twice_boosted_tensor = EvaluateFetchNodes(optimized);
+ test::ExpectTensorEqual<float>(twice_boosted_tensor[0],
+ test::AsScalar<float>(2.0f));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_api_info.cc b/tensorflow/core/grappler/optimizers/function_api_info.cc
new file mode 100644
index 0000000000..798e0f6fd5
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/function_api_info.cc
@@ -0,0 +1,167 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+
+#include <string>
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+FunctionApiInfo::FunctionApiInfo() {}
+FunctionApiInfo::~FunctionApiInfo() {}
+
+Status FunctionApiInfo::Init(const FunctionDef& function_def) {
+ for (const auto& attr : function_def.attr()) {
+ if (attr.first == "experimental_api_preferred_device") {
+ preferred_device_ = attr.second.s();
+ }
+ if (attr.first == "experimental_api_implements") {
+ interface_name_ = attr.second.s();
+ }
+ }
+ if (interface_name_.empty() && !preferred_device_.empty()) {
+ return errors::InvalidArgument(
+ "Function '", function_def.signature().name(),
+ "' has a preferred device, but does not implement an interface");
+ }
+ return Status::OK();
+}
+
+const string& FunctionApiInfo::preferred_device() const {
+ return preferred_device_;
+}
+
+const string& FunctionApiInfo::interface_name() const {
+ return interface_name_;
+}
+
+FunctionLibraryApiInfo::FunctionLibraryApiInfo() {}
+FunctionLibraryApiInfo::~FunctionLibraryApiInfo() {}
+
+namespace {
+bool IsSameSignature(const FunctionDef& f1, const FunctionDef& f2) {
+ if (f1.ret().size() != f2.ret().size()) return false;
+ const auto& sig1 = f1.signature();
+ const auto& sig2 = f2.signature();
+ // Functions have positional semantics, so we don't check for names.
+ if (sig1.input_arg_size() != sig2.input_arg_size()) return false;
+ for (int k = 0; k < sig1.input_arg_size(); ++k) {
+ const OpDef::ArgDef& arg1 = sig1.input_arg(k);
+ const OpDef::ArgDef& arg2 = sig2.input_arg(k);
+ if (arg1.type() != arg2.type()) return false;
+ if (arg1.type_attr() != arg2.type_attr()) return false;
+ if (arg1.number_attr() != arg2.number_attr()) return false;
+ if (arg1.type_list_attr() != arg2.type_list_attr()) return false;
+ if (arg1.is_ref() != arg2.is_ref()) return false;
+ }
+ return true;
+}
+
+Status ValidateSignature(const string& interface_name,
+ const std::vector<const FunctionDef*>& equiv_funcs) {
+ if (equiv_funcs.size() < 2) return Status::OK();
+ for (size_t k = 1; k < equiv_funcs.size(); ++k) {
+ if (!IsSameSignature(*equiv_funcs[0], *equiv_funcs[k]))
+ return errors::InvalidArgument(
+ "Functions '", equiv_funcs[0]->signature().name(), "' and '",
+ equiv_funcs[k]->signature().name(), "' both implement '",
+ interface_name, "' but their signatures do not match.");
+ }
+ return Status::OK();
+}
+
+Status ValidateSignatures(
+ const std::unordered_map<string, std::vector<const FunctionDef*>>&
+ intf_to_func) {
+ for (const auto& item : intf_to_func)
+ TF_RETURN_IF_ERROR(ValidateSignature(item.first, item.second));
+ return Status::OK();
+}
+} // namespace
+
+Status FunctionLibraryApiInfo::Init(
+ const FunctionDefLibrary& function_library) {
+ std::unordered_map<string, std::vector<const FunctionDef*>> intf_to_func;
+ for (const auto& function : function_library.function()) {
+ std::unique_ptr<FunctionApiInfo> func_info(new FunctionApiInfo);
+ TF_RETURN_IF_ERROR(func_info->Init(function));
+ // Ignore the function if it does not implement any interface.
+ if (func_info->interface_name().empty()) continue;
+
+ const string& function_name = function.signature().name();
+ const string& interface_name = func_info->interface_name();
+ func_to_intf_[function_name] = interface_name;
+ intf_to_funcs_[interface_name].emplace_back(function_name);
+ intf_to_func[interface_name].emplace_back(&function);
+ func_info_[function_name] = std::move(func_info);
+ }
+ TF_RETURN_IF_ERROR(ValidateSignatures(intf_to_func));
+ return Status::OK();
+}
+
+void FunctionLibraryApiInfo::GetEquivalentImplementations(
+ const string& function_name, std::vector<string>* other_names) const {
+ const auto intf_it = func_to_intf_.find(function_name);
+ // The function does not implement any interface.
+ if (intf_it == func_to_intf_.end()) return;
+ CHECK(!intf_it->second.empty()) << "Function " << function_name
+ << "should at least implement 1 interface.";
+ const auto it = intf_to_funcs_.find(intf_it->second);
+ CHECK(it != intf_to_funcs_.end())
+ << "Function " << function_name << " maps to " << intf_it->second
+ << " but no reverse mapping was found";
+ CHECK_GE(it->second.size(), 1) << "Class " << it->first << " is empty";
+ other_names->reserve(it->second.size() - 1);
+ for (const auto& other_name : it->second) {
+ if (other_name == function_name) continue;
+ other_names->emplace_back(other_name);
+ }
+}
+
+void FunctionLibraryApiInfo::GetBestImplementation(
+ const string& function_name, const string& device,
+ string* best_func_name) const {
+ CHECK(best_func_name != nullptr);
+ const auto func_it = func_to_intf_.find(function_name);
+ if (func_it == func_to_intf_.end()) return;
+
+ const auto it = intf_to_funcs_.find(func_it->second);
+ // No function found for the given interface.
+ if (it == intf_to_funcs_.end()) return;
+ for (const auto& func_name : it->second) {
+ const auto func_api_info = func_info_.find(func_name)->second.get();
+ if (func_api_info->preferred_device() == device) {
+ best_func_name->assign(func_name);
+ return;
+ }
+ }
+ // Didn't find a function with the match device name, choose the first one
+ // among all the available functions.
+ best_func_name->assign(it->second.front());
+}
+
+const FunctionApiInfo* FunctionLibraryApiInfo::GetApiInfo(
+ const string& function_name) const {
+ const auto it = func_info_.find(function_name);
+ if (it == func_info_.end()) return nullptr;
+ return it->second.get();
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_api_info.h b/tensorflow/core/grappler/optimizers/function_api_info.h
new file mode 100644
index 0000000000..412687c58c
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/function_api_info.h
@@ -0,0 +1,80 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+class FunctionApiInfo {
+ public:
+ FunctionApiInfo();
+ virtual ~FunctionApiInfo();
+
+ Status Init(const FunctionDef& function_def);
+
+ const string& interface_name() const;
+ const string& preferred_device() const;
+
+ private:
+ string interface_name_;
+ string preferred_device_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionApiInfo);
+};
+
+// A collection of information for function and the interface it implements.
+// A interface is a well defined math operation, eg I1 = 2 * x + y. Multiple
+// functions could implement the same interface with different behavior based on
+// different hardware condition and limits,
+// eg F1 = math_ops.add(math_ops.add(x, x), y), or
+// F2 = math_ops.add(math_ops.matmul(x, 2), y).
+class FunctionLibraryApiInfo {
+ public:
+ FunctionLibraryApiInfo();
+ virtual ~FunctionLibraryApiInfo();
+ // Populate the internal field for the functions within the function_library.
+ Status Init(const FunctionDefLibrary& function_library);
+
+ void GetEquivalentImplementations(const string& function_name,
+ std::vector<string>* other_names) const;
+
+ void GetBestImplementation(const string& function_name, const string& device,
+ string* best_func_name) const;
+
+ const FunctionApiInfo* GetApiInfo(const string& function_name) const;
+
+ private:
+ // Map between function name to function details.
+ std::unordered_map<string, std::unique_ptr<FunctionApiInfo>> func_info_;
+ // Map between function name to interface name.
+ std::unordered_map<string, string> func_to_intf_;
+ // Map between interface name to function names.
+ std::unordered_map<string, std::vector<string>> intf_to_funcs_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryApiInfo);
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
diff --git a/tensorflow/core/grappler/optimizers/function_api_info_test.cc b/tensorflow/core/grappler/optimizers/function_api_info_test.cc
new file mode 100644
index 0000000000..582890d3e3
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/function_api_info_test.cc
@@ -0,0 +1,160 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+void SetArg(const string& name, const string& type_name,
+ OpDef::ArgDef* arg_def) {
+ arg_def->set_name(name);
+ arg_def->set_type_attr(type_name);
+}
+
+typedef std::pair<string, string> ArgSpec; // name, type.
+
+void SetArgs(const std::vector<ArgSpec>& args_spec, OpDef* sig) {
+ for (const auto& arg_spec : args_spec)
+ SetArg(arg_spec.first, arg_spec.second, sig->add_input_arg());
+ SetArg("output", "float32", sig->add_output_arg());
+}
+
+void PopulateFunction(const string& name, const string& api_interface_name,
+ const string& preferred_device,
+ const std::vector<ArgSpec>& input_args,
+ FunctionDef* func_def) {
+ OpDef* sig = func_def->mutable_signature();
+ sig->set_name(name);
+
+ SetArgs(input_args, sig);
+
+ if (!api_interface_name.empty() || !preferred_device.empty()) {
+ auto* func_attr = func_def->mutable_attr();
+ if (!api_interface_name.empty())
+ (*func_attr)["experimental_api_implements"].set_s(api_interface_name);
+ if (!preferred_device.empty())
+ (*func_attr)["experimental_api_preferred_device"].set_s(preferred_device);
+ }
+}
+
+void PopulateSampleLibrary(const bool mismatch_args,
+ FunctionDefLibrary* func_lib) {
+ const std::vector<ArgSpec> func_args{{"in1", "float32"}, {"in2", "int32"}};
+ const std::vector<ArgSpec> func_wrong_args{{"in1", "int32"},
+ {"in2", "int32"}};
+ PopulateFunction("DoStuffCpu", "DoStuff", "CPU", func_args,
+ func_lib->add_function());
+ PopulateFunction("DoStuffGpu", "DoStuff", "GPU",
+ mismatch_args ? func_wrong_args : func_args,
+ func_lib->add_function());
+ PopulateFunction("DoThings", "DoThings", "", func_args,
+ func_lib->add_function());
+ PopulateFunction("OneOff", "", "", func_args, func_lib->add_function());
+ PopulateFunction("AnotherOneOff", "", "", func_args,
+ func_lib->add_function());
+}
+
+bool CheckEquivImpl(const FunctionLibraryApiInfo& lib_api_info,
+ const string& func_name,
+ const std::vector<string>& expected_other) {
+ std::vector<string> other_impl;
+ lib_api_info.GetEquivalentImplementations(func_name, &other_impl);
+ const std::unordered_set<string> actual(other_impl.begin(), other_impl.end());
+ const std::unordered_set<string> expected(expected_other.begin(),
+ expected_other.end());
+ return actual == expected;
+}
+
+bool CheckGetBestImpl(const FunctionLibraryApiInfo& lib_api_info,
+ const string& function_name, const string& device,
+ const string& expected_function_name) {
+ string best_function_name;
+ lib_api_info.GetBestImplementation(function_name, device,
+ &best_function_name);
+
+ return best_function_name == expected_function_name;
+}
+
+string GetInterfaceName(const FunctionLibraryApiInfo& lib_api_info,
+ const string& func_name) {
+ auto* info = lib_api_info.GetApiInfo(func_name);
+ CHECK_NOTNULL(info);
+ return info->interface_name();
+}
+
+string GetPreferredDevice(const FunctionLibraryApiInfo& lib_api_info,
+ const string& func_name) {
+ auto* info = lib_api_info.GetApiInfo(func_name);
+ CHECK_NOTNULL(info);
+ return info->preferred_device();
+}
+
+TEST(FunctionApiInfoTest, ParseTags) {
+ FunctionDefLibrary func_lib;
+ PopulateSampleLibrary(/* mismatch_args */ false, &func_lib);
+ FunctionLibraryApiInfo lib_api_info;
+ TF_ASSERT_OK(lib_api_info.Init(func_lib));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffCpu", {"DoStuffGpu"}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffGpu", {"DoStuffCpu"}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "Undefined", {}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "OneOff", {}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "AnotherOneOff", {}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoThings", {}));
+
+ EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffCpu"));
+ EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffGpu"));
+ EXPECT_EQ("DoThings", GetInterfaceName(lib_api_info, "DoThings"));
+
+ EXPECT_EQ("CPU", GetPreferredDevice(lib_api_info, "DoStuffCpu"));
+ EXPECT_EQ("GPU", GetPreferredDevice(lib_api_info, "DoStuffGpu"));
+ EXPECT_EQ("", GetPreferredDevice(lib_api_info, "DoThings"));
+
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffCpu", "CPU", "DoStuffCpu"));
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffCpu", "GPU", "DoStuffGpu"));
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffGpu", "CPU", "DoStuffCpu"));
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffGpu", "GPU", "DoStuffGpu"));
+
+ EXPECT_TRUE(CheckGetBestImpl(lib_api_info, "DoThings", "GPU", "DoThings"));
+ // TPU impl is not available, choose the first one available which is the CPU.
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffGpu", "TPU", "DoStuffCpu"));
+}
+
+TEST(FunctionApiInfoTest, MismatchedArguments) {
+ FunctionDefLibrary func_lib;
+ PopulateSampleLibrary(/* mismatch_args */ true, &func_lib);
+ FunctionLibraryApiInfo lib_api_info;
+ const Status ret = lib_api_info.Init(func_lib);
+ EXPECT_FALSE(ret.ok());
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index 91794cefe5..c775a26914 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -1071,11 +1071,13 @@ static bool IdentifySwappingCandidates(
// ensure that swapping the tensor back in won't recreate the memory
// bottleneck. Last but not least, we want the tensor to have as few
// remaining uses as possible.
+ //
+ // Note that we must perform the arithmetic inexactly as "double", since
+ // the values do not fit into any integral type.
mem_info.fitness =
- MathUtil::IPow((earliest_use - peak_time).count(), 2);
- mem_info.fitness /= MathUtil::IPow(mem_info.uses_left.size(), 2);
- mem_info.fitness +=
- MathUtil::IPow((allocation_time - peak_time).count(), 2);
+ MathUtil::IPow<double>((earliest_use - peak_time).count(), 2) /
+ MathUtil::IPow<double>(mem_info.uses_left.size(), 2) +
+ MathUtil::IPow<double>((allocation_time - peak_time).count(), 2);
mem_info.fitness = -mem_info.fitness;
mem_state.push_back(mem_info);
}
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 5fd34efeb1..1ed1b22931 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/optimizers/debug_stripper.h"
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h"
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
@@ -72,6 +73,16 @@ bool IsRunOnceOptimizer(const string& name) {
name == "loop_optimizer";
}
+// Check if the graphdef contains nodes that indicate TPU execution.
+bool IsTPUGraphDef(const GraphDef& def) {
+ for (auto node : def.node()) {
+ if (node.op() == "TPUCompile" || node.op() == "TPUPartitionedCall") {
+ return true;
+ }
+ }
+ return false;
+}
+
} // namespace
#define MK_OPT(NAME, VALUE) \
@@ -156,7 +167,7 @@ Status MetaOptimizer::InitializeOptimizers(
optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
}
- return Status::OK();
+ return InitializeCustomGraphOptimizers(optimizers);
}
Status MetaOptimizer::InitializeOptimizersByName(
@@ -180,9 +191,24 @@ Status MetaOptimizer::InitializeOptimizersByName(
VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
}
}
+ return InitializeCustomGraphOptimizers(optimizers);
+}
+
+Status MetaOptimizer::InitializeCustomGraphOptimizers(
+ std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
for (const auto& optimizer_config : cfg_.custom_optimizers()) {
- auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
- optimizer_config.name());
+ // Initialize the ExperimentalImplementationSelector here instead of
+ // CustomizeOptimizer registry, due the static link issue in TensorRT for
+ // double registry.
+ // TODO(laigd): Remove this hack and change it back to use the registry once
+ // the duplicate static import issue is fixed.
+ std::unique_ptr<CustomGraphOptimizer> custom_optimizer;
+ if (optimizer_config.name() == "ExperimentalImplementationSelector") {
+ custom_optimizer.reset(new ExperimentalImplementationSelector());
+ } else {
+ custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
+ optimizer_config.name());
+ }
if (custom_optimizer) {
VLOG(2) << "Registered custom configurable graph optimizer: "
<< optimizer_config.name();
@@ -208,7 +234,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
}
std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
- if (cfg_.optimizers().empty() && cfg_.custom_optimizers().empty()) {
+ if (cfg_.optimizers().empty()) {
TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers));
} else {
TF_RETURN_IF_ERROR(InitializeOptimizersByName(&optimizers));
@@ -326,10 +352,25 @@ Status MetaOptimizer::RunOptimizer(
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
+ LOG(INFO) << "Starting optimization for grappler item: " << item.id;
optimization_results_.clear();
// 1. Optimize main graph
TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph));
+ VLOG(1) << "Optimized main graph.";
+
+ // Skip optimizing functions if this is a TPU graph. Currently, Grappler
+ // passes do not handle TPU functions correctly in a variety of ways (Note
+ // that due to the pre-placement TPU graph rewriting passes, the TPU-related
+ // ops are encapsulated away into functions). For example, TPU graphs contain
+ // TPUReplicateMetadata node that carries relevant TPU metadata and Grappler
+ // passes could prune that away. Grappler passes could also cause issues
+ // around shape inference. Since the desired and existing behavior is to not
+ // optimize TPU functions with Grappler, this check preserves that.
+ if (IsTPUGraphDef(*optimized_graph)) {
+ VLOG(2) << "Skipping optimizing funcs for TPU graphs";
+ return Status::OK();
+ }
// 2. Optimize function library
FunctionLibraryDefinition flib(OpRegistry::Global(),
@@ -393,7 +434,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
}
- VLOG(3) << "Optimized " << optimized_funcs.size()
+ VLOG(1) << "Optimized " << optimized_funcs.size()
<< " functions: " << str_util::Join(optimized_funcs, ", ");
return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index 151a54cbdf..831c5e37c0 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -52,6 +52,9 @@ class MetaOptimizer : public GraphOptimizer {
// Initialize active optimizers from RewriterConfig optimizer names.
Status InitializeOptimizersByName(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
+ // Initialize active optimizers from RewriterConfig.custom_optimizers.
+ Status InitializeCustomGraphOptimizers(
+ std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
// Run optimization pass over a single GrapplerItem. Meta optimizer might run
// multiple such passes: 1) for the main graph 2) for the function library
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index 9a03c7dfef..e74e0f7501 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -64,6 +64,13 @@ bool TestOptimizer::optimized_;
REGISTER_GRAPH_OPTIMIZER(TestOptimizer);
+class TestGraphOptimizer : public TestOptimizer {
+ public:
+ string name() const override { return "test_graph_optimizer"; }
+};
+
+REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer);
+
class MetaOptimizerTest : public GrapplerTest {};
TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
@@ -83,6 +90,27 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
EXPECT_TRUE(TestOptimizer::IsOptimized());
}
+TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ TestOptimizer::SetOptimized(false);
+ TestGraphOptimizer::SetOptimized(false);
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers("TestOptimizer");
+ auto customGraphOptimizer = rewriter_config.add_custom_optimizers();
+ customGraphOptimizer->set_name("TestGraphOptimizer");
+ rewriter_config.set_min_graph_nodes(-1);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_TRUE(TestOptimizer::IsOptimized());
+ EXPECT_TRUE(TestGraphOptimizer::IsOptimized());
+}
+
TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
GrapplerItem item;
@@ -98,6 +126,24 @@ TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
TF_EXPECT_OK(status);
}
+TEST_F(MetaOptimizerTest, RunToggleOptimizersAndCustomGraphOptimizerTwice) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ RewriterConfig rewriter_config;
+ auto customGraphOptimizer = rewriter_config.add_custom_optimizers();
+ customGraphOptimizer->set_name("TestGraphOptimizer");
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+ rewriter_config.set_min_graph_nodes(-1);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_TRUE(TestGraphOptimizer::IsOptimized());
+}
+
TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index a2c363ea6e..a428aea7f5 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -304,21 +304,21 @@ Status GrapplerFunctionItemInstantiation::GetArgType(
}
GrapplerFunctionItem::GrapplerFunctionItem(
- const string& func_name, const string& description,
- const AttrValueMap& func_attr,
- const std::vector<InputArgExpansion>& input_arg_expansions,
- const std::vector<OutputArgExpansion>& output_arg_expansions,
- const std::vector<string>& keep_nodes, const int graph_def_version,
- bool is_stateful, GraphDef&& function_body)
- : description_(description),
- func_attr_(func_attr),
- input_arg_expansions_(input_arg_expansions),
- output_arg_expansions_(output_arg_expansions),
+ string func_name, string description, AttrValueMap func_attr,
+ std::vector<InputArgExpansion> input_arg_expansions,
+ std::vector<OutputArgExpansion> output_arg_expansions,
+ std::vector<string> keep_nodes, const int graph_def_version,
+ const bool is_stateful, GraphDef&& function_body)
+ : description_(std::move(description)),
+ func_attr_(std::move(func_attr)),
+ input_arg_expansions_(std::move(input_arg_expansions)),
+ output_arg_expansions_(std::move(output_arg_expansions)),
is_stateful_(is_stateful) {
- id = func_name;
- keep_ops = keep_nodes;
- // Swap the graph body.
- graph.Swap(&function_body);
+ // Move assign GrapplerItem members.
+ keep_ops = std::move(keep_nodes);
+ id = std::move(func_name);
+ graph = std::move(function_body);
+
graph.mutable_versions()->set_producer(graph_def_version);
// Fill the feed nodes with input placeholders.
for (const InputArgExpansion& input_arg : input_arg_expansions_) {
@@ -598,8 +598,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
*item = GrapplerFunctionItem(
/*func_name=*/signature.name(), /*description=*/signature.description(),
/*func_attr=*/AttrValueMap(func.attr().begin(), func.attr().end()),
- inputs, outputs, keep_nodes, graph_def_version, is_stateful,
- std::move(function_body));
+ std::move(inputs), std::move(outputs), std::move(keep_nodes),
+ graph_def_version, is_stateful, std::move(function_body));
return Status::OK();
}
diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h
index 61588ceb83..733caf325f 100644
--- a/tensorflow/core/grappler/utils/functions.h
+++ b/tensorflow/core/grappler/utils/functions.h
@@ -136,13 +136,12 @@ class GrapplerFunctionItemInstantiation {
class GrapplerFunctionItem : public GrapplerItem {
public:
GrapplerFunctionItem() = default;
- GrapplerFunctionItem(
- const string& func_name, const string& description,
- const AttrValueMap& func_attr,
- const std::vector<InputArgExpansion>& input_arg_expansions,
- const std::vector<OutputArgExpansion>& output_arg_expansions,
- const std::vector<string>& keep_nodes, const int versions,
- bool is_stateful, GraphDef&& function_body);
+ GrapplerFunctionItem(string func_name, string description,
+ AttrValueMap func_attr,
+ std::vector<InputArgExpansion> input_arg_expansions,
+ std::vector<OutputArgExpansion> output_arg_expansions,
+ std::vector<string> keep_nodes, int graph_def_version,
+ bool is_stateful, GraphDef&& function_body);
const string& description() const;
diff --git a/tensorflow/core/grappler/utils/scc.h b/tensorflow/core/grappler/utils/scc.h
index 4fb7aab647..ceb9f5dbf2 100644
--- a/tensorflow/core/grappler/utils/scc.h
+++ b/tensorflow/core/grappler/utils/scc.h
@@ -24,15 +24,16 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-// Compute modified strongly connected components:
+// Computes modified strongly connected components:
// All nodes that are not part of a loop are assigned the special -1 id
// All nodes that are part of at least one loop are assigned a positive
// component id: if 2 nodes v and w are reachable from one another (i.e. if they
// belong to the same scc), they'll be assigned the same id, otherwise they'll
-// be assigned distinct ids. Returns the number of distinct ids.
+// be assigned distinct ids. *num_components is set to the number of distinct
+// ids.
void StronglyConnectedComponents(
const GraphDef& graph, std::unordered_map<const NodeDef*, int>* components,
- int* num_ids);
+ int* num_components);
// Returns the number of individual loops present in the graph, and populate the
// 'loops' argument with the collection of loops (denoted by their loop ids) a
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 633fe9ab77..7aa1169061 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -30,6 +30,7 @@ load(
"//tensorflow:tensorflow.bzl",
"if_android",
"tf_cc_test",
+ "tf_cc_test_mkl",
"tf_cc_tests",
"tf_cc_binary",
"tf_copts",
@@ -50,6 +51,10 @@ load(
"tf_kernel_tests_linkstatic",
)
load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
+load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
"if_mkl_ml",
@@ -643,14 +648,7 @@ cc_library(
":split_v_op",
":strided_slice_op",
":tile_ops",
- ] + if_mkl(
- [
- ":mkl_transpose_op",
- ],
- [
- ":transpose_op",
- ],
- ) + [
+ ":transpose_op",
":unique_op",
":unpack_op",
":unravel_index_op",
@@ -893,24 +891,13 @@ tf_kernel_library(
deps = ARRAY_DEPS,
)
-if_mkl(
- [tf_mkl_kernel_library(
- name = "mkl_transpose_op",
- srcs = [
- "mkl_transpose_op.cc",
- "transpose_op.cc",
- ],
- hdrs = ["transpose_op.h"],
- deps = ARRAY_DEPS + mkl_deps(),
- )],
- [tf_kernel_library(
- name = "transpose_op",
- srcs = [
- "transpose_op.cc",
- ],
- hdrs = ["transpose_op.h"],
- deps = ARRAY_DEPS,
- )],
+tf_kernel_library(
+ name = "transpose_op",
+ srcs = [
+ "transpose_op.cc",
+ ],
+ hdrs = ["transpose_op.h"],
+ deps = ARRAY_DEPS + if_mkl([":mkl_transpose_op"]),
)
tf_kernel_library(
@@ -1123,7 +1110,7 @@ tf_cuda_cc_test(
name = "depthwise_conv_ops_test",
size = "small",
srcs = ["depthwise_conv_ops_test.cc"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":conv_ops",
":image",
@@ -2296,6 +2283,31 @@ tf_cc_tests(
],
)
+cc_library(
+ name = "eigen_benchmark",
+ testonly = 1,
+ hdrs = [
+ "eigen_benchmark.h",
+ ":eigen_helpers",
+ ],
+ deps = [
+ "//tensorflow/core:framework",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_cc_test(
+ name = "eigen_benchmark_cpu_test",
+ srcs = ["eigen_benchmark_cpu_test.cc"],
+ deps = [
+ ":eigen_benchmark",
+ ":eigen_helpers",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//third_party/eigen3",
+ ],
+)
+
tf_cc_tests(
name = "basic_ops_benchmark_test",
size = "small",
@@ -4196,6 +4208,7 @@ cc_library(
"hinge-loss.h",
"logistic-loss.h",
"loss.h",
+ "poisson-loss.h",
"smooth-hinge-loss.h",
"squared-loss.h",
],
@@ -4496,6 +4509,25 @@ tf_kernel_library(
deps = STRING_DEPS,
)
+tf_cc_test(
+ name = "substr_op_test",
+ size = "small",
+ srcs = ["substr_op_test.cc"],
+ deps = [
+ ":substr_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
tf_kernel_library(
name = "as_string_op",
prefix = "as_string_op",
@@ -5176,6 +5208,7 @@ filegroup(
"fifo_queue.cc",
"fifo_queue_op.cc",
"fused_batch_norm_op.cc",
+ "listdiff_op.cc",
"population_count_op.cc",
"population_count_op.h",
"winograd_transform.h",
@@ -6200,6 +6233,26 @@ tf_mkl_kernel_library(
] + mkl_deps(),
)
+tf_cc_test_mkl(
+ name = "mkl_conv_ops_test",
+ size = "small",
+ srcs = ["mkl_conv_ops_test.cc"],
+ deps = [
+ ":ops_testutil",
+ ":ops_util",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
tf_mkl_kernel_library(
name = "mkl_tfconv_op",
prefix = "mkl_tfconv",
@@ -6325,6 +6378,15 @@ tf_mkl_kernel_library(
deps = NN_DEPS + mkl_deps() + [":cwise_op"],
)
+tf_mkl_kernel_library(
+ name = "mkl_transpose_op",
+ srcs = [
+ "mkl_transpose_op.cc",
+ ],
+ hdrs = ["transpose_op.h"],
+ deps = ARRAY_DEPS + mkl_deps(),
+)
+
# NOTE(lespeholt): This rule is deprecated, please use:
# tensorflow/core/util/batch_util.h
cc_library(
diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc
index 7b28c8e91f..e15ea82e7d 100644
--- a/tensorflow/core/kernels/bias_op.cc
+++ b/tensorflow/core/kernels/bias_op.cc
@@ -134,8 +134,8 @@ class BiasOp : public BinaryOp<T> {
if (data_format_ == FORMAT_NCHW) {
int32 batch, height, width, channel;
GetBiasValueDims(input, data_format_, &batch, &height, &width, &channel);
- Eigen::DSizes<int32, 4> four_dims(1, channel, 1, 1);
- Eigen::DSizes<int32, 4> broad_cast_dims(batch, 1, height, width);
+ Eigen::DSizes<Eigen::Index, 4> four_dims(1, channel, 1, 1);
+ Eigen::DSizes<Eigen::Index, 4> broad_cast_dims(batch, 1, height, width);
const Device& d = context->eigen_device<Device>();
output->tensor<T, 4>().device(d) =
input.tensor<T, 4>() +
@@ -247,14 +247,14 @@ class BiasGradOp : public OpKernel {
OP_REQUIRES(context, output_backprop.dims() == 4,
errors::InvalidArgument(
"NCHW format supports only 4D input/output tensor."));
- Eigen::DSizes<int, 4> four_dims(batch, channel, height, width);
+ Eigen::DSizes<Eigen::Index, 4> four_dims(batch, channel, height, width);
#ifdef EIGEN_HAS_INDEX_LIST
using idx0 = Eigen::type2index<0>;
using idx2 = Eigen::type2index<2>;
using idx3 = Eigen::type2index<3>;
Eigen::IndexList<idx0, idx2, idx3> reduction_axes;
#else
- Eigen::array<int, 3> reduction_axes = {0, 2, 3};
+ Eigen::array<Eigen::Index, 3> reduction_axes = {0, 2, 3};
#endif
output->template flat<T>().device(context->eigen_device<Device>()) =
output_backprop.flat<T>()
@@ -263,11 +263,12 @@ class BiasGradOp : public OpKernel {
.sum(reduction_axes)
.template cast<T>(); // End of code by intel_tf.
} else {
- Eigen::DSizes<int, 2> two_dims(batch * height * width, channel);
+ Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width,
+ channel);
#ifdef EIGEN_HAS_INDEX_LIST
Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
#else
- Eigen::array<int, 1> reduction_axis = {0};
+ Eigen::array<Eigen::Index, 1> reduction_axis = {0};
#endif
output->template flat<T>().device(context->eigen_device<Device>()) =
output_backprop.flat<T>()
diff --git a/tensorflow/core/kernels/boosted_trees/BUILD b/tensorflow/core/kernels/boosted_trees/BUILD
index 4910021c63..4e8bfa02fc 100644
--- a/tensorflow/core/kernels/boosted_trees/BUILD
+++ b/tensorflow/core/kernels/boosted_trees/BUILD
@@ -15,7 +15,9 @@ load(
tf_proto_library(
name = "boosted_trees_proto",
- srcs = ["boosted_trees.proto"],
+ srcs = [
+ "boosted_trees.proto",
+ ],
cc_api_version = 2,
visibility = ["//visibility:public"],
)
@@ -87,9 +89,21 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "quantile_ops",
+ srcs = ["quantile_ops.cc"],
+ deps = [
+ "//tensorflow/core:boosted_trees_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles",
+ ],
+)
+
+tf_kernel_library(
name = "boosted_trees_ops",
deps = [
":prediction_ops",
+ ":quantile_ops",
":resource_ops",
":stats_ops",
":training_ops",
diff --git a/tensorflow/core/kernels/boosted_trees/quantile_ops.cc b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
new file mode 100644
index 0000000000..d1840941c1
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
@@ -0,0 +1,453 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include <algorithm>
+#include <iterator>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+const char* const kExampleWeightsName = "example_weights";
+const char* const kMaxElementsName = "max_elements";
+const char* const kGenerateQuantiles = "generate_quantiles";
+const char* const kNumBucketsName = "num_buckets";
+const char* const kEpsilonName = "epsilon";
+const char* const kBucketBoundariesName = "bucket_boundaries";
+const char* const kBucketsName = "buckets";
+const char* const kSummariesName = "summaries";
+const char* const kNumStreamsName = "num_streams";
+const char* const kNumFeaturesName = "num_features";
+const char* const kFloatFeaturesName = "float_values";
+const char* const kResourceHandleName = "quantile_stream_resource_handle";
+
+using QuantileStreamResource = BoostedTreesQuantileStreamResource;
+using QuantileStream =
+ boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
+using QuantileSummary =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
+using QuantileSummaryEntry =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float,
+ float>::SummaryEntry;
+
+// Generates quantiles on a finalized QuantileStream.
+std::vector<float> GenerateBoundaries(const QuantileStream& stream,
+ const int64 num_boundaries) {
+ std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries);
+
+ // Uniquify elements as we may get dupes.
+ auto end_it = std::unique(boundaries.begin(), boundaries.end());
+ boundaries.resize(std::distance(boundaries.begin(), end_it));
+ return boundaries;
+}
+
+// Generates quantiles on a finalized QuantileStream.
+std::vector<float> GenerateQuantiles(const QuantileStream& stream,
+ const int64 num_quantiles) {
+ // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values
+ // will be returned.
+ std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles - 1);
+ CHECK_EQ(boundaries.size(), num_quantiles);
+ return boundaries;
+}
+
+std::vector<float> GetBuckets(const int32 feature,
+ const OpInputList& buckets_list) {
+ const auto& buckets = buckets_list[feature].flat<float>();
+ std::vector<float> buckets_vector(buckets.data(),
+ buckets.data() + buckets.size());
+ return buckets_vector;
+}
+
+REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesQuantileStreamResource);
+
+REGISTER_KERNEL_BUILDER(
+ Name("IsBoostedTreesQuantileStreamResourceInitialized").Device(DEVICE_CPU),
+ IsResourceInitialized<BoostedTreesQuantileStreamResource>);
+
+class BoostedTreesCreateQuantileStreamResourceOp : public OpKernel {
+ public:
+ explicit BoostedTreesCreateQuantileStreamResourceOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Only create one, if one does not exist already. Report status for all
+ // other exceptions. If one already exists, it unrefs the new one.
+ // An epsilon value of zero could cause perfoamance issues and is therefore,
+ // disallowed.
+ const Tensor* epsilon_t;
+ OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
+ float epsilon = epsilon_t->scalar<float>()();
+ OP_REQUIRES(
+ context, epsilon > 0,
+ errors::InvalidArgument("An epsilon value of zero is not allowed."));
+
+ const Tensor* num_streams_t;
+ OP_REQUIRES_OK(context, context->input(kNumStreamsName, &num_streams_t));
+ int64 num_streams = num_streams_t->scalar<int64>()();
+
+ auto result =
+ new QuantileStreamResource(epsilon, max_elements_, num_streams);
+ auto status = CreateResource(context, HandleFromInput(context, 0), result);
+ if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
+ OP_REQUIRES(context, false, status);
+ }
+ }
+
+ private:
+ // An upper bound on the number of entries that the summaries might have
+ // for a feature.
+ int64 max_elements_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesCreateQuantileStreamResource").Device(DEVICE_CPU),
+ BoostedTreesCreateQuantileStreamResourceOp);
+
+class BoostedTreesMakeQuantileSummariesOp : public OpKernel {
+ public:
+ explicit BoostedTreesMakeQuantileSummariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ // Read float features list;
+ OpInputList float_features_list;
+ OP_REQUIRES_OK(
+ context, context->input_list(kFloatFeaturesName, &float_features_list));
+
+ // Parse example weights and get batch size.
+ const Tensor* example_weights_t;
+ OP_REQUIRES_OK(context,
+ context->input(kExampleWeightsName, &example_weights_t));
+ auto example_weights = example_weights_t->flat<float>();
+ const int64 batch_size = example_weights.size();
+ const Tensor* epsilon_t;
+ OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
+ float epsilon = epsilon_t->scalar<float>()();
+
+ OpOutputList summaries_output_list;
+ OP_REQUIRES_OK(
+ context, context->output_list(kSummariesName, &summaries_output_list));
+
+ auto do_quantile_summary_gen = [&](const int64 begin, const int64 end) {
+ // Iterating features.
+ for (int64 index = begin; index < end; index++) {
+ const auto feature_values = float_features_list[index].flat<float>();
+ QuantileStream stream(epsilon, batch_size + 1);
+ // Run quantile summary generation.
+ for (int64 j = 0; j < batch_size; j++) {
+ stream.PushEntry(feature_values(j), example_weights(j));
+ }
+ stream.Finalize();
+ const auto summary_entry_list = stream.GetFinalSummary().GetEntryList();
+ Tensor* output_t;
+ OP_REQUIRES_OK(
+ context,
+ summaries_output_list.allocate(
+ index,
+ TensorShape({static_cast<int64>(summary_entry_list.size()), 4}),
+ &output_t));
+ auto output = output_t->matrix<float>();
+ for (auto row = 0; row < summary_entry_list.size(); row++) {
+ const auto& entry = summary_entry_list[row];
+ output(row, 0) = entry.value;
+ output(row, 1) = entry.weight;
+ output(row, 2) = entry.min_rank;
+ output(row, 3) = entry.max_rank;
+ }
+ }
+ };
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * batch_size;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
+ kCostPerUnit, do_quantile_summary_gen);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesMakeQuantileSummaries").Device(DEVICE_CPU),
+ BoostedTreesMakeQuantileSummariesOp);
+
+class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceAddSummariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ OpInputList summaries_list;
+ OP_REQUIRES_OK(context,
+ context->input_list(kSummariesName, &summaries_list));
+ int32 num_streams = stream_resource->num_streams();
+ CHECK_EQ(static_cast<int>(num_streams), summaries_list.size());
+
+ auto do_quantile_add_summary = [&](const int64 begin, const int64 end) {
+ // Iterating all features.
+ for (int64 feature_idx = begin; feature_idx < end; ++feature_idx) {
+ const Tensor& summaries = summaries_list[feature_idx];
+ const auto summary_values = summaries.matrix<float>();
+ const auto& tensor_shape = summaries.shape();
+ const int64 entries_size = tensor_shape.dim_size(0);
+ CHECK_EQ(tensor_shape.dim_size(1), 4);
+ std::vector<QuantileSummaryEntry> summary_entries;
+ summary_entries.reserve(entries_size);
+ for (int64 i = 0; i < entries_size; i++) {
+ float value = summary_values(i, 0);
+ float weight = summary_values(i, 1);
+ float min_rank = summary_values(i, 2);
+ float max_rank = summary_values(i, 3);
+ QuantileSummaryEntry entry(value, weight, min_rank, max_rank);
+ summary_entries.push_back(entry);
+ }
+ stream_resource->stream(feature_idx)->PushSummary(summary_entries);
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_add_summary);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceAddSummaries").Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceAddSummariesOp);
+
+class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceFlushOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr(kGenerateQuantiles, &generate_quantiles_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ const Tensor* num_buckets_t;
+ OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t));
+ const int64 num_buckets = num_buckets_t->scalar<int64>()();
+ const int64 num_streams = stream_resource->num_streams();
+
+ auto do_quantile_flush = [&](const int64 begin, const int64 end) {
+ // Iterating over all streams.
+ for (int64 stream_idx = begin; stream_idx < end; ++stream_idx) {
+ QuantileStream* stream = stream_resource->stream(stream_idx);
+ stream->Finalize();
+ stream_resource->set_boundaries(
+ generate_quantiles_ ? GenerateQuantiles(*stream, num_buckets)
+ : GenerateBoundaries(*stream, num_buckets),
+ stream_idx);
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_flush);
+
+ stream_resource->set_buckets_ready(true);
+ }
+
+ private:
+ bool generate_quantiles_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceFlush").Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceFlushOp);
+
+class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp
+ : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceGetBucketBoundariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ const int64 num_streams = stream_resource->num_streams();
+ CHECK_EQ(num_features_, num_streams);
+ OpOutputList bucket_boundaries_list;
+ OP_REQUIRES_OK(context, context->output_list(kBucketBoundariesName,
+ &bucket_boundaries_list));
+
+ auto do_quantile_get_buckets = [&](const int64 begin, const int64 end) {
+ // Iterating over all streams.
+ for (int64 stream_idx = begin; stream_idx < end; stream_idx++) {
+ const auto& boundaries = stream_resource->boundaries(stream_idx);
+ Tensor* bucket_boundaries_t = nullptr;
+ OP_REQUIRES_OK(context,
+ bucket_boundaries_list.allocate(
+ stream_idx, {static_cast<int64>(boundaries.size())},
+ &bucket_boundaries_t));
+ auto* quantiles_flat = bucket_boundaries_t->flat<float>().data();
+ memcpy(quantiles_flat, boundaries.data(),
+ sizeof(float) * boundaries.size());
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_get_buckets);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
+ .Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceGetBucketBoundariesOp);
+
+// Given the calculated quantiles thresholds and input data, this operation
+// converts the input features into the buckets (categorical values), depending
+// on which quantile they fall into.
+class BoostedTreesBucketizeOp : public OpKernel {
+ public:
+ explicit BoostedTreesBucketizeOp(OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ // Read float features list;
+ OpInputList float_features_list;
+ OP_REQUIRES_OK(
+ context, context->input_list(kFloatFeaturesName, &float_features_list));
+ OpInputList bucket_boundaries_list;
+ OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
+ &bucket_boundaries_list));
+ OP_REQUIRES(context,
+ tensorflow::TensorShapeUtils::IsVector(
+ bucket_boundaries_list[0].shape()),
+ errors::InvalidArgument(
+ strings::Printf("Buckets should be flat vectors.")));
+ OpOutputList buckets_list;
+ OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list));
+
+ auto do_quantile_get_quantiles = [&](const int64 begin, const int64 end) {
+ // Iterating over all resources
+ for (int64 feature_idx = begin; feature_idx < end; feature_idx++) {
+ const Tensor& values_tensor = float_features_list[feature_idx];
+ const int64 num_values = values_tensor.dim_size(0);
+
+ Tensor* output_t = nullptr;
+ OP_REQUIRES_OK(
+ context, buckets_list.allocate(
+ feature_idx, TensorShape({num_values, 1}), &output_t));
+ auto output = output_t->matrix<int32>();
+
+ const std::vector<float>& bucket_boundaries_vector =
+ GetBuckets(feature_idx, bucket_boundaries_list);
+ CHECK(!bucket_boundaries_vector.empty())
+ << "Got empty buckets for feature " << feature_idx;
+ auto flat_values = values_tensor.flat<float>();
+ for (int64 instance = 0; instance < num_values; instance++) {
+ const float value = flat_values(instance);
+ auto bucket_iter =
+ std::lower_bound(bucket_boundaries_vector.begin(),
+ bucket_boundaries_vector.end(), value);
+ if (bucket_iter == bucket_boundaries_vector.end()) {
+ --bucket_iter;
+ }
+ const int32 bucket = static_cast<int32>(
+ bucket_iter - bucket_boundaries_vector.begin());
+ // Bucket id.
+ output(instance, 0) = bucket;
+ }
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_features_;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
+ kCostPerUnit, do_quantile_get_quantiles);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesBucketize").Device(DEVICE_CPU),
+ BoostedTreesBucketizeOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
index 3163c63949..12d9473776 100644
--- a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
@@ -1,5 +1,5 @@
# Description:
-# This directory contains common utilities used in boosted_trees.
+# This directory contains common quantile utilities used in boosted_trees.
package(
default_visibility = ["//tensorflow:internal"],
)
@@ -16,6 +16,7 @@ cc_library(
name = "weighted_quantiles",
srcs = [],
hdrs = [
+ "quantile_stream_resource.h",
"weighted_quantiles_buffer.h",
"weighted_quantiles_stream.h",
"weighted_quantiles_summary.h",
@@ -23,6 +24,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
],
)
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
new file mode 100644
index 0000000000..1c31724272
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
@@ -0,0 +1,96 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
+
+#include <vector>
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+using QuantileStream =
+ boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
+
+// Quantile Stream Resource for a list of streams sharing the same number of
+// quantiles, maximum elements, and epsilon.
+class BoostedTreesQuantileStreamResource : public ResourceBase {
+ public:
+ BoostedTreesQuantileStreamResource(const float epsilon,
+ const int64 max_elements,
+ const int64 num_streams)
+ : are_buckets_ready_(false),
+ epsilon_(epsilon),
+ num_streams_(num_streams),
+ max_elements_(max_elements) {
+ streams_.reserve(num_streams_);
+ boundaries_.reserve(num_streams_);
+ for (int64 idx = 0; idx < num_streams; ++idx) {
+ streams_.push_back(QuantileStream(epsilon, max_elements));
+ boundaries_.push_back(std::vector<float>());
+ }
+ }
+
+ string DebugString() override { return "QuantileStreamResource"; }
+
+ tensorflow::mutex* mutex() { return &mu_; }
+
+ QuantileStream* stream(const int64 index) { return &streams_[index]; }
+
+ const std::vector<float>& boundaries(const int64 index) {
+ return boundaries_[index];
+ }
+
+ void set_boundaries(const std::vector<float>& boundaries, const int64 index) {
+ boundaries_[index] = boundaries;
+ }
+
+ float epsilon() const { return epsilon_; }
+ int64 num_streams() const { return num_streams_; }
+
+ bool are_buckets_ready() const { return are_buckets_ready_; }
+ void set_buckets_ready(const bool are_buckets_ready) {
+ are_buckets_ready_ = are_buckets_ready;
+ }
+
+ private:
+ ~BoostedTreesQuantileStreamResource() override {}
+
+ // Mutex for the whole resource.
+ tensorflow::mutex mu_;
+
+ // Quantile streams.
+ std::vector<QuantileStream> streams_;
+
+ // Stores the boundaries. Same size as streams_.
+ std::vector<std::vector<float>> boundaries_;
+
+ // Whether boundaries are created. Initially boundaries are empty until
+ // set_boundaries are called.
+ bool are_buckets_ready_;
+
+ const float epsilon_;
+ const int64 num_streams_;
+ // An upper-bound for the number of elements.
+ int64 max_elements_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BoostedTreesQuantileStreamResource);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
diff --git a/tensorflow/core/kernels/conditional_accumulator.h b/tensorflow/core/kernels/conditional_accumulator.h
index a7836896c7..390db8fe5a 100644
--- a/tensorflow/core/kernels/conditional_accumulator.h
+++ b/tensorflow/core/kernels/conditional_accumulator.h
@@ -51,9 +51,11 @@ class ConditionalAccumulator
// dtype: The datatype of the gradients to be accumulated.
// shape: The shape of the accumulated gradients.
// name: A name to use for the ConditionalAccumulator.
+ // reduction_type: The reduction type, i.e., MEAN or SUM
ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape,
- const string& name)
- : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name) {}
+ const string& name, const string& reduction_type)
+ : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name,
+ reduction_type) {}
~ConditionalAccumulator() override{};
protected:
diff --git a/tensorflow/core/kernels/conditional_accumulator_base.cc b/tensorflow/core/kernels/conditional_accumulator_base.cc
index 90593c56b8..292cf0cd64 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base.cc
+++ b/tensorflow/core/kernels/conditional_accumulator_base.cc
@@ -14,12 +14,17 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/conditional_accumulator_base.h"
+#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
ConditionalAccumulatorBase::ConditionalAccumulatorBase(
- const DataType& dtype, const PartialTensorShape& shape, const string& name)
- : dtype_(dtype), shape_(shape), name_(name) {
+ const DataType& dtype, const PartialTensorShape& shape, const string& name,
+ const string& reduction_type)
+ : dtype_(dtype),
+ shape_(shape),
+ name_(name),
+ reduction_type_(reduction_type) {
counter_ = 0;
current_global_step_ = 0;
}
@@ -190,7 +195,9 @@ bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx,
current_global_step_++;
// Average the accumulated gradient
- DivideAccumGradByCounter(ctx);
+ if (reduction_type_ == "MEAN") {
+ DivideAccumGradByCounter(ctx);
+ }
// Set output for accumulated gradient tensor
bool successful_set_output = SetOutput(ctx);
diff --git a/tensorflow/core/kernels/conditional_accumulator_base.h b/tensorflow/core/kernels/conditional_accumulator_base.h
index b7b7482a00..4a5ec6f0fb 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base.h
+++ b/tensorflow/core/kernels/conditional_accumulator_base.h
@@ -52,7 +52,7 @@ class ConditionalAccumulatorBase : public ResourceBase {
// name: A name to use for the ConditionalAccumulator.
ConditionalAccumulatorBase(const DataType& dtype,
const PartialTensorShape& shape,
- const string& name);
+ const string& name, const string& reduction_type);
typedef AsyncOpKernel::DoneCallback DoneCallback;
@@ -125,6 +125,7 @@ class ConditionalAccumulatorBase : public ResourceBase {
const DataType dtype_;
const PartialTensorShape shape_;
const string name_;
+ const string reduction_type_;
mutex mu_;
int counter_ GUARDED_BY(mu_);
int64 current_global_step_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/conditional_accumulator_base_op.h b/tensorflow/core/kernels/conditional_accumulator_base_op.h
index 012a0dcc12..ca24d690f8 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base_op.h
+++ b/tensorflow/core/kernels/conditional_accumulator_base_op.h
@@ -51,6 +51,8 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
&accumulator_handle_, nullptr));
OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("reduction_type", &reduction_type_));
}
void Compute(OpKernelContext* ctx) override {
@@ -81,6 +83,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
DataType dtype_;
PartialTensorShape shape_;
ContainerInfo cinfo_;
+ string reduction_type_;
private:
Status SetAccumulatorHandle(OpKernelContext* ctx)
diff --git a/tensorflow/core/kernels/conditional_accumulator_op.cc b/tensorflow/core/kernels/conditional_accumulator_op.cc
index e13bf8a4c6..52ac51a9b6 100644
--- a/tensorflow/core/kernels/conditional_accumulator_op.cc
+++ b/tensorflow/core/kernels/conditional_accumulator_op.cc
@@ -34,7 +34,8 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
Creator GetCreator() const override {
return [this](ConditionalAccumulatorBase** ret) {
ConditionalAccumulator<Device, T>* accumulator =
- new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name());
+ new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(),
+ reduction_type_);
*ret = accumulator;
return Status::OK();
};
diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
index de9b69828e..639c3062cc 100644
--- a/tensorflow/core/kernels/conv_2d.h
+++ b/tensorflow/core/kernels/conv_2d.h
@@ -137,17 +137,16 @@ struct MatMulConvFunctor {
}
};
-// Shuffles a filter tensor from:
-// [<spatial_dims>, in, out]
-// to:
-// [out, in, <spatial_dims>]
+// Shuffles a filter tensor from TensorFlow format HWIO to dst_filter_format.
+//
+// Note: Currently OIHW is the only supported destination format. Support for
+// OHWI format will be added in a follow-up change.
template <typename Device, typename T, typename IndexType, int NDIMS>
struct TransformFilter {
- void operator()(const Device& d,
+ void operator()(const Device& d, FilterTensorFormat dst_filter_format,
typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
typename TTypes<T, NDIMS, IndexType>::Tensor out) {
- // We want a 3, 2, 0, 1 shuffle. Merge the spatial dimensions together
- // to speed up the shuffle operation.
+ // Merge the spatial dimensions together to speed up the shuffle operation.
Eigen::DSizes<IndexType, 3> merged_dims;
merged_dims[0] = in.dimension(0); // spatial dimensions
for (int i = 1; i < NDIMS - 2; ++i) {
@@ -156,16 +155,30 @@ struct TransformFilter {
merged_dims[1] = in.dimension(NDIMS - 2); // input filters
merged_dims[2] = in.dimension(NDIMS - 1); // output filters
+ CHECK(dst_filter_format == FORMAT_OIHW)
+ << "Unsupported destination filter format: "
+ << ToString(dst_filter_format);
+ // Source filter format is FORMAT_HWIO and spatial dimensions HW are merged
+ // in the beginning.
+ Eigen::DSizes<IndexType, 3> shuffling_perm =
+ Eigen::DSizes<IndexType, 3>(2, 1, 0);
+
Eigen::DSizes<IndexType, NDIMS> expanded_dims;
- expanded_dims[0] = in.dimension(NDIMS - 1); // output filters
- expanded_dims[1] = in.dimension(NDIMS - 2); // input filters
- for (int i = 0; i < NDIMS - 2; ++i) { // spatial dimensions
- expanded_dims[i + 2] = in.dimension(i);
+ int out_index = 0;
+ for (int merged_dim = 0; merged_dim < merged_dims.rank(); ++merged_dim) {
+ if (shuffling_perm[merged_dim] == 0) {
+ for (int spatial_dim = 0; spatial_dim < NDIMS - 2; ++spatial_dim) {
+ expanded_dims[out_index++] = in.dimension(spatial_dim);
+ }
+ } else {
+ constexpr int kLastSpatialDim = NDIMS - 3;
+ expanded_dims[out_index++] =
+ in.dimension(kLastSpatialDim + shuffling_perm[merged_dim]);
+ }
}
- out.device(d) = in.reshape(merged_dims)
- .shuffle(Eigen::DSizes<IndexType, 3>(2, 1, 0))
- .reshape(expanded_dims);
+ out.device(d) =
+ in.reshape(merged_dims).shuffle(shuffling_perm).reshape(expanded_dims);
}
};
@@ -282,7 +295,9 @@ struct SwapDimension0And2InTensor3 {
const gtl::ArraySlice<int64>& input_dims, T* out);
};
-// Reverses the effect of TransformFilter above.
+// Transforms back filter from OIHW to HWOI format to reverse effect of
+// TransformFilter above.
+// TODO(hinsu): Support reverse transformation from filter format OHWI as well.
template <typename Device, typename T, int NDIMS>
struct ReverseTransformFilter {
void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
diff --git a/tensorflow/core/kernels/conv_3d.h b/tensorflow/core/kernels/conv_3d.h
index 02e3655ad1..b819c6f910 100644
--- a/tensorflow/core/kernels/conv_3d.h
+++ b/tensorflow/core/kernels/conv_3d.h
@@ -19,6 +19,7 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_CONV_3D_H_
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h"
#include "tensorflow/core/kernels/eigen_cuboid_convolution.h"
namespace tensorflow {
@@ -28,6 +29,14 @@ namespace functor {
template <typename Device, typename T>
struct CuboidConvolution;
+// Backward input pass for the cuboid convolution.
+template <typename Device, typename T>
+struct CuboidConvolutionBackwardInput;
+
+// Backward filter pass for the cuboid convolution.
+template <typename Device, typename T>
+struct CuboidConvolutionBackwardFilter;
+
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename T>
@@ -42,6 +51,40 @@ struct CuboidConvolution<CPUDevice, T> {
}
};
+template <typename T>
+struct CuboidConvolutionBackwardInput<CPUDevice, T> {
+ void operator()(const CPUDevice& d,
+ typename TTypes<T, 5>::Tensor input_backward,
+ typename TTypes<T, 5>::ConstTensor filter,
+ typename TTypes<T, 5>::ConstTensor output_backward,
+ int stride_planes, int stride_rows, int stride_cols) {
+ // Need to swap the order of plane/row/col strides when calling Eigen.
+ input_backward.device(d) = Eigen::CuboidConvolutionBackwardInput(
+ filter, output_backward,
+ input_backward.dimension(3), // input_planes
+ input_backward.dimension(2), // input_rows
+ input_backward.dimension(1), // input_cols
+ stride_cols, stride_rows, stride_planes);
+ }
+};
+
+template <typename T>
+struct CuboidConvolutionBackwardFilter<CPUDevice, T> {
+ void operator()(const CPUDevice& d,
+ typename TTypes<T, 5>::Tensor filter_backward,
+ typename TTypes<T, 5>::ConstTensor input,
+ typename TTypes<T, 5>::ConstTensor output_backward,
+ int stride_planes, int stride_rows, int stride_cols) {
+ // Need to swap the order of plane/row/col strides when calling Eigen.
+ filter_backward.device(d) = Eigen::CuboidConvolutionBackwardKernel(
+ input, output_backward,
+ filter_backward.dimension(2), // kernel_planes
+ filter_backward.dimension(1), // kernel_rows
+ filter_backward.dimension(0), // kernel_cols
+ stride_cols, stride_rows, stride_planes);
+ }
+};
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index 63b1bcda43..9e86a16b66 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -1018,7 +1018,8 @@ namespace functor {
extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index d664a11e73..43bb5ea56c 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -901,7 +901,8 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
&transformed_filter));
functor::TransformFilter<GPUDevice, T, int, 4>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
+ ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter.tensor<T, 4>()),
To32Bit(transformed_filter.tensor<T, 4>()));
Tensor transformed_out_backprop;
@@ -1090,7 +1091,8 @@ namespace functor {
extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index fc0a2f123f..507720c998 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -41,6 +41,17 @@ limitations under the License.
namespace tensorflow {
+// Compute padding for the given spatial dimension.
+int ConvBackpropDimensions::SpatialPadding(const Padding& padding,
+ int dim) const {
+ return (padding == VALID)
+ ? 0
+ : std::max<int>(
+ 0, static_cast<int>((output_size(dim) - 1) * stride(dim) +
+ (filter_size(dim) - 1) * dilation(dim) +
+ 1 - input_size(dim)));
+}
+
// The V2 version computes windowed output size with arbitrary dilation_rate,
// while the original version only handles the cases where dilation_rates equal
// to 1.
diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h
index 535586d53a..9551959463 100644
--- a/tensorflow/core/kernels/conv_grad_ops.h
+++ b/tensorflow/core/kernels/conv_grad_ops.h
@@ -234,6 +234,16 @@ struct ConvBackpropDimensions {
// Input and output feature depth.
int64 in_depth, out_depth;
+
+ // Convenience access methods for spatial dimensions properties.
+ int64 input_size(int dim) const { return spatial_dims[dim].input_size; }
+ int64 filter_size(int dim) const { return spatial_dims[dim].filter_size; }
+ int64 output_size(int dim) const { return spatial_dims[dim].output_size; }
+ int64 stride(int dim) const { return spatial_dims[dim].stride; }
+ int64 dilation(int dim) const { return spatial_dims[dim].dilation; }
+
+ // Compute padding for the given spatial dimension.
+ int SpatialPadding(const Padding& padding, int dim) const;
};
// Common code between implementations of Conv?DBackpropInput and
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index 15f1bf9aba..bab91f5e86 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -32,111 +33,130 @@ limitations under the License.
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
+#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
using stream_executor::dnn::DimIndex;
#endif
+namespace {
+
+// TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and
+// conv_grad_input_ops_3d.cc.
+
+// TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels.
+
+// "Depth" is already used for the channel dimension, so for the third spatial
+// dimension in this file we use "plane", although in NDHWC layout it's
+// indicated with a "D".
+
+// Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
+// order (planes, height, width, depth), constructed from patches in 'col_data',
+// which is required to be in storage order (out_planes * out_height *
+// out_width, filter_planes, filter_height, filter_width, in_depth).
+//
+// Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
+template <typename T>
+void Col2im(const T* col_data, const int depth, const int planes,
+ const int height, const int width, const int filter_p,
+ const int filter_h, const int filter_w, const int pad_pt,
+ const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
+ const int pad_r, const int stride_p, const int stride_h,
+ const int stride_w, T* im_data) {
+ const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
+ const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
+ const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
+ int p_pad = -pad_pt;
+ for (int p = 0; p < planes_col; ++p) {
+ int h_pad = -pad_t;
+ for (int h = 0; h < height_col; ++h) {
+ int w_pad = -pad_l;
+ for (int w = 0; w < width_col; ++w) {
+ T* im_patch_data =
+ im_data + (p_pad * height * width + h_pad * width + w_pad) * depth;
+ for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
+ for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
+ for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
+ if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
+ iw < width) {
+ for (int i = 0; i < depth; ++i) {
+ im_patch_data[i] += col_data[i];
+ }
+ }
+ im_patch_data += depth;
+ col_data += depth;
+ }
+ // Jump over remaining number of depth.
+ im_patch_data += depth * (width - filter_w);
+ }
+ // Jump over remaining number of (depth * width).
+ im_patch_data += (depth * width) * (height - filter_h);
+ }
+ w_pad += stride_w;
+ }
+ h_pad += stride_h;
+ }
+ p_pad += stride_p;
+ }
+}
+
+// Returns in 'col_data', image patches in storage order (planes, height, width,
+// depth) extracted from image at 'input_data', which is required to be in
+// storage order (batch, planes, height, width, depth).
+//
+// Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
+template <typename T>
+void Im2col(const T* input_data, const int depth, const int planes,
+ const int height, const int width, const int filter_p,
+ const int filter_h, const int filter_w, const int pad_pt,
+ const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
+ const int pad_r, const int stride_p, const int stride_h,
+ const int stride_w, T* col_data) {
+ const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
+ const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
+ const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
+
+ int p_pad = -pad_pt;
+ for (int p = 0; p < planes_col; ++p) {
+ int h_pad = -pad_t;
+ for (int h = 0; h < height_col; ++h) {
+ int w_pad = -pad_l;
+ for (int w = 0; w < width_col; ++w) {
+ for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
+ for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
+ for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
+ if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
+ iw < width) {
+ memcpy(col_data,
+ input_data +
+ (ip * height * width + ih * width + iw) * depth,
+ sizeof(T) * depth);
+ } else {
+ // This should be simply padded with zero.
+ memset(col_data, 0, sizeof(T) * depth);
+ }
+ col_data += depth;
+ }
+ }
+ }
+ w_pad += stride_w;
+ }
+ h_pad += stride_h;
+ }
+ p_pad += stride_p;
+ }
+}
+
+} // namespace
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-// TODO(mjanusz): Get rid of the macro and return shapes directly.
-#define EXTRACT_AND_VERIFY_DIMENSIONS(label) \
- const Tensor& out_backprop = context->input(2); \
- OP_REQUIRES( \
- context, input_shape.dims() == 5, \
- errors::InvalidArgument(label, ": input must be 5-dimensional")); \
- OP_REQUIRES( \
- context, filter_shape.dims() == 5, \
- errors::InvalidArgument(label, ": filter must be 5-dimensional")); \
- OP_REQUIRES( \
- context, out_backprop.dims() == 5, \
- errors::InvalidArgument(label, ": out_backprop must be 5-dimensional")); \
- const int64 batch = input_shape.dim_size(0); \
- OP_REQUIRES( \
- context, batch == out_backprop.dim_size(0), \
- errors::InvalidArgument( \
- label, ": input and out_backprop must have the same batch size")); \
- const std::array<int64, 3> input_size = { \
- {GetTensorDim(input_shape, data_format_, '0'), \
- GetTensorDim(input_shape, data_format_, '1'), \
- GetTensorDim(input_shape, data_format_, '2')}}; \
- const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C'); \
- const std::array<int64, 3> filter_size = {{filter_shape.dim_size(0), \
- filter_shape.dim_size(1), \
- filter_shape.dim_size(2)}}; \
- const int64 output_cols = GetTensorDim(out_backprop, data_format_, '2'); \
- const int64 output_rows = GetTensorDim(out_backprop, data_format_, '1'); \
- const int64 output_planes = GetTensorDim(out_backprop, data_format_, '0'); \
- OP_REQUIRES(context, in_depth == filter_shape.dim_size(3), \
- errors::InvalidArgument( \
- label, ": input and filter must have the same depth")); \
- const int64 out_depth = filter_shape.dim_size(4); \
- OP_REQUIRES( \
- context, out_depth == GetTensorDim(out_backprop, data_format_, 'C'), \
- errors::InvalidArgument( \
- label, ": filter and out_backprop must have the same out_depth")); \
- const std::array<int64, 3> dilations = { \
- {GetTensorDim(dilation_, data_format_, '0'), \
- GetTensorDim(dilation_, data_format_, '1'), \
- GetTensorDim(dilation_, data_format_, '2')}}; \
- const std::array<int64, 3> strides = { \
- {GetTensorDim(stride_, data_format_, '0'), \
- GetTensorDim(stride_, data_format_, '1'), \
- GetTensorDim(stride_, data_format_, '2')}}; \
- std::array<int64, 3> out, padding; \
- OP_REQUIRES_OK( \
- context, Get3dOutputSizeV2(input_size, filter_size, dilations, strides, \
- padding_, &out, &padding)); \
- OP_REQUIRES(context, output_planes == out[0], \
- errors::InvalidArgument( \
- label, \
- ": Number of planes of out_backprop doesn't match " \
- "computed: actual = ", \
- output_planes, ", computed = ", out[0])); \
- OP_REQUIRES( \
- context, output_rows == out[1], \
- errors::InvalidArgument( \
- label, ": Number of rows of out_backprop doesn't match computed: ", \
- "actual = ", output_rows, ", computed = ", out[1])); \
- OP_REQUIRES( \
- context, output_cols == out[2], \
- errors::InvalidArgument( \
- label, ": Number of cols of out_backprop doesn't match computed: ", \
- "actual = ", output_cols, ", computed = ", out[2])); \
- const auto expanded_out_planes = (output_planes - 1) * strides[0] + 1; \
- const auto expanded_out_rows = (output_rows - 1) * strides[1] + 1; \
- const auto expanded_out_cols = (output_cols - 1) * strides[2] + 1; \
- const auto padded_out_planes = input_size[0] + filter_size[0] - 1; \
- const auto padded_out_rows = input_size[1] + filter_size[1] - 1; \
- const auto padded_out_cols = input_size[2] + filter_size[2] - 1; \
- const auto top_pad_planes = filter_size[0] - 1 - padding[0]; \
- const auto top_pad_rows = filter_size[1] - 1 - padding[1]; \
- const auto left_pad_cols = filter_size[2] - 1 - padding[2]; \
- const auto bottom_pad_planes = \
- padded_out_planes - expanded_out_planes - top_pad_planes; \
- const auto bottom_pad_rows = \
- padded_out_rows - expanded_out_rows - top_pad_rows; \
- const auto right_pad_cols = \
- padded_out_cols - expanded_out_cols - left_pad_cols; \
- VLOG(2) << "Conv3d: " << label \
- << ": expanded_out_planes = " << expanded_out_planes \
- << ": expanded_out_rows = " << expanded_out_rows \
- << ", expanded_out_cols = " << expanded_out_cols \
- << ", padded_out_planes = " << padded_out_planes \
- << ", padded_out_rows = " << padded_out_rows \
- << ", padded_out_cols = " << padded_out_cols \
- << ", top_pad_planes = " << top_pad_planes \
- << ", top_pad_rows = " << top_pad_rows \
- << ", left_pad_cols = " << left_pad_cols \
- << ", bottom_pad_planes = " << bottom_pad_planes \
- << ", bottom_pad_rows = " << bottom_pad_rows \
- << ", right_pad_cols = " << right_pad_cols
-
-// Backprop for input.
+// Backprop for input that offloads computation to
+// Eigen::CuboidConvolutionBackwardInput.
template <typename Device, class T>
class Conv3DBackpropInputOp : public OpKernel {
public:
@@ -192,6 +212,116 @@ class Conv3DBackpropInputOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& filter = context->input(1);
const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape input_shape;
+ if (takes_shape_) {
+ const Tensor& input_sizes = context->input(0);
+ // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes.
+ OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
+ } else {
+ input_shape = context->input(0).shape();
+ }
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape,
+ stride_, padding_, data_format_, &dims));
+
+ Tensor* in_backprop;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input_shape, &in_backprop));
+
+ functor::CuboidConvolutionBackwardInput<Device, T>()(
+ context->eigen_device<Device>(),
+ in_backprop->tensor<T, 5>(), // input_backward
+ filter.tensor<T, 5>(), // filter
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+ }
+
+ private:
+ std::vector<int32> dilation_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp);
+};
+
+// Custom backprop for input that explicitly does the work sharding and calls
+// Eigen only to multiply matrices.
+template <typename Device, class T>
+class Conv3DCustomBackpropInputOp : public OpKernel {
+ // Limit the maximum size of allocated temporary buffer to
+ // kMaxTempAllocationOverhead times the size of the input tensors (input,
+ // filter, out_backprop). If the size of the temporary buffer exceeds this
+ // limit, fallback on Eigen implementation.
+ static constexpr int kMaxTempAllocationOverhead = 25;
+
+ public:
+ explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context)
+ : OpKernel(context),
+ data_format_(FORMAT_NHWC),
+ takes_shape_(type_string().find("V2") != std::string::npos) {
+ // data_format is only available in V2.
+ if (takes_shape_) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(
+ context, data_format_ == FORMAT_NHWC,
+ errors::InvalidArgument(
+ "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
+ }
+
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
+ OP_REQUIRES(context, dilation_.size() == 5,
+ errors::InvalidArgument("Dilation rates field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
+ GetTensorDim(dilation_, data_format_, 'N') == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilation rates in the batch and depth dimensions."));
+
+ // TODO(yangzihao): Add CPU version of dilated conv 3D.
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, '0') == 1 &&
+ GetTensorDim(dilation_, data_format_, '1') == 1 &&
+ GetTensorDim(dilation_, data_format_, '2') == 1),
+ errors::InvalidArgument(
+ "Current CPU implementation does not yet support "
+ "dilation rates larger than 1."));
+
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(stride_, data_format_, 'C') == 1 &&
+ GetTensorDim(stride_, data_format_, 'N') == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& filter = context->input(1);
+ const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
@@ -200,51 +330,239 @@ class Conv3DBackpropInputOp : public OpKernel {
} else {
input_shape = context->input(0).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
- {0, 0},
- {top_pad_planes, bottom_pad_planes},
- {top_pad_rows, bottom_pad_rows},
- {left_pad_cols, right_pad_cols},
- {0, 0}};
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape,
+ stride_, padding_, data_format_, &dims));
+
Tensor* in_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
- // Fill out a padded out_backprop.
- TensorShape padded_out_shape({batch, padded_out_planes, padded_out_rows,
- padded_out_cols, out_depth});
- Tensor padded_output;
+ int64 top_pad_planes, bottom_pad_planes;
+ int64 top_pad_rows, bottom_pad_rows;
+ int64 left_pad_cols, right_pad_cols;
+
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[0].input_size,
+ dims.spatial_dims[0].filter_size,
+ dims.spatial_dims[0].stride, padding_,
+ &dims.spatial_dims[0].output_size,
+ &top_pad_planes, &bottom_pad_planes));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[1].input_size,
+ dims.spatial_dims[1].filter_size,
+ dims.spatial_dims[1].stride, padding_,
+ &dims.spatial_dims[1].output_size,
+ &top_pad_rows, &bottom_pad_rows));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[2].input_size,
+ dims.spatial_dims[2].filter_size,
+ dims.spatial_dims[2].stride, padding_,
+ &dims.spatial_dims[2].output_size,
+ &left_pad_cols, &right_pad_cols));
+
+ // TODO(ezhulenev): Extract work size and shard estimation to shared
+ // functions in conv_grad_ops, and update 2d convolution backprop.
+
+ // The total dimension size of each kernel.
+ const int64 filter_total_size =
+ dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
+ dims.spatial_dims[2].filter_size * dims.in_depth;
+
+ // The output image size is the spatial size of the output.
+ const int64 output_image_size = dims.spatial_dims[0].output_size *
+ dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size;
+
+ const auto cache_sizes = Eigen::internal::CacheSizes();
+ const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
+
+ // Use L3 cache size as target working set size.
+ const size_t target_working_set_size = l3_cache_size / sizeof(T);
+
+ // Calculate size of matrices involved in MatMul: C = A x B.
+ const int64 size_A = output_image_size * dims.out_depth;
+
+ const int64 size_B = filter_total_size * dims.out_depth;
+
+ const int64 size_C = output_image_size * filter_total_size;
+
+ const int64 work_unit_size = size_A + size_B + size_C;
+
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+
+ // Use parallel tensor contractions if there is no batching.
+ //
+ // Compared to Conv2D code, this version is missing work size estimation. In
+ // benchmarks I didn't find a case when it's beneficial to run parallel
+ // contraction compared to sharding and matmuls.
+ const bool use_parallel_contraction = dims.batch_size == 1;
+
+ const size_t shard_size =
+ use_parallel_contraction
+ ? 1
+ : (target_working_set_size + work_unit_size - 1) / work_unit_size;
+
+ // Total number of elements in all the tensors used by this kernel.
+ int64 total_tensor_elements = input_shape.num_elements() +
+ filter_shape.num_elements() +
+ out_backprop_shape.num_elements();
+
+ // Shape of the temporary workspace buffer.
+ TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
+ static_cast<int64>(output_image_size),
+ static_cast<int64>(filter_total_size)};
+ int64 col_buffer_elements = col_buffer_shape.num_elements();
+
+ // If the temporary allocation overhead is too large, fallback on Eigen
+ // implementation which requires much less memory.
+ int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
+ if (col_buffer_overhead > kMaxTempAllocationOverhead) {
+ VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: "
+ "col_buffer_overhead="
+ << col_buffer_overhead;
+
+ functor::CuboidConvolutionBackwardInput<Device, T>()(
+ context->eigen_device<Device>(),
+ in_backprop->tensor<T, 5>(), // input_backward
+ filter.tensor<T, 5>(), // filter
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+
+ return;
+ }
+
+ Tensor col_buffer;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- padded_out_shape, &padded_output));
- Eigen::DSizes<Eigen::DenseIndex, 5> no_op_shuffle{0, 1, 2, 3, 4};
- Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
- strides[2], 1};
- functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
- eigen_strides, pad_dims, no_op_shuffle, padded_output.tensor<T, 5>());
- const Tensor& padded_output_cref = padded_output;
-
- // Fill a new "reverted" filter. We need to transpose the in_depth and
- // out_depth for the filter and reverse the planes, rows and cols.
- TensorShape r_filter_shape(
- {filter_size[0], filter_size[1], filter_size[2], out_depth, in_depth});
- Tensor r_filter;
- OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
- r_filter_shape, &r_filter));
- Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{0, 1, 2, 4, 3};
- Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), filter.tensor<T, 5>(), filter_order,
- filter_rev_dims, r_filter.tensor<T, 5>());
- const Tensor& r_filter_cref = r_filter;
-
- // Now we can call conv_3d directly.
- functor::CuboidConvolution<Device, T>()(
- context->eigen_device<Device>(), in_backprop->tensor<T, 5>(),
- padded_output_cref.tensor<T, 5>(), r_filter_cref.tensor<T, 5>(), 1, 1,
- 1, BrainPadding2EigenPadding(VALID));
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ col_buffer_shape, &col_buffer));
+
+ // The input offset corresponding to a single input image.
+ const int64 input_offset = dims.spatial_dims[0].input_size *
+ dims.spatial_dims[1].input_size *
+ dims.spatial_dims[2].input_size * dims.in_depth;
+
+ // The output offset corresponding to a single output image.
+ const int64 output_offset =
+ dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size * dims.out_depth;
+
+ const T* filter_data = filter.template flat<T>().data();
+ T* col_buffer_data = col_buffer.template flat<T>().data();
+ const T* out_backprop_data = out_backprop.template flat<T>().data();
+
+ auto in_backprop_flat = in_backprop->template flat<T>();
+ T* input_backprop_data = in_backprop_flat.data();
+ in_backprop_flat.device(context->eigen_device<Device>()) =
+ in_backprop_flat.constant(T(0));
+
+ if (use_parallel_contraction) {
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ TensorMap;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ ConstTensorMap;
+
+ // Initialize contraction dims (we need to transpose 'B' below).
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
+ contract_dims[0].first = 1;
+ contract_dims[0].second = 1;
+
+ for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
+ // Compute gradient into col_buffer.
+ TensorMap C(col_buffer_data, output_image_size, filter_total_size);
+
+ ConstTensorMap A(out_backprop_data + output_offset * image_id,
+ output_image_size, dims.out_depth);
+ ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
+
+ C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
+
+ Col2im<T>(col_buffer_data, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ input_backprop_data);
+
+ input_backprop_data += input_offset;
+ }
+ } else {
+ typedef Eigen::Map<
+ Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
+ MatrixMap;
+ typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
+ Eigen::RowMajor>>
+ ConstMatrixMap;
+
+ for (int image_id = 0; image_id < dims.batch_size;
+ image_id += shard_size) {
+ const int shard_limit =
+ std::min(static_cast<int>(shard_size),
+ static_cast<int>(dims.batch_size) - image_id);
+
+ auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols,
+ &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols,
+ &output_image_size, &filter_total_size,
+ &input_backprop_data, &col_buffer_data,
+ &out_backprop_data, &filter_data, &input_offset,
+ &output_offset, &size_C](int64 start, int64 limit) {
+ for (int shard_id = start; shard_id < limit; ++shard_id) {
+ T* im2col_buf = col_buffer_data + shard_id * size_C;
+ T* input_data = input_backprop_data + shard_id * input_offset;
+ const T* out_data = out_backprop_data + shard_id * output_offset;
+
+ // Compute gradient into 'im2col_buf'.
+ MatrixMap C(im2col_buf, output_image_size, filter_total_size);
+
+ ConstMatrixMap A(out_data, output_image_size, dims.out_depth);
+ ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth);
+
+ C.noalias() = A * B.transpose();
+
+ Col2im<T>(im2col_buf, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ input_data);
+ }
+ };
+ Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
+ work_unit_size, shard);
+
+ input_backprop_data += input_offset * shard_limit;
+ out_backprop_data += output_offset * shard_limit;
+ }
+ }
}
private:
@@ -253,21 +571,48 @@ class Conv3DBackpropInputOp : public OpKernel {
Padding padding_;
TensorFormat data_format_;
bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp);
};
+// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
+// default Eigen implementation (at the cost of ~2x-8x peak memory usage).
+
#define REGISTER_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropInputOp<CPUDevice, T>); \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropInputOp<CPUDevice, T>);
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropInputOp<CPUDevice, T>);
+
TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
-// Backprop for filter.
+// Backprop for filter that offloads computation to
+// Eigen::CuboidConvolutionBackwardFilter.
template <typename Device, class T>
class Conv3DBackpropFilterOp : public OpKernel {
public:
@@ -323,8 +668,11 @@ class Conv3DBackpropFilterOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
- TensorShape filter_shape;
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape filter_shape;
if (takes_shape_) {
const Tensor& filter_sizes = context->input(1);
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
@@ -333,13 +681,13 @@ class Conv3DBackpropFilterOp : public OpKernel {
filter_shape = context->input(1).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
- {0, 0},
- {top_pad_planes, bottom_pad_planes},
- {top_pad_rows, bottom_pad_rows},
- {left_pad_cols, right_pad_cols},
- {0, 0}};
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensions(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, stride_,
+ padding_, data_format_, &dims));
+
Tensor* filter_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, filter_shape, &filter_backprop));
@@ -349,70 +697,292 @@ class Conv3DBackpropFilterOp : public OpKernel {
return;
}
- // For the backprop of the filter, we need to also transpose the
- // out_backprop.
- // The shape of backprop is
- // [batch, out_z, out_y, out_x, out_depth]
- // And we need to change it to
- // [out_depth, out_x, out_y, out_z, batch]
- Eigen::DSizes<Eigen::DenseIndex, 5> out_order{4, 1, 2, 3, 0};
- TensorShape padded_out_shape({out_depth, padded_out_planes, padded_out_rows,
- padded_out_cols, batch});
- Tensor padded_output;
+ functor::CuboidConvolutionBackwardFilter<Device, T>()(
+ context->eigen_device<Device>(),
+ filter_backprop->tensor<T, 5>(), // filter_backward
+ input.tensor<T, 5>(), // input
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+ }
+
+ private:
+ std::vector<int32> dilation_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp);
+};
+
+// Custom backprop for filter that explicitly does the work sharding and calls
+// Eigen only to multiply matrices.
+template <typename Device, class T>
+class Conv3DCustomBackpropFilterOp : public OpKernel {
+ // Limit the maximum size of allocated temporary buffer to
+ // kMaxTempAllocationOverhead times the size of the input tensors (input,
+ // filter, out_backprop). If the size of the temporary buffer exceeds this
+ // limit, fallback on Eigen implementation.
+ static constexpr int kMaxTempAllocationOverhead = 25;
+
+ public:
+ explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context)
+ : OpKernel(context),
+ data_format_(FORMAT_NHWC),
+ takes_shape_(type_string().find("V2") != std::string::npos) {
+ // data_format is only available in V2.
+ if (takes_shape_) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(
+ context, data_format_ == FORMAT_NHWC,
+ errors::InvalidArgument(
+ "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
+ }
+
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
+ OP_REQUIRES(context, dilation_.size() == 5,
+ errors::InvalidArgument("Dilation rates field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
+ GetTensorDim(dilation_, data_format_, 'N') == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilation rates in the batch and depth dimensions."));
+
+ // TODO(yangzihao): Add CPU version of dilated conv 3D.
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, '0') == 1 &&
+ GetTensorDim(dilation_, data_format_, '1') == 1 &&
+ GetTensorDim(dilation_, data_format_, '2') == 1),
+ errors::InvalidArgument(
+ "Current CPU implementation does not yet support "
+ "dilation rates larger than 1."));
+
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(stride_, data_format_, 'C') == 1 &&
+ GetTensorDim(stride_, data_format_, 'N') == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const TensorShape& input_shape = input.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape filter_shape;
+ if (takes_shape_) {
+ const Tensor& filter_sizes = context->input(1);
+ OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
+ filter_sizes.vec<int32>(), &filter_shape));
+ } else {
+ filter_shape = context->input(1).shape();
+ }
+
+ ConvBackpropDimensions dims;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- padded_out_shape, &padded_output));
- Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
- strides[2], 1};
- functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
- eigen_strides, pad_dims, out_order, padded_output.tensor<T, 5>());
- const Tensor& padded_output_cref = padded_output;
-
- // For the backprop of the filter, we need to transpose the input.
- // The shape of input is
- // [batch, in_z, in_y, in_x, in_depth]
- // And we need to change it to
- // [in_z, in_y, in_x, batch, in_depth]
- Eigen::DSizes<Eigen::DenseIndex, 5> in_order{1, 2, 3, 0, 4};
- TensorShape in_shuffle_shape(
- {input_size[0], input_size[1], input_size[2], batch, in_depth});
- Tensor in_shuffle;
+ ConvBackpropComputeDimensions(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, stride_,
+ padding_, data_format_, &dims));
+
+ Tensor* filter_backprop;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- in_shuffle_shape, &in_shuffle));
- // No need for reversing this time.
- Eigen::array<bool, 5> no_reverse{false, false, false, false, false};
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), input.tensor<T, 5>(), in_order,
- no_reverse, in_shuffle.tensor<T, 5>());
- const Tensor& in_shuffle_cref = in_shuffle;
-
- // The output of the conv_3d would be
- // [out_depth, filter_size[2], filter_size[1], filter_size[0], in_depth]
- // and we need to shuffle it back to
- // [filter_size[2], filter_size[1], filter_size[0], in_depth, out_depth];
- // And we need to reverse the filter backprops.
- // So we need to allocate (sigh) yet another piece of memory to hold the
- // output.
- TensorShape filter_shuffle_shape(
- {out_depth, filter_size[0], filter_size[1], filter_size[2], in_depth});
- Tensor filter_shuffle;
- OP_REQUIRES_OK(
- context, context->allocate_temp(DataTypeToEnum<T>::v(),
- filter_shuffle_shape, &filter_shuffle));
- functor::CuboidConvolution<Device, T>()(
- context->eigen_device<Device>(), filter_shuffle.tensor<T, 5>(),
- padded_output_cref.tensor<T, 5>(), in_shuffle_cref.tensor<T, 5>(), 1, 1,
- 1, BrainPadding2EigenPadding(VALID));
-
- // Now copy the filter_backprop back to the destination.
- Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{1, 2, 3, 4, 0};
- Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
- const Tensor& filter_shuffle_cref = filter_shuffle;
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 5>(),
- filter_order, filter_rev_dims, filter_backprop->tensor<T, 5>());
+ context->allocate_output(0, filter_shape, &filter_backprop));
+
+ if (input_shape.num_elements() == 0) {
+ filter_backprop->template flat<T>().setZero();
+ return;
+ }
+
+ int64 top_pad_planes, bottom_pad_planes;
+ int64 top_pad_rows, bottom_pad_rows;
+ int64 left_pad_cols, right_pad_cols;
+
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[0].input_size,
+ dims.spatial_dims[0].filter_size,
+ dims.spatial_dims[0].stride, padding_,
+ &dims.spatial_dims[0].output_size,
+ &top_pad_planes, &bottom_pad_planes));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[1].input_size,
+ dims.spatial_dims[1].filter_size,
+ dims.spatial_dims[1].stride, padding_,
+ &dims.spatial_dims[1].output_size,
+ &top_pad_rows, &bottom_pad_rows));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[2].input_size,
+ dims.spatial_dims[2].filter_size,
+ dims.spatial_dims[2].stride, padding_,
+ &dims.spatial_dims[2].output_size,
+ &left_pad_cols, &right_pad_cols));
+
+ // TODO(ezhulenev): Extract work size and shard estimation to shared
+ // functions in conv_grad_ops, and update 2d convolution backprop.
+
+ // The total dimension size of each kernel.
+ const int64 filter_total_size =
+ dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
+ dims.spatial_dims[2].filter_size * dims.in_depth;
+ // The output image size is the spatial size of the output.
+ const int64 output_image_size = dims.spatial_dims[0].output_size *
+ dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size;
+
+ // Shard 'batch' images (volumes) into 'shard_size' groups of images
+ // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by
+ // dividing the L3 cache size ('target_working_set_size') by the matmul size
+ // of an individual image ('work_unit_size').
+
+ const auto cache_sizes = Eigen::internal::CacheSizes();
+ const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
+
+ // TODO(andydavis)
+ // *) Consider reducing 'target_working_set_size' if L3 is shared by
+ // other concurrently running tensorflow ops.
+ const size_t target_working_set_size = l3_cache_size / sizeof(T);
+
+ const int64 size_A = output_image_size * filter_total_size;
+
+ const int64 size_B = output_image_size * dims.out_depth;
+
+ const int64 size_C = filter_total_size * dims.out_depth;
+
+ const int64 work_unit_size = size_A + size_B + size_C;
+
+ const size_t shard_size =
+ (target_working_set_size + work_unit_size - 1) / work_unit_size;
+
+ // Total number of elements in all the tensors used by this kernel.
+ int64 total_tensor_elements = input_shape.num_elements() +
+ filter_shape.num_elements() +
+ out_backprop_shape.num_elements();
+
+ // Shape of the temporary workspace buffer.
+ TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
+ static_cast<int64>(output_image_size),
+ static_cast<int64>(filter_total_size)};
+ int64 col_buffer_elements = col_buffer_shape.num_elements();
+
+ // If the temporary allocation overhead is too large, fallback on Eigen
+ // implementation which requires much less memory.
+ int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
+ if (col_buffer_overhead > kMaxTempAllocationOverhead) {
+ VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: "
+ "col_buffer_overhead="
+ << col_buffer_overhead;
+
+ functor::CuboidConvolutionBackwardFilter<Device, T>()(
+ context->eigen_device<Device>(),
+ filter_backprop->tensor<T, 5>(), // filter_backward
+ input.tensor<T, 5>(), // input
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+
+ return;
+ }
+
+ Tensor col_buffer;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ col_buffer_shape, &col_buffer));
+
+ // The input offset corresponding to a single input image.
+ const int64 input_offset = dims.spatial_dims[0].input_size *
+ dims.spatial_dims[1].input_size *
+ dims.spatial_dims[2].input_size * dims.in_depth;
+ // The output offset corresponding to a single output image.
+ const int64 output_offset =
+ dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size * dims.out_depth;
+
+ const T* input_data = input.template flat<T>().data();
+ T* col_buffer_data = col_buffer.template flat<T>().data();
+ const T* out_backprop_data = out_backprop.template flat<T>().data();
+ T* filter_backprop_data = filter_backprop->template flat<T>().data();
+
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ TensorMap;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ ConstTensorMap;
+
+ TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth);
+ C.setZero();
+
+ // Initialize contraction dims (we need to transpose 'A' below).
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
+ contract_dims[0].first = 0;
+ contract_dims[0].second = 0;
+
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+
+ for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) {
+ const int shard_limit =
+ std::min(static_cast<int>(shard_size),
+ static_cast<int>(dims.batch_size) - image_id);
+
+ auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes,
+ &top_pad_rows, &left_pad_cols, &bottom_pad_planes,
+ &bottom_pad_rows, &right_pad_cols, &input_offset,
+ &size_A](int64 start, int64 limit) {
+ for (int shard_id = start; shard_id < limit; ++shard_id) {
+ const T* input_data_shard = input_data + shard_id * input_offset;
+ T* col_data_shard = col_buffer_data + shard_id * size_A;
+
+ // When we compute the gradient with respect to the filters, we need
+ // to do im2col to allow gemm-type computation.
+ Im2col<T>(input_data_shard, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ col_data_shard);
+ }
+ };
+ Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
+ size_A, shard);
+
+ ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
+ filter_total_size);
+ ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
+ dims.out_depth);
+
+ // Gradient with respect to filter.
+ C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
+
+ input_data += input_offset * shard_limit;
+ out_backprop_data += output_offset * shard_limit;
+ }
}
private:
@@ -421,21 +991,60 @@ class Conv3DBackpropFilterOp : public OpKernel {
Padding padding_;
TensorFormat data_format_;
bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp);
};
+// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
+// default Eigen implementation (at the cost of ~2x-8x peak memory usage).
+
#define REGISTER_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropFilterOp<CPUDevice, T>); \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
.Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
.TypeConstraint<T>("T"), \
Conv3DBackpropFilterOp<CPUDevice, T>);
-TF_CALL_half(REGISTER_CPU_KERNEL);
+
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
+// WARNING: Eigen::half is not trivially copyable and can't be used in
+// custom backprop filter kernel because of memcpy and memset in Im2col.
+#define REGISTER_CPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>);
+
+TF_CALL_half(REGISTER_CPU_KERNEL);
+#undef REGISTER_CPU_KERNEL
+
// GPU definitions of both ops.
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
@@ -445,7 +1054,8 @@ namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void TransformFilter<GPUDevice, T, int, 5>::operator()( \
- const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 5, int>::ConstTensor in, \
typename TTypes<T, 5, int>::Tensor out); \
template <> \
void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \
@@ -523,6 +1133,10 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& filter = context->input(1);
const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
@@ -531,7 +1145,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
} else {
input_shape = context->input(0).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensionsV2(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, dilation_,
+ stride_, padding_, data_format_, &dims));
+
Tensor* in_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
@@ -539,13 +1160,15 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
- if (filter_size[0] == 1 && filter_size[1] == 1 && filter_size[2] == 1 &&
- dilation_[0] == 1 && dilation_[1] == 1 && dilation_[2] == 1 &&
- stride_[0] == 1 && stride_[1] == 1 && stride_[2] == 1 &&
+ if (dims.filter_size(0) == 1 && dims.filter_size(1) == 1 &&
+ dims.filter_size(2) == 1 && dims.dilation(0) == 1 &&
+ dims.dilation(1) == 1 && dims.dilation(2) == 1 && dims.stride(0) == 1 &&
+ dims.stride(1) == 1 && dims.stride(2) == 1 &&
data_format_ == FORMAT_NHWC) {
- const uint64 m = batch * input_size[0] * input_size[1] * input_size[2];
- const uint64 k = out_depth;
- const uint64 n = in_depth;
+ const uint64 m = dims.batch_size * dims.input_size(0) *
+ dims.input_size(1) * dims.input_size(2);
+ const uint64 k = dims.out_depth;
+ const uint64 n = dims.in_depth;
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
out_backprop.template flat<T>().size());
@@ -567,13 +1190,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
", n=", n, ", k=", k));
}
return;
- } else if (filter_size[0] == input_size[0] &&
- filter_size[1] == input_size[1] &&
- filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
- data_format_ == FORMAT_NHWC) {
- const uint64 m = batch;
- const uint64 k = out_depth;
- const uint64 n = input_size[0] * input_size[1] * input_size[2] * in_depth;
+ } else if (dims.filter_size(0) == dims.input_size(0) &&
+ dims.filter_size(1) == dims.input_size(1) &&
+ dims.filter_size(2) == dims.input_size(2) &&
+ padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
+ const uint64 m = dims.batch_size;
+ const uint64 k = dims.out_depth;
+ const uint64 n = dims.input_size(0) * dims.input_size(1) *
+ dims.input_size(2) * dims.in_depth;
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
out_backprop.template flat<T>().size());
@@ -597,65 +1221,59 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
return;
}
- int padding_rows = 0, padding_cols = 0, padding_planes = 0;
-
- if (padding_ == Padding::SAME) {
- padding_planes = std::max<int>(
- 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
- padding_cols = std::max<int>(
- 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
- padding_rows = std::max<int>(
- 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
- }
+ int padding_planes = dims.SpatialPadding(padding_, 0);
+ int padding_rows = dims.SpatialPadding(padding_, 1);
+ int padding_cols = dims.SpatialPadding(padding_, 2);
+ const bool planes_odd = (padding_planes % 2 != 0);
const bool rows_odd = (padding_rows % 2 != 0);
const bool cols_odd = (padding_cols % 2 != 0);
- const bool planes_odd = (padding_planes % 2 != 0);
TensorShape compatible_input_shape;
if (rows_odd || cols_odd || planes_odd) {
// cuDNN only supports the same amount of padding on both sides.
compatible_input_shape = {
- batch,
- in_depth,
- input_size[0] + planes_odd,
- input_size[1] + rows_odd,
- input_size[2] + cols_odd,
+ dims.batch_size,
+ dims.in_depth,
+ dims.input_size(0) + planes_odd,
+ dims.input_size(1) + rows_odd,
+ dims.input_size(2) + cols_odd,
};
} else {
- compatible_input_shape = {batch, in_depth, input_size[0], input_size[1],
- input_size[2]};
+ compatible_input_shape = {dims.batch_size, dims.in_depth,
+ dims.input_size(0), dims.input_size(1),
+ dims.input_size(2)};
}
CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
<< ", " << padding_planes << ")";
se::dnn::BatchDescriptor input_desc(3);
- input_desc.set_count(batch)
+ input_desc.set_count(dims.batch_size)
.set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
.set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
.set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
- .set_feature_map_count(in_depth)
+ .set_feature_map_count(dims.in_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::BatchDescriptor output_desc(3);
- output_desc.set_count(batch)
- .set_spatial_dim(DimIndex::X, output_cols)
- .set_spatial_dim(DimIndex::Y, output_rows)
- .set_spatial_dim(DimIndex::Z, output_planes)
- .set_feature_map_count(out_depth)
+ output_desc.set_count(dims.batch_size)
+ .set_spatial_dim(DimIndex::X, dims.output_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.output_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.output_size(0))
+ .set_feature_map_count(dims.out_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::FilterDescriptor filter_desc(3);
- filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
- .set_spatial_dim(DimIndex::Y, filter_size[1])
- .set_spatial_dim(DimIndex::Z, filter_size[0])
- .set_input_feature_map_count(in_depth)
- .set_output_feature_map_count(out_depth);
+ filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
+ .set_input_feature_map_count(dims.in_depth)
+ .set_output_feature_map_count(dims.out_depth);
se::dnn::ConvolutionDescriptor conv_desc(3);
- conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
- .set_dilation_rate(DimIndex::Y, dilations[1])
- .set_dilation_rate(DimIndex::Z, dilations[0])
- .set_filter_stride(DimIndex::X, strides[2])
- .set_filter_stride(DimIndex::Y, strides[1])
- .set_filter_stride(DimIndex::Z, strides[0])
+ conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
+ .set_dilation_rate(DimIndex::Y, dims.dilation(1))
+ .set_dilation_rate(DimIndex::Z, dims.dilation(0))
+ .set_filter_stride(DimIndex::X, dims.stride(2))
+ .set_filter_stride(DimIndex::Y, dims.stride(1))
+ .set_filter_stride(DimIndex::Z, dims.stride(0))
.set_zero_padding(DimIndex::X, padding_cols / 2)
.set_zero_padding(DimIndex::Y, padding_rows / 2)
.set_zero_padding(DimIndex::Z, padding_planes / 2);
@@ -664,20 +1282,23 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
Tensor transformed_filter;
OP_REQUIRES_OK(
context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({out_depth, in_depth, filter_size[0],
- filter_size[1], filter_size[2]}),
- &transformed_filter));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
+ dims.filter_size(1), dims.filter_size(2)}),
+ &transformed_filter));
functor::TransformFilter<GPUDevice, T, int, 5>()(
- context->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
+ context->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter.tensor<T, 5>()),
To32Bit(transformed_filter.tensor<T, 5>()));
// Shape: batch, filters, z, y, x.
Tensor transformed_out_backprop;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
- output_cols};
- if (out_depth > 1) {
+ TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
+ dims.output_size(0), dims.output_size(1),
+ dims.output_size(2)};
+ if (dims.out_depth > 1) {
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<T>::value, nchw_shape,
&transformed_out_backprop));
@@ -713,14 +1334,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
const int device_id = stream->parent()->device_ordinal();
DataType dtype = context->input(0).dtype();
const ConvParameters conv_parameters = {
- batch,
- in_depth,
- {{input_size[0], input_size[1], input_size[2]}},
+ dims.batch_size,
+ dims.in_depth,
+ {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
FORMAT_NCHW,
- out_depth,
- {{filter_size[0], filter_size[1], filter_size[2]}},
- {{dilations[0], dilations[1], dilations[2]}},
- {{strides[0], strides[1], strides[2]}},
+ dims.out_depth,
+ {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
+ {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
+ {{dims.stride(0), dims.stride(1), dims.stride(2)}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
device_id,
@@ -799,10 +1420,11 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
if (rows_odd || cols_odd || planes_odd) {
Tensor in_backprop_remove_padding;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- {batch, in_depth, input_size[0],
- input_size[1], input_size[2]},
- &in_backprop_remove_padding));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ {dims.batch_size, dims.in_depth, dims.input_size(0),
+ dims.input_size(1), dims.input_size(2)},
+ &in_backprop_remove_padding));
// Remove the padding for odd spatial dimensions.
functor::PadInput<GPUDevice, T, int, 5>()(
@@ -896,6 +1518,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape filter_shape;
if (takes_shape_) {
const Tensor& filter_sizes = context->input(1);
@@ -905,7 +1531,12 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
filter_shape = context->input(1).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensionsV2(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, dilation_,
+ stride_, padding_, data_format_, &dims));
Tensor* filter_backprop;
OP_REQUIRES_OK(context,
@@ -914,13 +1545,15 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
- if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 &&
- dilations[2] == 1 && dilations[1] == 1 && dilations[0] == 1 &&
- strides[2] == 1 && strides[1] == 1 && strides[0] == 1 &&
+ if (dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
+ dims.filter_size(0) == 1 && dims.dilation(2) == 1 &&
+ dims.dilation(1) == 1 && dims.dilation(0) == 1 && dims.stride(2) == 1 &&
+ dims.stride(1) == 1 && dims.stride(0) == 1 &&
data_format_ == FORMAT_NHWC) {
- const uint64 m = in_depth;
- const uint64 k = batch * input_size[1] * input_size[2] * input_size[0];
- const uint64 n = out_depth;
+ const uint64 m = dims.in_depth;
+ const uint64 k = dims.batch_size * dims.input_size(1) *
+ dims.input_size(2) * dims.input_size(0);
+ const uint64 n = dims.out_depth;
// The shape of output backprop is
// [batch, out_z, out_y, out_x, out_depth]
@@ -951,13 +1584,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
", n=", n, ", k=", k));
}
return;
- } else if (filter_size[0] == input_size[0] &&
- filter_size[1] == input_size[1] &&
- filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
- data_format_ == FORMAT_NHWC) {
- const uint64 m = input_size[0] * input_size[1] * input_size[2] * in_depth;
- const uint64 k = batch;
- const uint64 n = out_depth;
+ } else if (dims.filter_size(0) == dims.input_size(0) &&
+ dims.filter_size(1) == dims.input_size(1) &&
+ dims.filter_size(2) == dims.input_size(2) &&
+ padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
+ const uint64 m = dims.input_size(0) * dims.input_size(1) *
+ dims.input_size(2) * dims.in_depth;
+ const uint64 k = dims.batch_size;
+ const uint64 n = dims.out_depth;
auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
input.template flat<T>().size());
@@ -979,30 +1613,24 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
return;
}
- int padding_rows = 0, padding_cols = 0, padding_planes = 0;
-
- if (padding_ == Padding::SAME) {
- padding_planes = std::max<int>(
- 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
- padding_cols = std::max<int>(
- 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
- padding_rows = std::max<int>(
- 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
- }
- bool rows_odd = (padding_rows % 2 != 0);
- bool cols_odd = (padding_cols % 2 != 0);
- bool planes_odd = (padding_planes % 2 != 0);
+ int padding_planes = dims.SpatialPadding(padding_, 0);
+ int padding_rows = dims.SpatialPadding(padding_, 1);
+ int padding_cols = dims.SpatialPadding(padding_, 2);
+ const bool planes_odd = (padding_planes % 2 != 0);
+ const bool rows_odd = (padding_rows % 2 != 0);
+ const bool cols_odd = (padding_cols % 2 != 0);
Tensor compatible_input;
if (rows_odd || cols_odd || planes_odd) {
- OP_REQUIRES_OK(context, context->allocate_temp(
- DataTypeToEnum<T>::value,
- ShapeFromFormat(data_format_, batch,
- {{input_size[0] + planes_odd,
- input_size[1] + rows_odd,
- input_size[2] + cols_odd}},
- in_depth),
- &compatible_input));
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ ShapeFromFormat(data_format_, dims.batch_size,
+ {{dims.input_size(0) + planes_odd,
+ dims.input_size(1) + rows_odd,
+ dims.input_size(2) + cols_odd}},
+ dims.in_depth),
+ &compatible_input));
functor::PadInput<GPUDevice, T, int, 5>()(
context->template eigen_device<GPUDevice>(),
To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
@@ -1016,35 +1644,35 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
<< ", " << padding_planes << ")";
se::dnn::BatchDescriptor input_desc(3);
- input_desc.set_count(batch)
+ input_desc.set_count(dims.batch_size)
.set_spatial_dim(DimIndex::X,
GetTensorDim(compatible_input, data_format_, '2'))
.set_spatial_dim(DimIndex::Y,
GetTensorDim(compatible_input, data_format_, '1'))
.set_spatial_dim(DimIndex::Z,
GetTensorDim(compatible_input, data_format_, '0'))
- .set_feature_map_count(in_depth)
+ .set_feature_map_count(dims.in_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::BatchDescriptor output_desc(3);
- output_desc.set_count(batch)
- .set_spatial_dim(DimIndex::X, output_cols)
- .set_spatial_dim(DimIndex::Y, output_rows)
- .set_spatial_dim(DimIndex::Z, output_planes)
- .set_feature_map_count(out_depth)
+ output_desc.set_count(dims.batch_size)
+ .set_spatial_dim(DimIndex::X, dims.output_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.output_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.output_size(0))
+ .set_feature_map_count(dims.out_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::FilterDescriptor filter_desc(3);
- filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
- .set_spatial_dim(DimIndex::Y, filter_size[1])
- .set_spatial_dim(DimIndex::Z, filter_size[0])
- .set_input_feature_map_count(in_depth)
- .set_output_feature_map_count(out_depth);
+ filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
+ .set_input_feature_map_count(dims.in_depth)
+ .set_output_feature_map_count(dims.out_depth);
se::dnn::ConvolutionDescriptor conv_desc(3);
- conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
- .set_dilation_rate(DimIndex::Y, dilations[1])
- .set_dilation_rate(DimIndex::Z, dilations[0])
- .set_filter_stride(DimIndex::X, strides[2])
- .set_filter_stride(DimIndex::Y, strides[1])
- .set_filter_stride(DimIndex::Z, strides[0])
+ conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
+ .set_dilation_rate(DimIndex::Y, dims.dilation(1))
+ .set_dilation_rate(DimIndex::Z, dims.dilation(0))
+ .set_filter_stride(DimIndex::X, dims.stride(2))
+ .set_filter_stride(DimIndex::Y, dims.stride(1))
+ .set_filter_stride(DimIndex::Z, dims.stride(0))
.set_zero_padding(DimIndex::X, padding_cols / 2)
.set_zero_padding(DimIndex::Y, padding_rows / 2)
.set_zero_padding(DimIndex::Z, padding_planes / 2);
@@ -1052,19 +1680,21 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
Tensor pre_transformed_filter_backprop;
OP_REQUIRES_OK(
context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({out_depth, in_depth, filter_size[0],
- filter_size[1], filter_size[2]}),
- &pre_transformed_filter_backprop));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
+ dims.filter_size(1), dims.filter_size(2)}),
+ &pre_transformed_filter_backprop));
Tensor transformed_out_backprop;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
- output_cols};
+ TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
+ dims.output_size(0), dims.output_size(1),
+ dims.output_size(2)};
OP_REQUIRES_OK(
context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
&transformed_out_backprop));
- if (out_depth > 1) {
+ if (dims.out_depth > 1) {
functor::NHWCToNCHW<GPUDevice, T, 5>()(
context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
transformed_out_backprop.tensor<T, 5>());
@@ -1076,10 +1706,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
}
Tensor transformed_input;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, in_depth, compatible_input.dim_size(1),
- compatible_input.dim_size(2),
- compatible_input.dim_size(3)};
- if (in_depth > 1) {
+ TensorShape nchw_shape = {
+ dims.batch_size, dims.in_depth, compatible_input.dim_size(1),
+ compatible_input.dim_size(2), compatible_input.dim_size(3)};
+ if (dims.in_depth > 1) {
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<T>::value,
nchw_shape, &transformed_input));
@@ -1110,14 +1740,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
const int device_id = stream->parent()->device_ordinal();
DataType dtype = input.dtype();
const ConvParameters conv_parameters = {
- batch,
- in_depth,
- {{input_size[0], input_size[1], input_size[2]}},
+ dims.batch_size,
+ dims.in_depth,
+ {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
FORMAT_NCHW,
- out_depth,
- {{filter_size[0], filter_size[1], filter_size[2]}},
- {{dilations[0], dilations[1], dilations[2]}},
- {{strides[0], strides[1], strides[2]}},
+ dims.out_depth,
+ {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
+ {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
+ {{dims.stride(0), dims.stride(1), dims.stride(2)}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
device_id,
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index ef692418d6..717a9f40a9 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -680,9 +680,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
TensorShape({filter.dim_size(3), filter.dim_size(2),
filter.dim_size(0), filter.dim_size(1)}),
&transformed_filter));
-
functor::TransformFilter<GPUDevice, T, int, 4>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
+ ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter.tensor<T, 4>()),
To32Bit(transformed_filter.tensor<T, 4>()));
Tensor transformed_output;
@@ -731,9 +731,15 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
if (cudnn_use_autotune &&
!AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
std::vector<AlgorithmDesc> algorithms;
- CHECK(stream->parent()->GetConvolveAlgorithms(
- conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
- &algorithms));
+ OP_REQUIRES(
+ ctx,
+ stream->parent()->GetConvolveAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
+ stream->parent()),
+ &algorithms),
+ errors::Unknown("Failed to get convolution algorithm. This is probably "
+ "because cuDNN failed to initialize, so try looking to "
+ "see if a warning log message was printed above."));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
@@ -823,7 +829,8 @@ namespace functor {
extern template struct MatMulConvFunctor<GPUDevice, T>; \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index a1eed4e68c..83df4dce38 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -386,7 +386,8 @@ struct LaunchConvOp<GPUDevice, T> {
// filter: [x, y, z, in, out]
// t_filter: [out, in, x, y, z]
functor::TransformFilter<GPUDevice, T, int, 5>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
+ ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter.tensor<T, 5>()),
To32Bit(transformed_filter.tensor<T, 5>()));
Tensor transformed_output;
@@ -434,10 +435,16 @@ struct LaunchConvOp<GPUDevice, T> {
if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
std::vector<AlgorithmDesc> algorithms;
- CHECK(stream->parent()->GetConvolveAlgorithms(
- conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
- stream->parent()),
- &algorithms));
+ OP_REQUIRES(ctx,
+ stream->parent()->GetConvolveAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
+ stream->parent()),
+ &algorithms),
+ errors::Unknown(
+ "Failed to get convolution algorithm. This is probably "
+ "because cuDNN failed to initialize, so try looking to "
+ "see if a warning log message was printed above."));
+
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
@@ -514,7 +521,8 @@ namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void TransformFilter<GPUDevice, T, int, 5>::operator()( \
- const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 5, int>::ConstTensor in, \
typename TTypes<T, 5, int>::Tensor out); \
template <> \
void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index afc611f277..21d135decd 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -142,8 +142,12 @@ class ConvParameters {
template <typename T>
bool ShouldIncludeWinogradNonfusedAlgo(
se::StreamExecutor* stream_exec) const {
+ auto* dnn_support = stream_exec->AsDnn();
+ if (!dnn_support) {
+ return false;
+ }
// Skip this check for cuDNN 7 and newer.
- auto version = stream_exec->AsDnn()->GetVersion();
+ auto version = dnn_support->GetVersion();
if (version.ok() && version.ValueOrDie().major_version() >= 7) {
return true;
}
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index a5fa48f85e..46167db3a2 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -170,51 +170,33 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex(
return tensor_index;
}
-// A Cuda custom kernel that swaps dimension-0 and dimension-2 of a 3D tensor.
-template <typename T, bool conjugate = false>
-__global__ void SwapDimension0And2InTensor3Simple(int nthreads, const T* input,
- Dimension<3> input_dims,
- T* output) {
- Dimension<3> output_dims;
- output_dims[0] = input_dims[2];
- output_dims[1] = input_dims[1];
- output_dims[2] = input_dims[0];
-
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int output_index = index;
-
- Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
-
- Index<3> input_tensor_index;
- input_tensor_index[0] = output_tensor_index[2];
- input_tensor_index[1] = output_tensor_index[1];
- input_tensor_index[2] = output_tensor_index[0];
-
- int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
-
- output[output_index] =
- maybe_conj<T, conjugate>::run(ldg(input + input_index));
- }
-}
-
-// A Cuda custom kernel that swaps dimension-1 and dimension-2 of a 3D tensor.
-template <typename T, bool conjugate = false>
-__global__ void SwapDimension1And2InTensor3Simple(int nthreads, const T* input,
- Dimension<3> input_dims,
- T* output) {
+// A simple CUDA custom kernel to shuffle dimensions of a 3D tensor according to
+// the given shuffle permutation in template parameters. Shuffle permutation
+// <sp0, sp1, sp2> shuffles dimensions such that input dimension 0 goes to sp0,
+// 1 goes to sp1 and 2 goes to sp2. For example, shuffle permutation <2, 0, 1>
+// will populate output so that input[x][y][z] is equal to (*output)[y][z][x].
+//
+// Requires that nthreads is equal to the total number of elements in the input
+// tensor.
+template <typename T, int sp0, int sp1, int sp2, bool conjugate = false>
+__global__ void ShuffleInTensor3Simple(int nthreads, const T* input,
+ Dimension<3> input_dims, T* output) {
Dimension<3> output_dims;
- output_dims[0] = input_dims[0];
- output_dims[1] = input_dims[2];
- output_dims[2] = input_dims[1];
-
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int output_index = index;
+ output_dims[sp0] = input_dims[0];
+ output_dims[sp1] = input_dims[1];
+ output_dims[sp2] = input_dims[2];
+
+ // Iterate over output as opposed to iterating over input for better
+ // performance. Iterating over output will generate sequential writes and
+ // random reads that performs better compared to sequential reads and random
+ // writes.
+ CUDA_1D_KERNEL_LOOP(output_index, nthreads) {
Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
Index<3> input_tensor_index;
- input_tensor_index[0] = output_tensor_index[0];
- input_tensor_index[1] = output_tensor_index[2];
- input_tensor_index[2] = output_tensor_index[1];
+ input_tensor_index[0] = output_tensor_index[sp0];
+ input_tensor_index[1] = output_tensor_index[sp1];
+ input_tensor_index[2] = output_tensor_index[sp2];
int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
@@ -439,7 +421,7 @@ __global__ void PadInputCustomKernelNCHW(int nthreads, const T* input,
template <typename T, int NDIMS>
struct TransformFilter<GPUDevice, T, int, NDIMS> {
typedef GPUDevice Device;
- void operator()(const Device& d,
+ void operator()(const Device& d, FilterTensorFormat dst_filter_format,
typename TTypes<T, NDIMS, int>::ConstTensor in,
typename TTypes<T, NDIMS, int>::Tensor out) {
Dimension<3> combined_dims;
@@ -450,13 +432,18 @@ struct TransformFilter<GPUDevice, T, int, NDIMS> {
combined_dims[1] = in.dimension(NDIMS - 2); // input filters
combined_dims[2] = in.dimension(NDIMS - 1); // output filters
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
- SwapDimension0And2InTensor3Simple<T>
+
+ CHECK(dst_filter_format == FORMAT_OIHW)
+ << "Unsupported output layout: " << ToString(dst_filter_format);
+
+ ShuffleInTensor3Simple<T, 2, 1, 0>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in.data(), combined_dims, out.data());
}
};
-// Converts Cudnn filter format back to TensorFlow filter format.
+// Converts Cudnn filter format OIHW back to TensorFlow filter format HWIO.
+// TODO(hinsu): Support reverse transformation from filter format OHWI as well.
template <typename T, int NDIMS>
struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
typedef GPUDevice Device;
@@ -470,7 +457,7 @@ struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
combined_dims[2] *= in.dimension(i);
}
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
- SwapDimension0And2InTensor3Simple<T>
+ ShuffleInTensor3Simple<T, 2, 1, 0>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in.data(), combined_dims, out.data());
}
@@ -937,7 +924,7 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
} else {
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
- SwapDimension1And2InTensor3Simple<T, conjugate>
+ ShuffleInTensor3Simple<T, 0, 2, 1, conjugate>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, input, input_dims, output);
}
@@ -969,7 +956,7 @@ struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> {
static_cast<int>(combined_dims[2])};
size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
CudaLaunchConfig config = GetCudaLaunchConfig(total_size, d);
- SwapDimension0And2InTensor3Simple<T, conjugate>
+ ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in, input_dims, out);
}
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 8d867455e7..b3c359010d 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -51,6 +51,7 @@ cc_library(
hdrs = ["captured_function.h"],
deps = [
":dataset",
+ ":single_threaded_executor",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -61,6 +62,42 @@ cc_library(
)
cc_library(
+ name = "single_threaded_executor",
+ srcs = ["single_threaded_executor.cc"],
+ hdrs = ["single_threaded_executor.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:lib",
+ ],
+ alwayslink = 1,
+)
+
+tf_cc_test(
+ name = "single_threaded_executor_test",
+ srcs = ["single_threaded_executor_test.cc"],
+ deps = [
+ ":single_threaded_executor",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:array",
+ "//tensorflow/core/kernels:control_flow_ops",
+ "//tensorflow/core/kernels:function_ops",
+ "//tensorflow/core/kernels:math",
+ "//tensorflow/core/kernels:random_ops",
+ "//tensorflow/core/kernels:state",
+ ],
+)
+
+cc_library(
name = "window_dataset",
srcs = ["window_dataset.cc"],
hdrs = ["window_dataset.h"],
@@ -481,8 +518,7 @@ tf_kernel_library(
":dataset",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
+ "//tensorflow/core:graph",
],
)
@@ -505,8 +541,7 @@ tf_kernel_library(
":dataset",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
+ "//tensorflow/core:graph",
],
)
@@ -640,6 +675,19 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "model_dataset_op",
+ srcs = ["model_dataset_op.cc"],
+ deps = [
+ ":dataset",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
name = "dataset_ops",
srcs = ["dataset_ops.cc"],
deps = [
@@ -673,6 +721,7 @@ tf_kernel_library(
":map_and_batch_dataset_op",
":map_dataset_op",
":map_defun_op",
+ ":model_dataset_op",
":optimize_dataset_op",
":optional_ops",
":padded_batch_dataset_op",
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index f9b5353724..d1db1d7bec 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -117,6 +117,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
@@ -241,5 +242,5 @@ REGISTER_KERNEL_BUILDER(Name("BatchDatasetV2").Device(DEVICE_CPU),
BatchDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index 6ca0bcd37d..34c6c86538 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level description of
@@ -69,7 +69,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
- new FileIterator({this, strings::StrCat(prefix, "::FileIterator")}));
+ new FileIterator({this, strings::StrCat(prefix, "::FileCache")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -553,7 +553,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new MemoryIterator(
- {this, strings::StrCat(prefix, "::MemoryIterator")}, cache_));
+ {this, strings::StrCat(prefix, "::MemoryCache")}, cache_));
}
const DataTypeVector& output_dtypes() const override {
@@ -891,5 +891,5 @@ REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU),
CacheDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index abdf6ee4e8..8a5d30a27c 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -17,33 +17,101 @@ limitations under the License.
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
+namespace data {
+
+namespace {
+
+// Simplistic implementation of the `StepStatsCollectorInterface` that only
+// cares about collecting the CPU time needed to execute a captured function.
+class SimpleStepStatsCollector : public StepStatsCollectorInterface {
+ public:
+ void IncrementProcessingTime(int64 delta) {
+ mutex_lock l(mu_);
+ processing_time_ += delta;
+ }
+
+ NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override {
+ return new SimpleNodeExecStats(this);
+ }
+
+ string ReportAllocsOnResourceExhausted(const string& err) override {
+ return "";
+ }
+
+ int64 processing_time() {
+ tf_shared_lock l(mu_);
+ return processing_time_;
+ }
+
+ private:
+ class SimpleNodeExecStats : public NodeExecStatsInterface {
+ public:
+ explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector)
+ : step_stats_collector_(step_stats_collector) {}
+
+ void Done(const string& device) override {
+ step_stats_collector_->IncrementProcessingTime(end_time_ns_ -
+ start_time_ns_);
+ delete this;
+ }
+
+ void RecordExecutorStarted() override {
+ start_time_ns_ = Env::Default()->NowNanos();
+ }
+
+ void RecordComputeStarted() override {}
+
+ void RecordComputeEnded() override {}
+
+ void RecordExecutorEnded() override {
+ end_time_ns_ = Env::Default()->NowNanos();
+ }
+
+ void SetMemory(OpKernelContext* ctx) override {}
+
+ void SetOutput(int slot, const Tensor* tensor) override {}
+
+ void SetReferencedTensors(const TensorReferenceVector& tensors) override {}
+
+ void SetScheduled(int64 nanos) override {}
+
+ private:
+ int64 start_time_ns_ = 0;
+ int64 end_time_ns_ = 0;
+ SimpleStepStatsCollector* step_stats_collector_; // Not owned.
+ };
+
+ mutex mu_;
+ int64 processing_time_ GUARDED_BY(mu_) = 0;
+};
+
+} // namespace
/* static */
Status CapturedFunction::Create(
- const NameAttrList& func, std::vector<Tensor> captured_inputs,
+ const NameAttrList& func, OpKernelContext* ctx, const string& argument,
std::unique_ptr<CapturedFunction>* out_function) {
- out_function->reset(new CapturedFunction(func, std::move(captured_inputs)));
- return Status::OK();
+ return CapturedFunction::Create(func, ctx, argument, true, out_function);
}
-/* static */
Status CapturedFunction::Create(
const NameAttrList& func, OpKernelContext* ctx, const string& argument,
+ bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction>* out_function) {
- OpInputList argument_inputs;
- TF_RETURN_IF_ERROR(ctx->input_list(argument, &argument_inputs));
- std::vector<Tensor> arguments_t;
- arguments_t.reserve(argument_inputs.size());
- for (const Tensor& t : argument_inputs) {
- arguments_t.push_back(t);
- }
- return CapturedFunction::Create(func, std::move(arguments_t), out_function);
+ OpInputList inputs;
+ TF_RETURN_IF_ERROR(ctx->input_list(argument, &inputs));
+ std::vector<Tensor> arguments(inputs.begin(), inputs.end());
+ *out_function = WrapUnique(new CapturedFunction(func, std::move(arguments),
+ use_inter_op_parallelism));
+ return Status::OK();
}
CapturedFunction::~CapturedFunction() {
@@ -272,6 +340,9 @@ Status CapturedFunction::Instantiate(IteratorContext* ctx) {
inst_opts.overlay_lib = ctx->function_library().get();
inst_opts.state_handle = std::to_string(random::New64());
inst_opts.create_kernels_eagerly = true;
+ if (!use_inter_op_parallelism_) {
+ inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
+ }
Status s = (lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()),
inst_opts, &f_handle_));
TF_RETURN_IF_ERROR(s);
@@ -345,7 +416,8 @@ Status CapturedFunction::RunInstantiated(const std::vector<Tensor>& args,
void CapturedFunction::RunAsync(IteratorContext* ctx,
std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
- FunctionLibraryRuntime::DoneCallback done) {
+ FunctionLibraryRuntime::DoneCallback done,
+ const string& prefix) {
// NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
// be deleted before `done` is called. Take care not to capture `ctx` in any
// code that may execute asynchronously in this function.
@@ -355,17 +427,17 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
done(s);
return;
}
- auto frame =
- new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_);
+ std::shared_ptr<OwnedArgsCallFrame> frame(
+ new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_));
FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id();
ResourceMgr* resource_mgr = ctx->lib()->device()->resource_manager();
- auto step_container = new ScopedStepContainer(
+ std::shared_ptr<ScopedStepContainer> step_container(new ScopedStepContainer(
f_opts.step_id, [resource_mgr](const string& name) {
resource_mgr->Cleanup(name).IgnoreError();
- });
- f_opts.step_container = step_container;
+ }));
+ f_opts.step_container = step_container.get();
f_opts.runner = ctx->runner();
if (ctx->lib()->device()->device_type() != DEVICE_CPU) {
f_opts.create_rendezvous = true;
@@ -376,32 +448,55 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
// (such as queue kernels) that depend on the non-nullness of
// `OpKernelContext::cancellation_manager()`, but additional effort
// will be required to plumb it through the `IteratorContext`.
- auto c_mgr = new CancellationManager;
- f_opts.cancellation_manager = c_mgr;
-
- tf_shared_lock l(mu_);
- ctx->lib()->Run(f_opts, handle, frame,
- std::bind(
- [rets, step_container, c_mgr, frame](
- FunctionLibraryRuntime::DoneCallback done,
- // Begin unbound arguments.
- Status s) {
- delete step_container;
- delete c_mgr;
- if (s.ok()) {
- s = frame->ConsumeRetvals(rets);
- }
- delete frame;
- done(s);
- },
- std::move(done), std::placeholders::_1));
+ std::shared_ptr<CancellationManager> c_mgr(new CancellationManager);
+ f_opts.cancellation_manager = c_mgr.get();
+ std::shared_ptr<SimpleStepStatsCollector> stats_collector;
+ std::shared_ptr<model::Node> node;
+ if (ctx->model()) {
+ node = ctx->model()->LookupNode(prefix);
+ if (node) {
+ stats_collector = MakeUnique<SimpleStepStatsCollector>();
+ }
+ }
+ f_opts.stats_collector = stats_collector.get();
+
+ OwnedArgsCallFrame* raw_frame = frame.get();
+ auto callback = std::bind(
+ [rets](const std::shared_ptr<CancellationManager>& c_mgr,
+ const FunctionLibraryRuntime::DoneCallback& done,
+ const std::shared_ptr<OwnedArgsCallFrame>& frame,
+ const std::shared_ptr<model::Node>& node,
+ const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
+ const std::shared_ptr<ScopedStepContainer>& step_container,
+ // Begin unbound arguments.
+ Status s) {
+ if (s.ok()) {
+ s = frame->ConsumeRetvals(rets);
+ }
+ if (node) {
+ node->add_processing_time(stats_collector->processing_time());
+ node->start_work();
+ }
+ done(s);
+ if (node) {
+ node->stop_work();
+ }
+ },
+ std::move(c_mgr), std::move(done), std::move(frame), std::move(node),
+ std::move(stats_collector), std::move(step_container),
+ std::placeholders::_1);
+
+ ctx->lib()->Run(f_opts, handle, raw_frame, std::move(callback));
}
CapturedFunction::CapturedFunction(const NameAttrList& func,
- std::vector<Tensor> captured_inputs)
+ std::vector<Tensor> captured_inputs,
+ bool use_inter_op_parallelism)
: func_(func),
lib_(nullptr),
f_handle_(kInvalidHandle),
- captured_inputs_(std::move(captured_inputs)) {}
+ captured_inputs_(std::move(captured_inputs)),
+ use_inter_op_parallelism_(use_inter_op_parallelism) {}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index c95f2b1c01..a10376bf97 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -32,6 +32,8 @@ class Device;
class OpKernelContext;
class ResourceMgr;
+namespace data {
+
// A `CapturedFunction` encapsulates a TensorFlow function and all of
// the runtime support required to execute it.
//
@@ -40,18 +42,19 @@ class ResourceMgr;
// context.
class CapturedFunction {
public:
- // Creates a new instance from a list of named attributes and captured inputs.
- //
- // NOTE(mrry): The `captured_inputs` are passed by value. For
- // efficiency, you are recommended to move this argument into the call.
- static Status Create(const NameAttrList& func,
- std::vector<Tensor> captured_inputs,
+ // Creates a new instance using a list of named attributes, fetching captured
+ // inputs from a context argument.
+ static Status Create(const NameAttrList& func, OpKernelContext* ctx,
+ const string& argument,
std::unique_ptr<CapturedFunction>* out_function);
// Creates a new instance using a list of named attributes, fetching captured
// inputs from a context argument.
+ //
+ // If `use_inter_op_parallelism` is false, the runtime may use an executor
+ // that is optimized for small functions.
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
- const string& argument,
+ const string& argument, bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction>* out_function);
~CapturedFunction();
@@ -93,7 +96,8 @@ class CapturedFunction {
// in order to be able to deallocate them as early as possible.
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
- FunctionLibraryRuntime::DoneCallback done);
+ FunctionLibraryRuntime::DoneCallback done,
+ const string& prefix);
// Returns the named list of function arguments.
const NameAttrList& func() { return func_; }
@@ -114,7 +118,8 @@ class CapturedFunction {
private:
CapturedFunction(const NameAttrList& func,
- std::vector<Tensor> captured_inputs);
+ std::vector<Tensor> captured_inputs,
+ bool use_inter_op_parallelism);
Status GetHandle(IteratorContext* ctx,
FunctionLibraryRuntime::Handle* out_handle);
@@ -126,10 +131,17 @@ class CapturedFunction {
const std::vector<Tensor> captured_inputs_;
DataTypeSlice ret_types_;
std::function<void(std::function<void()>)> captured_runner_ = nullptr;
+ const bool use_inter_op_parallelism_;
TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction);
};
+} // namespace data
+
+// TODO(b/114112161): Remove these aliases when all users have moved over to the
+// `tensorflow::data` namespace.
+using data::CapturedFunction;
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
index c361a9adcb..a04f150e71 100644
--- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc
+++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -195,5 +195,5 @@ REGISTER_KERNEL_BUILDER(Name("ConcatenateDataset").Device(DEVICE_CPU),
ConcatenateDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc
index c71d027f23..bd1ccd5b5d 100644
--- a/tensorflow/core/kernels/data/dataset_ops.cc
+++ b/tensorflow/core/kernels/data/dataset_ops.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
+namespace data {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
@@ -48,4 +49,5 @@ class DatasetToGraphOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU),
DatasetToGraphOp);
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index d85ef1cbab..e7ac368ae3 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -17,8 +17,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
namespace tensorflow {
-
-namespace dataset {
+namespace data {
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
@@ -45,6 +44,5 @@ Status MakeIteratorFromInputElement(
ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
}
-} // namespace dataset
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 6c4191c2be..234856ea39 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -20,16 +20,14 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
-namespace dataset {
+namespace data {
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator);
-} // namespace dataset
-
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
index 9770bc025d..237511a07d 100644
--- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -301,5 +301,5 @@ REGISTER_KERNEL_BUILDER(Name("DenseToSparseBatchDataset").Device(DEVICE_CPU),
DenseToSparseBatchDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
index ce577397c5..a7e3a56727 100644
--- a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -166,5 +166,5 @@ REGISTER_KERNEL_BUILDER(Name("FilterByLastComponentDataset").Device(DEVICE_CPU),
FilterByLastComponentDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index bbce001eaf..19c35f94a6 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -37,14 +37,6 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
FunctionLibraryRuntime::Handle pred_handle;
OP_REQUIRES_OK(ctx,
ctx->function_library()->Instantiate(
@@ -61,9 +53,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
Node* ret_node = pred_body->ret_nodes[0];
Node* ret_input_node;
OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node));
+
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
if (ret_input_node->def().op() == "_Arg") {
int32 index = -1;
@@ -280,5 +273,5 @@ REGISTER_KERNEL_BUILDER(Name("FilterDataset").Device(DEVICE_CPU),
FilterDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
index b1eb2fd849..2fada22a21 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -39,18 +39,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
-
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, func_, std::move(captured_func),
output_types_, output_shapes_);
}
@@ -245,7 +236,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
private:
Status BuildCurrentElementIteratorLocked(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- return dataset::MakeIteratorFromInputElement(
+ return MakeIteratorFromInputElement(
ctx, captured_func_inputs_, element_index_++,
dataset()->captured_func_.get(), prefix(),
&current_element_iterator_);
@@ -285,5 +276,5 @@ REGISTER_KERNEL_BUILDER(Name("FlatMapDataset").Device(DEVICE_CPU),
FlatMapDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index ccee690d7e..71a36314a0 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
@@ -144,54 +145,31 @@ GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx)
void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase** output) {
- OpInputList init_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args",
- &init_func_other_args_input));
- std::vector<Tensor> init_func_other_args;
- init_func_other_args.reserve(init_func_other_args_input.size());
- for (const Tensor& t : init_func_other_args_input) {
- init_func_other_args.push_back(t);
- }
std::unique_ptr<CapturedFunction> init_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args),
- &init_func));
-
- OpInputList next_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args",
- &next_func_other_args_input));
- std::vector<Tensor> next_func_other_args;
- next_func_other_args.reserve(next_func_other_args_input.size());
- for (const Tensor& t : next_func_other_args_input) {
- next_func_other_args.push_back(t);
- }
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(
+ init_func_, ctx, "init_func_other_args", &init_func));
+
std::unique_ptr<CapturedFunction> next_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args),
- &next_func));
-
- OpInputList finalize_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args",
- &finalize_func_other_args_input));
- std::vector<Tensor> finalize_func_other_args;
- finalize_func_other_args.reserve(finalize_func_other_args_input.size());
- for (const Tensor& t : finalize_func_other_args_input) {
- finalize_func_other_args.push_back(t);
- }
- std::unique_ptr<CapturedFunction> finalize_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- finalize_func_, std::move(finalize_func_other_args),
- &finalize_func));
+ next_func_, ctx, "next_func_other_args", &next_func));
+
+ std::unique_ptr<CapturedFunction> finalize_func;
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(finalize_func_, ctx,
+ "finalize_func_other_args",
+ &finalize_func));
*output =
new Dataset(ctx, std::move(init_func), std::move(next_func),
std::move(finalize_func), output_types_, output_shapes_);
}
+namespace {
REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU),
GeneratorDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("GeneratorDataset").Device(DEVICE_GPU).HostMemory("handle"),
GeneratorDatasetOp);
+} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.h b/tensorflow/core/kernels/data/generator_dataset_op.h
index 8407543136..d23ed97ec3 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.h
+++ b/tensorflow/core/kernels/data/generator_dataset_op.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
+namespace data {
class GeneratorDatasetOp : public DatasetOpKernel {
public:
@@ -36,5 +37,6 @@ class GeneratorDatasetOp : public DatasetOpKernel {
NameAttrList finalize_func_;
};
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
index 130f04da3e..d6ee42a7c6 100644
--- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -433,4 +434,5 @@ REGISTER_KERNEL_BUILDER(Name("GroupByReducerDataset").Device(DEVICE_CPU),
GroupByReducerDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index 46a3185b49..8b417bb1c2 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -41,50 +42,19 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- // Get captured inputs for the key, reduce, and window_size functions.
- OpInputList key_func_other_argument_inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("key_func_other_arguments",
- &key_func_other_argument_inputs));
- std::vector<Tensor> key_func_other_arguments;
- key_func_other_arguments.reserve(key_func_other_argument_inputs.size());
- for (const Tensor& t : key_func_other_argument_inputs) {
- key_func_other_arguments.push_back(t);
- }
- OpInputList reduce_func_other_argument_inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("reduce_func_other_arguments",
- &reduce_func_other_argument_inputs));
- std::vector<Tensor> reduce_func_other_arguments;
- reduce_func_other_arguments.reserve(
- reduce_func_other_argument_inputs.size());
- for (const Tensor& t : reduce_func_other_argument_inputs) {
- reduce_func_other_arguments.push_back(t);
- }
- OpInputList window_size_func_other_argument_inputs;
- OP_REQUIRES_OK(ctx,
- ctx->input_list("window_size_func_other_arguments",
- &window_size_func_other_argument_inputs));
- std::vector<Tensor> window_size_func_other_arguments;
- window_size_func_other_arguments.reserve(
- window_size_func_other_argument_inputs.size());
- for (const Tensor& t : window_size_func_other_argument_inputs) {
- window_size_func_other_arguments.push_back(t);
- }
- // TODO(mrry): Refactor CapturedFunction to share the runtime
- // state between multiple functions?
std::unique_ptr<CapturedFunction> captured_key_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- key_func_, std::move(key_func_other_arguments),
- &captured_key_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx,
+ "key_func_other_arguments",
+ &captured_key_func));
std::unique_ptr<CapturedFunction> captured_reduce_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(reduce_func_,
- std::move(reduce_func_other_arguments),
- &captured_reduce_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx,
+ "reduce_func_other_arguments",
+ &captured_reduce_func));
std::unique_ptr<CapturedFunction> captured_window_size_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- window_size_func_, std::move(window_size_func_other_arguments),
- &captured_window_size_func));
+ OP_REQUIRES_OK(ctx,
+ CapturedFunction::Create(window_size_func_, ctx,
+ "window_size_func_other_arguments",
+ &captured_window_size_func));
*output = new Dataset(
ctx, input, key_func_, reduce_func_, window_size_func_,
@@ -549,4 +519,5 @@ REGISTER_KERNEL_BUILDER(Name("GroupByWindowDataset").Device(DEVICE_CPU),
GroupByWindowDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc
index 716e040277..0aa802b874 100644
--- a/tensorflow/core/kernels/data/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -39,14 +39,6 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
const Tensor* cycle_length_t;
OP_REQUIRES_OK(ctx, ctx->input("cycle_length", &cycle_length_t));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(cycle_length_t->shape()),
@@ -66,8 +58,8 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
errors::InvalidArgument("block_length must be greater than zero."));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output =
new Dataset(ctx, input, func_, std::move(captured_func), cycle_length,
@@ -201,7 +193,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(input_impl_->GetNext(
ctx, &args_list_[cycle_index_], &end_of_input_));
if (!end_of_input_) {
- TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
+ TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, args_list_[cycle_index_], cycle_index_,
dataset()->captured_func_.get(), prefix(),
&current_elements_[cycle_index_]));
@@ -288,7 +280,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
&args_list_[idx][i]));
}
- TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
+ TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, args_list_[idx], idx, dataset()->captured_func_.get(),
prefix(), &current_elements_[idx]));
TF_RETURN_IF_ERROR(
@@ -330,5 +322,5 @@ REGISTER_KERNEL_BUILDER(Name("InterleaveDataset").Device(DEVICE_CPU),
InterleaveDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 4e9b280968..30c6585ba2 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -36,7 +36,7 @@ limitations under the License.
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -236,6 +236,8 @@ class IteratorResource : public ResourceBase {
const std::vector<PartialTensorShape> output_shapes_;
};
+namespace {
+
// Helper class for reading data from a VariantTensorData object.
class VariantTensorDataReader : public IteratorStateReader {
public:
@@ -401,12 +403,12 @@ class IteratorStateVariant {
}
string TypeName() const { return kIteratorVariantTypeName; }
void Encode(VariantTensorData* data) const { *data = *data_; }
- bool Decode(const VariantTensorData& data) {
+ bool Decode(VariantTensorData data) {
if (data.type_name() != TypeName()) {
return false;
}
std::unique_ptr<VariantTensorData> tensor_data(new VariantTensorData);
- *tensor_data = data;
+ std::swap(*tensor_data, data);
std::unique_ptr<VariantTensorDataReader> reader(
new VariantTensorDataReader(tensor_data.get()));
status_ = reader->status();
@@ -443,6 +445,8 @@ class IteratorStateVariant {
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
kIteratorVariantTypeName);
+} // namespace
+
// Note that IteratorHandleOp holds a reference to the resource it creates. If
// cleaning up resources with DestroyResourceOp is important, consider creating
// resource containers with AnonymousIteratorHandleOp instead.
@@ -622,6 +626,8 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator)));
}
+namespace {
+
class ToSingleElementOp : public AsyncOpKernel {
public:
explicit ToSingleElementOp(OpKernelConstruction* ctx)
@@ -887,6 +893,8 @@ class OneShotIteratorOp : public AsyncOpKernel {
const int graph_def_version_;
};
+} // namespace
+
void IteratorGetNextOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
IteratorResource* iterator;
OP_REQUIRES_OK_ASYNC(
@@ -957,6 +965,8 @@ void IteratorGetNextSyncOp::Compute(OpKernelContext* ctx) {
}
}
+namespace {
+
class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
public:
explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx)
@@ -1037,6 +1047,8 @@ class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
+} // namespace
+
void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
const Tensor& resource_handle_t = ctx->input(0);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
@@ -1108,6 +1120,8 @@ void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
}
+namespace {
+
class SerializeIteratorOp : public OpKernel {
public:
explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
@@ -1202,4 +1216,7 @@ REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
DeserializeIteratorOp);
+} // namespace
+
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h
index 723564286c..8a2b2639a7 100644
--- a/tensorflow/core/kernels/data/iterator_ops.h
+++ b/tensorflow/core/kernels/data/iterator_ops.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
namespace tensorflow {
+namespace data {
class IteratorResource;
@@ -142,6 +143,7 @@ class IteratorFromStringHandleOp : public OpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 8b0c9ad6b2..83896219a3 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -26,10 +26,11 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -39,7 +40,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()),
op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -49,14 +49,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
int64 batch_size;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size));
OP_REQUIRES(
@@ -77,7 +69,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
case 2:
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx,
+ num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
break;
@@ -92,8 +85,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
ParseScalarArgument(ctx, "drop_remainder", &drop_remainder));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, batch_size, num_parallel_calls,
drop_remainder, output_types_, output_shapes_, func_,
@@ -190,7 +183,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
+ : DatasetIterator<Dataset>(params),
+ num_parallel_calls_(params.dataset->num_parallel_calls_) {}
~Iterator() override {
mutex_lock l(mu_);
@@ -204,6 +198,24 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
+ if (num_parallel_calls_ == kAutoTune) {
+ num_parallel_calls_ = 1;
+ std::function<void(int64)> set_fn = [this](int64 value) {
+ {
+ mutex_lock l(mu_);
+ num_parallel_calls_ = value;
+ }
+ VLOG(2) << "setting parallelism knob to " << value;
+ cond_var_.notify_all();
+ };
+ AddTunableParameter(
+ ctx, "parallelism", num_parallel_calls_ /* value */, 1 /* min */,
+ port::NumSchedulableCPUs() /* max */, std::move(set_fn));
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ }
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -218,7 +230,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
EnsureRunnerThreadStarted(ctx);
while (batch_results_.empty() ||
batch_results_.front()->num_calls > 0) {
+ StopWork(ctx);
cond_var_.wait(l);
+ StartWork(ctx);
}
std::swap(result, batch_results_.front());
batch_results_.pop_front();
@@ -365,7 +379,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
ctx.get(), std::move(input_element), return_values.get(),
[this, ctx, result, return_values, offset](Status status) {
Callback(ctx, result, return_values, offset, status);
- });
+ },
+ prefix());
},
ctx, std::move(input_element)));
}
@@ -423,7 +438,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
int MaxBatchResults() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- return (dataset()->num_parallel_calls_ + dataset()->batch_size_ - 1) /
+ return (num_parallel_calls_ + dataset()->batch_size_ - 1) /
dataset()->batch_size_;
}
@@ -475,23 +490,31 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
LOCKS_EXCLUDED(mu_) {
std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
- new_calls.reserve(dataset()->num_parallel_calls_);
+ StartWork(ctx.get());
+ auto stop_cleanup =
+ gtl::MakeCleanup([this, &ctx]() { StopWork(ctx.get()); });
+ {
+ tf_shared_lock l(mu_);
+ new_calls.reserve(num_parallel_calls_);
+ }
while (true) {
{
mutex_lock l(mu_);
while (!cancelled_ &&
- (num_calls_ >= dataset()->num_parallel_calls_ ||
+ (num_calls_ >= num_parallel_calls_ ||
batch_results_.size() > MaxBatchResults() ||
(batch_results_.size() == MaxBatchResults() &&
call_counter_ % dataset()->batch_size_ == 0))) {
+ StopWork(ctx.get());
cond_var_.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_) {
return;
}
- while (num_calls_ < dataset()->num_parallel_calls_ &&
+ while (num_calls_ < num_parallel_calls_ &&
(batch_results_.size() < MaxBatchResults() ||
(batch_results_.size() == MaxBatchResults() &&
call_counter_ % dataset()->batch_size_ != 0))) {
@@ -638,6 +661,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
// user specified level of parallelism and there are slots available in
// the `batch_results_` buffer.
condition_variable cond_var_;
+ // Identifies the maximum number of parallel calls.
+ int64 num_parallel_calls_ GUARDED_BY(mu_) = 0;
// Counts the number of outstanding calls for this batch.
int64 num_calls_ GUARDED_BY(mu_) = 0;
// Counts the total number of calls.
@@ -661,7 +686,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const Eigen::ThreadPoolDevice* device_; // not owned
};
- const int graph_def_version_;
const int op_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
@@ -675,5 +699,5 @@ REGISTER_KERNEL_BUILDER(Name("MapAndBatchDatasetV2").Device(DEVICE_CPU),
MapAndBatchDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index 7f8182d917..f112e1dc43 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -28,27 +28,20 @@ namespace {
class MapDatasetOp : public UnaryDatasetOpKernel {
public:
- explicit MapDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
+ &use_inter_op_parallelism_));
}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ use_inter_op_parallelism_,
+ &captured_func));
*output = new Dataset(ctx, input, func_, std::move(captured_func),
output_types_, output_shapes_);
@@ -183,14 +176,14 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList func_;
+ bool use_inter_op_parallelism_;
};
REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc
index 607d0ca028..6657f2b2b3 100644
--- a/tensorflow/core/kernels/data/map_defun_op.cc
+++ b/tensorflow/core/kernels/data/map_defun_op.cc
@@ -18,18 +18,20 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/batch_util.h"
#include "tensorflow/core/util/reffed_status_callback.h"
namespace tensorflow {
+namespace data {
namespace {
void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
bool always_collect_stats) {
opts->step_id = ctx->step_id();
opts->rendezvous = ctx->rendezvous();
- opts->cancellation_manager = ctx->cancellation_manager();
if (always_collect_stats) {
opts->stats_collector = ctx->stats_collector();
}
@@ -60,103 +62,186 @@ class MapDefunOp : public AsyncOpKernel {
~MapDefunOp() override {}
- void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- int64 batch_size = ctx->input(0).dim_size(0);
- // Inputs
- auto* args = new std::vector<Tensor>;
- auto* arg_shapes = new std::vector<TensorShape>;
- arg_shapes->reserve(ctx->num_inputs());
- args->reserve(ctx->num_inputs());
-
+ Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) {
+ // Validates inputs and gets the size of their leading dimension.
+ *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- args->push_back(ctx->input(i));
- arg_shapes->push_back(ctx->input(i).shape());
- arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension
- OP_REQUIRES_ASYNC(
- ctx, batch_size == ctx->input(i).dim_size(0),
- errors::InvalidArgument(
- "All inputs must have the same dimension 0. Input ", i,
- " has leading dimension ", ctx->input(i).dim_size(0),
- ", while all previous inputs have leading dimension ", batch_size,
- "."),
- done);
+ if (ctx->input(i).dims() == 0) {
+ return errors::InvalidArgument(
+ "All inputs must have rank at least 1. Input ", i,
+ " has a rank of 0.");
+ } else if (ctx->input(i).dim_size(0) != *batch_size) {
+ return errors::InvalidArgument(
+ "All inputs must have the same dimension 0. Input ", i,
+ " has leading dimension ", ctx->input(i).dim_size(0),
+ ", while all previous inputs have leading dimension ", batch_size);
+ }
}
+ return Status::OK();
+ }
- // Outputs
- auto* output = new OpOutputList;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done);
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ ComputeOptions* compute_opts = nullptr;
- for (size_t i = 0; i < output_types().size(); ++i) {
- Tensor* out = nullptr;
- TensorShape output_shape = output_shapes_.at(i);
- output_shape.InsertDim(0, batch_size);
- OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out), done);
- }
+ OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done);
- SetRunOptions(ctx, &opts_, false);
+ Status s = SetupOutputs(ctx, compute_opts);
+ if (!s.ok()) delete compute_opts;
+ OP_REQUIRES_OK_ASYNC(ctx, s, done);
+
+ FunctionLibraryRuntime::Options opts;
+ SetRunOptions(ctx, &opts, false);
// Run loop
StatusCallback callback = std::bind(
- [](OpKernelContext* ctx, std::vector<Tensor>* args,
- std::vector<TensorShape>* arg_shapes, OpOutputList* output,
+ [](OpKernelContext* ctx, ComputeOptions* compute_opts,
DoneCallback& done, const Status& status) {
- delete args;
- delete arg_shapes;
- delete output;
+ delete compute_opts;
ctx->SetStatus(status);
done();
},
- ctx, args, arg_shapes, output, std::move(done), std::placeholders::_1);
+ ctx, compute_opts, std::move(done), std::placeholders::_1);
auto* refcounted = new ReffedStatusCallback(std::move(callback));
- for (size_t i = 1; i < static_cast<size_t>(batch_size); ++i) {
- // Start from i = 1 because refcounted is initialized with refcount = 1
+ CancellationManager* parent_mgr = ctx->cancellation_manager();
+
+ for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) {
+ // We use a different cancellation manager each time the function is run
+ // to avoid the race condition between a function run error and other
+ // functions being cancelled as a result.
+ CancellationManager* c_mgr = new CancellationManager;
+ CancellationToken token = parent_mgr->get_cancellation_token();
+ const bool success = parent_mgr->RegisterCallback(
+ token, [c_mgr]() { c_mgr->StartCancel(); });
+
+ opts.cancellation_manager = c_mgr;
+ if (!success) {
+ delete c_mgr;
+ refcounted->UpdateStatus(errors::Cancelled(
+ "MapDefunOp functions cancelled because parent graph cancelled"));
+ break;
+ }
+
+ auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i);
+
refcounted->Ref();
+ ctx->function_library()->Run(opts, func_handle_, call_frame,
+ [call_frame, refcounted, c_mgr, parent_mgr,
+ token](const Status& func_status) {
+ parent_mgr->DeregisterCallback(token);
+ delete c_mgr;
+ delete call_frame;
+ refcounted->UpdateStatus(func_status);
+ refcounted->Unref();
+ });
}
- for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) {
- auto* call_frame =
- new MapFunctionCallFrame(*args, *arg_shapes, output, this, i);
- ctx->function_library()->Run(
- opts_, func_handle_, call_frame,
- [call_frame, refcounted](const Status& func_status) {
- delete call_frame;
- refcounted->UpdateStatus(func_status);
- refcounted->Unref();
- });
- }
+
+ // Unref 1 because refcounted is initialized with refcount = 1
+ refcounted->Unref();
}
private:
FunctionLibraryRuntime::Handle func_handle_;
- FunctionLibraryRuntime::Options opts_;
- std::vector<TensorShape> output_shapes_;
+ std::vector<PartialTensorShape> output_shapes_;
+
+ struct ComputeOptions {
+ // These vary per MapDefunOp::ComputeAsync call, but must persist until
+ // all calls to the function are complete. This struct also encapsulates
+ // all the components that need to be passed to each MapFunctionCallFrame.
+
+ const std::vector<Tensor> args;
+ const std::vector<TensorShape> arg_shapes;
+ const int64 batch_size;
+
+ // Output of a compute call
+ std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu);
+ OpOutputList output GUARDED_BY(mu);
+ mutex mu;
+
+ // Create a copy of output_shapes because every `Compute` may expect a
+ // different output shape.
+ ComputeOptions(std::vector<Tensor> args,
+ std::vector<TensorShape> arg_shapes, int64 batch_size,
+ const std::vector<PartialTensorShape>& output_shapes_attr)
+ : args(std::move(args)),
+ arg_shapes(std::move(arg_shapes)),
+ batch_size(batch_size),
+ output_shapes(output_shapes_attr) {}
+ };
+
+ // Get inputs to Compute and check that they are valid.
+ Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
+ int64 batch_size =
+ ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
+
+ for (size_t i = 0; i < ctx->num_inputs(); ++i) {
+ if (ctx->input(i).dims() == 0) {
+ return errors::InvalidArgument(
+ "All inputs must have rank at least 1. Input ", i,
+ " has a rank of 0.");
+ } else if (ctx->input(i).dim_size(0) != batch_size) {
+ return errors::InvalidArgument(
+ "All inputs must have the same dimension 0. Input ", i,
+ " has leading dimension ", ctx->input(i).dim_size(0),
+ ", while all previous inputs have leading dimension ", batch_size);
+ }
+ }
+
+ std::vector<Tensor> args;
+ std::vector<TensorShape> arg_shapes;
+ args.reserve(ctx->num_inputs());
+ arg_shapes.reserve(ctx->num_inputs());
+
+ for (size_t i = 0; i < ctx->num_inputs(); ++i) {
+ args.push_back(ctx->input(i));
+ arg_shapes.push_back(ctx->input(i).shape());
+ arg_shapes.at(i).RemoveDim(0);
+ }
+
+ *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes),
+ batch_size, output_shapes_);
+ return Status::OK();
+ }
+
+ Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) {
+ mutex_lock l(opts->mu);
+ TF_RETURN_IF_ERROR(ctx->output_list("output", &opts->output));
+
+ for (size_t i = 0; i < output_types().size(); ++i) {
+ if (output_shapes_.at(i).IsFullyDefined()) {
+ Tensor* out = nullptr;
+ TensorShape output_shape;
+ output_shapes_.at(i).AsTensorShape(&output_shape);
+ output_shape.InsertDim(0, opts->batch_size);
+ TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out));
+ }
+ }
+ return Status::OK();
+ }
class MapFunctionCallFrame : public CallFrameInterface {
public:
- MapFunctionCallFrame(const std::vector<Tensor>& args,
- const std::vector<TensorShape>& arg_shapes,
- OpOutputList* output, OpKernel* kernel, size_t iter)
- : args_(args),
- arg_shapes_(arg_shapes),
- output_(output),
- kernel_(kernel),
- iter_(iter) {}
+ MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel,
+ size_t iter)
+ : compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {}
~MapFunctionCallFrame() override {}
- size_t num_args() const override { return args_.size(); }
+ size_t num_args() const override { return compute_opts_->args.size(); }
+
size_t num_retvals() const override {
return static_cast<size_t>(kernel_->num_outputs());
}
Status GetArg(int index, Tensor* val) const override {
- if (index < 0 || index >= args_.size()) {
+ if (index < 0 || index >= compute_opts_->args.size()) {
return errors::InvalidArgument(
"Mismatch in number of function inputs.");
}
- bool result = val->CopyFrom(args_.at(index).Slice(iter_, iter_ + 1),
- arg_shapes_.at(index));
+ bool result =
+ val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1),
+ compute_opts_->arg_shapes.at(index));
if (!result) {
return errors::Internal("GetArg failed.");
} else if (!val->IsAligned()) {
@@ -179,18 +264,39 @@ class MapDefunOp : public AsyncOpKernel {
"output: ",
index);
}
- return batch_util::CopyElementToSlice(val, (*output_)[index], iter_);
+ { // Locking scope
+ mutex_lock l(compute_opts_->mu);
+ if (!compute_opts_->output_shapes.at(index).IsCompatibleWith(
+ val.shape())) {
+ return errors::InvalidArgument(
+ "Mismatch in function retval shape, ", val.shape(),
+ ", and expected output shape, ",
+ compute_opts_->output_shapes.at(index).DebugString(), ".");
+ }
+ if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) {
+ // Given val, we have new information about the output shape at
+ // this index. Store the shape and allocate the output accordingly.
+ compute_opts_->output_shapes.at(index) = val.shape();
+
+ Tensor* out = nullptr;
+ TensorShape actual_shape = val.shape();
+ actual_shape.InsertDim(0, compute_opts_->batch_size);
+ TF_RETURN_IF_ERROR(
+ compute_opts_->output.allocate(index, actual_shape, &out));
+ }
+ return batch_util::CopyElementToSlice(
+ val, (compute_opts_->output)[index], iter_);
+ }
}
private:
- const std::vector<Tensor>& args_;
- const std::vector<TensorShape>& arg_shapes_;
- OpOutputList* output_;
+ ComputeOptions* const compute_opts_; // Not owned
const OpKernel* kernel_;
const size_t iter_;
};
-}; // namespace
+};
REGISTER_KERNEL_BUILDER(Name("MapDefun").Device(DEVICE_CPU), MapDefunOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc
new file mode 100644
index 0000000000..63025d3371
--- /dev/null
+++ b/tensorflow/core/kernels/data/model_dataset_op.cc
@@ -0,0 +1,146 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/cpu_info.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+const int kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMicros;
+
+class ModelDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit ModelDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ *output = new Dataset(ctx, input);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, const DatasetBase* input)
+ : DatasetBase(DatasetContext(ctx)), input_(input) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Model")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override { return "ModelDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ model_(std::make_shared<model::Model>()) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ IteratorContext ctx_with_model(CreateParams(ctx));
+ return dataset()->input_->MakeIterator(&ctx_with_model, prefix(),
+ &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ int64 now = ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
+ if (last_optimization_ms_ + optimization_period_ms_ < now) {
+ model_->Optimize(port::NumSchedulableCPUs());
+ // Exponentially increase the period of running the optimization until
+ // a threshold is reached.
+ if (optimization_period_ms_ < kOptimizationPeriodThresholdMs) {
+ if (optimization_period_ms_ << 1 < kOptimizationPeriodThresholdMs) {
+ optimization_period_ms_ <<= 1;
+ } else {
+ optimization_period_ms_ = kOptimizationPeriodThresholdMs;
+ }
+ }
+ last_optimization_ms_ =
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
+ }
+ IteratorContext ctx_with_model(CreateParams(ctx));
+ return input_impl_->GetNext(&ctx_with_model, out_tensors,
+ end_of_sequence);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ IteratorContext::Params CreateParams(IteratorContext* ctx) {
+ IteratorContext::Params params = ctx->params();
+ params.model = model_;
+ return params;
+ }
+
+ private:
+ mutex mu_;
+ std::shared_ptr<model::Model> model_;
+ int64 last_optimization_ms_ GUARDED_BY(mu_) = 0;
+ int64 optimization_period_ms_ GUARDED_BY(mu_) = 10;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* input_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("ModelDataset").Device(DEVICE_CPU),
+ ModelDatasetOp);
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index 831e7252da..d5b725eac9 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -92,8 +93,10 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
DatasetGraphDefBuilder db(&b);
Node* input_node = nullptr;
SerializationContext::Params params;
+ std::vector<std::pair<string, Tensor>> input_list;
params.allow_stateful_functions = true;
params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
+ params.input_list = &input_list;
SerializationContext serialization_ctx(params);
TF_RETURN_IF_ERROR(
db.AddInputDataset(&serialization_ctx, input_, &input_node));
@@ -118,7 +121,7 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
GraphRunner graph_runner(ctx->function_library()->device());
TF_RETURN_IF_ERROR(
- graph_runner.Run(&graph, lib_, {}, {output_node}, &outputs));
+ graph_runner.Run(&graph, lib_, input_list, {output_node}, &outputs));
TF_RETURN_IF_ERROR(
GetDatasetFromVariantTensor(outputs[0], &optimized_input_));
optimized_input_->Ref();
@@ -268,4 +271,5 @@ REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
OptimizeDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc
index cfac45dbc7..346e4ceebd 100644
--- a/tensorflow/core/kernels/data/optional_ops.cc
+++ b/tensorflow/core/kernels/data/optional_ops.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant_op_registry.h"
namespace tensorflow {
+namespace data {
namespace {
const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
@@ -107,11 +108,8 @@ class OptionalFromValueOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
OpInputList components_input;
OP_REQUIRES_OK(ctx, ctx->input_list("components", &components_input));
- std::vector<Tensor> components;
- components.reserve(components_input.size());
- for (const Tensor& component_t : components_input) {
- components.push_back(component_t);
- }
+ std::vector<Tensor> components(components_input.begin(),
+ components_input.end());
OP_REQUIRES_OK(
ctx, WriteOptionalWithValueToOutput(ctx, 0, std::move(components)));
}
@@ -230,10 +228,9 @@ static Status OptionalDeviceCopy(
return Status::OK();
}
-#define REGISTER_OPTIONAL_COPY(DIRECTION) \
- INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- OptionalVariant, DIRECTION, kOptionalVariantTypeName, \
- OptionalDeviceCopy)
+#define REGISTER_OPTIONAL_COPY(DIRECTION) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
+ OptionalVariant, DIRECTION, OptionalDeviceCopy)
REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
@@ -267,4 +264,5 @@ Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) {
return Status::OK();
}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.h b/tensorflow/core/kernels/data/optional_ops.h
index 6f25567678..2cbf2933f5 100644
--- a/tensorflow/core/kernels/data/optional_ops.h
+++ b/tensorflow/core/kernels/data/optional_ops.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant_tensor_data.h"
namespace tensorflow {
+namespace data {
// Stores a DT_VARIANT value representing an Optional with the given value
// in the `output_index`^th output of the given kernel execution context.
@@ -31,6 +32,7 @@ Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
// in the `output_index`^th output of the given kernel execution context.
Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index);
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index be45eac46e..7b01c3b4e0 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -207,6 +207,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
@@ -382,5 +383,5 @@ REGISTER_KERNEL_BUILDER(Name("PaddedBatchDatasetV2").Device(DEVICE_CPU),
PaddedBatchDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index f6b3fd97e3..9cd46bf5dd 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <deque>
+#include <utility>
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
@@ -21,11 +22,12 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -34,8 +36,7 @@ namespace {
class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
@@ -43,14 +44,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
int64 cycle_length = 0;
OP_REQUIRES_OK(ctx,
ParseScalarArgument(ctx, "cycle_length", &cycle_length));
@@ -82,8 +75,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- interleave_func_, std::move(other_arguments), &captured_func));
+ ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments",
+ &captured_func));
*output =
new Dataset(ctx, input, interleave_func_, std::move(captured_func),
@@ -125,6 +118,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
const DataTypeVector& output_dtypes() const override {
return output_types_;
}
+
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
@@ -250,6 +244,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
+ AddConstantParameter(ctx, "parallelism", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -349,11 +344,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (must_wait_for_input) {
// Wait for elements to become available.
+ StopWork(ctx);
if (dataset()->sloppy_) {
sloppy_cond_var_.wait(l);
} else {
workers_[interleave_indices_[next_index_]].cond_var.wait(l);
}
+ StartWork(ctx);
}
}
return errors::Cancelled(
@@ -482,10 +479,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (reader->Contains(full_name("worker_threads_running"))) {
worker_threads_.reserve(dataset()->num_threads());
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, "worker_thread",
- std::bind(&Iterator::WorkerThread, this,
- new IteratorContext(*ctx), i)));
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
}
}
return Status::OK();
@@ -581,10 +578,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
workers_[i].SetInputs(s, std::move(args));
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, "worker_thread",
- std::bind(&Iterator::WorkerThread, this,
- new IteratorContext(*ctx), i)));
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
if (i < dataset()->cycle_length_) {
interleave_indices_.push_back(i);
} else {
@@ -599,7 +596,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
// Produces elements into the worker's output buffers.
- void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index) {
+ void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
+ const int64 thread_index) {
// Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
//
// 1. Any local state that may need to be checkpointed should be kept
@@ -620,10 +618,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// std::function arguments are copy-constructable, so we pass raw
// pointers, and then immediately wrap them to ensure correct ownership.
- std::unique_ptr<IteratorContext> ctx(ctx_ptr);
- auto cleanup = gtl::MakeCleanup([this, thread_index] {
+ StartWork(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
mutex_lock l(mu_);
workers_[thread_index].cond_var.notify_all();
+ StopWork(ctx.get());
});
bool make_new_iterator;
{
@@ -649,9 +648,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// 1. Build a new iterator or use the existing one.
if (make_new_iterator) {
// 1a. Get new input tensors or use the exiting ones.
-
bool read_new_input;
-
{
tf_shared_lock l(ckpt_mu_);
// worker_thread_states_[thread_index].input will be non-empty
@@ -663,7 +660,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (read_new_input) {
mutex_lock l(mu_);
while (!cancelled_ && !workers_[thread_index].is_producing) {
+ StopWork(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_) return;
// Copy the input tensors so that we do not need to block on `mu_`
@@ -684,7 +683,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
{
tf_shared_lock l(ckpt_mu_);
worker_thread_states_[thread_index].iterator_creation_status =
- dataset::MakeIteratorFromInputElement(
+ MakeIteratorFromInputElement(
ctx.get(), worker_thread_states_[thread_index].input,
thread_index, dataset()->captured_func_.get(), prefix(),
&worker_thread_states_[thread_index].iterator);
@@ -713,7 +712,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
+ StopWork(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_) return;
tf_shared_lock ckpt_l(ckpt_mu_);
@@ -762,7 +763,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
+ StopWork(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_) return;
@@ -914,7 +917,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
worker_thread_states_[index].iterator.reset();
} else {
std::unique_ptr<IteratorBase> iterator;
- Status s = dataset::MakeIteratorFromInputElement(
+ Status s = MakeIteratorFromInputElement(
ctx, worker_thread_states_[index].input, index,
dataset()->captured_func_.get(), prefix(), &iterator);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
@@ -1058,7 +1061,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList interleave_func_;
@@ -1067,6 +1069,616 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
ParallelInterleaveDatasetOp);
-} // namespace
+// The motivation for creating an alternative implementation of parallel
+// interleave is to decouple the degree of parallelism from the cycle length.
+// This makes it possible to change the degree of parallelism (e.g. through
+// auto-tuning) without changing the cycle length (which would change the order
+// in which elements are produced).
+//
+// Furthermore, this class favors modularity over extended functionality. In
+// particular, it refrains from implementing configurable buffering of output
+// elements and prefetching of input iterators, relying on other parts of
+// tf.data to provide this functionality if necessary.
+//
+// The above design choices were made with automated optimizations in mind,
+// isolating the degree of parallelism as the single tunable knob of this
+// implementation.
+class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
+ public:
+ explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ int64 cycle_length = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument(ctx, "cycle_length", &cycle_length));
+ OP_REQUIRES(ctx, cycle_length > 0,
+ errors::InvalidArgument("`cycle_length` must be > 0"));
+
+ int64 block_length = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument(ctx, "block_length", &block_length));
+ OP_REQUIRES(ctx, block_length > 0,
+ errors::InvalidArgument("`block_length` must be > 0"));
+
+ int64 num_parallel_calls;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
+ &num_parallel_calls));
+ OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
+ errors::InvalidArgument(
+ "num_parallel_calls must be greater than zero."));
+ OP_REQUIRES(
+ ctx, num_parallel_calls <= cycle_length,
+ errors::InvalidArgument(
+ "num_parallel_calls must less than or equal to cycle_length."));
+
+ std::unique_ptr<CapturedFunction> captured_func;
+ OP_REQUIRES_OK(
+ ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments",
+ &captured_func));
+
+ *output = new Dataset(ctx, input, interleave_func_,
+ std::move(captured_func), cycle_length, block_length,
+ num_parallel_calls, output_types_, output_shapes_);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func,
+ std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
+ int64 block_length, int64 num_parallel_calls,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ interleave_func_(func),
+ captured_func_(std::move(captured_func)),
+ cycle_length_(cycle_length),
+ block_length_(block_length),
+ num_parallel_calls_(num_parallel_calls),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::ParallelInterleaveV2")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "ParallelInterleaveDatasetV2Op::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name()));
+ Node* input_node;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
+ Node* cycle_length_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
+ Node* block_length_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
+ Node* num_parallel_calls_node;
+ TF_RETURN_IF_ERROR(
+ b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
+ DataTypeVector other_arguments_types;
+ other_arguments_types.reserve(captured_func_->captured_inputs().size());
+ std::vector<Node*> other_arguments;
+ other_arguments.reserve(captured_func_->captured_inputs().size());
+ for (const Tensor& t : captured_func_->captured_inputs()) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ other_arguments.emplace_back(node);
+ other_arguments_types.emplace_back(t.dtype());
+ }
+ AttrValue f;
+ b->BuildAttrValue(interleave_func_, &f);
+ AttrValue other_arguments_types_attr;
+ b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
+
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this,
+ {{0, input_node},
+ {2, cycle_length_node},
+ {3, block_length_node},
+ {4, num_parallel_calls_node}},
+ {{1, other_arguments}},
+ {{"f", f}, {"Targuments", other_arguments_types_attr}}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ args_list_(params.dataset->cycle_length_),
+ current_elements_(params.dataset->cycle_length_),
+ element_in_use_(params.dataset->cycle_length_, false),
+ num_parallel_calls_(params.dataset->num_parallel_calls_),
+ thread_pool_(new thread::ThreadPool(
+ Env::Default(), ThreadOptions(), "parallel_interleave",
+ dataset()->cycle_length_ /* num_threads */,
+ false /* low_latency_hint */)) {}
+
+ ~Iterator() override {
+ mutex_lock l(mu_);
+ // Cancel the runner thread.
+ cancelled_ = true;
+ cond_var_.notify_all();
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_.wait(l);
+ }
+ }
+
+ Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ if (num_parallel_calls_ == kAutoTune) {
+ num_parallel_calls_ = 1;
+ auto set_fn = [this](int64 value) {
+ {
+ mutex_lock l(mu_);
+ num_parallel_calls_ = value;
+ }
+ VLOG(2) << "setting parallelism knob to " << value;
+ cond_var_.notify_all();
+ };
+ AddTunableParameter(
+ ctx, "parallelism", num_parallel_calls_ /* value */, 1 /* min */,
+ dataset()->cycle_length_ /* max */, std::move(set_fn));
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ }
+ AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ std::shared_ptr<InvocationResult> result;
+ do {
+ {
+ mutex_lock l(mu_);
+ EnsureRunnerThreadStarted(ctx);
+ while (invocation_results_.empty() &&
+ (!end_of_input_ || num_open_ > 0)) {
+ StopWork(ctx);
+ cond_var_.wait(l);
+ StartWork(ctx);
+ }
+ if (!invocation_results_.empty()) {
+ std::swap(result, invocation_results_.front());
+ invocation_results_.pop_front();
+ } else {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ }
+ cond_var_.notify_all();
+ StopWork(ctx);
+ result->notification.WaitForNotification();
+ StartWork(ctx);
+ } while (result->skip);
+
+ if (result->status.ok()) {
+ *out_tensors = std::move(result->return_values);
+ }
+ *end_of_sequence = false;
+ return result->status;
+ }
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_.wait(l);
+ }
+ CHECK_EQ(num_calls_, 0);
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name("invocation_results.size"), invocation_results_.size()));
+ for (size_t i = 0; i < invocation_results_.size(); i++) {
+ std::shared_ptr<InvocationResult> result = invocation_results_[i];
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].size")),
+ result->return_values.size()));
+ for (size_t j = 0; j < result->return_values.size(); j++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(
+ strings::StrCat("invocation_results[", i, "][", j, "]")),
+ result->return_values[j]));
+ }
+ if (result->skip) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].skip")),
+ ""));
+ }
+ }
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("cycle_index"), cycle_index_));
+ if (end_of_input_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("end_of_input"), ""));
+ }
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("num_open"), num_open_));
+ TF_RETURN_IF_ERROR(WriteCurrentElements(writer));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ int64 invocation_results_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name("invocation_results.size"), &invocation_results_size));
+ for (size_t i = 0; i < invocation_results_size; i++) {
+ std::shared_ptr<InvocationResult> result(new InvocationResult());
+ invocation_results_.push_back(result);
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
+ size_t num_return_values;
+ {
+ int64 size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].size")),
+ &size));
+ num_return_values = static_cast<size_t>(size);
+ if (num_return_values != size) {
+ return errors::InvalidArgument(strings::StrCat(
+ full_name(
+ strings::StrCat("invocation_results[", i, "].size")),
+ ": ", size, " is not a valid value of type size_t."));
+ }
+ }
+ result->return_values.reserve(num_return_values);
+ for (size_t j = 0; j < num_return_values; j++) {
+ result->return_values.emplace_back();
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(full_name(strings::StrCat(
+ "invocation_results[", i, "][", j, "]")),
+ &result->return_values.back()));
+ }
+ result->skip = reader->Contains(
+ full_name(strings::StrCat("invocation_results[", i, "].skip")));
+ result->notification.Notify();
+ }
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("cycle_index"), &cycle_index_));
+ if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("num_open"), &num_open_));
+ TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader));
+ return Status::OK();
+ }
+
+ private:
+ struct InvocationResult {
+ Notification notification; // used for coordination with the consumer
+ Status status; // the invocation status
+ std::vector<Tensor> return_values; // the invocation result values
+ bool skip; // if set the result should be skipped
+ };
+
+ void EnsureRunnerThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!runner_thread_) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+ runner_thread_.reset(ctx->env()->StartThread(
+ {}, "runner_thread",
+ [this, new_ctx]() { RunnerThread(new_ctx); }));
+ }
+ }
+
+ // Fetches up to `results.size()` outputs from the cycle element at
+ // position `cycle_index`.
+ //
+ // If end of input is encountered, the `skip` field of the invocation
+ // result is used to identify results that should be skipped.
+ void FetchOutputs(
+ const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
+ const std::vector<std::shared_ptr<InvocationResult>>& results)
+ LOCKS_EXCLUDED(mu_) {
+ StartWork(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
+ bool end_of_input = false;
+ for (auto& result : results) {
+ if (!end_of_input) {
+ result->status = current_elements_[cycle_index]->GetNext(
+ ctx.get(), &result->return_values, &end_of_input);
+ }
+ if (end_of_input) {
+ result->skip = true;
+ }
+ result->notification.Notify();
+ if (!result->status.ok()) {
+ break;
+ }
+ }
+
+ // Release the ownership of the cycle element iterator, closing the
+ // iterator if end of input was encountered.
+ {
+ if (end_of_input) {
+ current_elements_[cycle_index].reset();
+ }
+ mutex_lock l(mu_);
+ element_in_use_[cycle_index] = false;
+ num_calls_--;
+ if (end_of_input) {
+ args_list_[cycle_index].clear();
+ num_open_--;
+ }
+ }
+ cond_var_.notify_all();
+ }
+
+ int64 MaxInvocationResults() {
+ return dataset()->cycle_length_ * dataset()->block_length_;
+ }
+
+ // Method responsible for 1) creating iterators out of input elements, 2)
+ // determining the order in which elements are fetched from the iterators,
+ // and 3) scheduling the fetching of the elements to a threadpool.
+ //
+ // This method runs in the `runner_thread` background thread.
+ void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ StartWork(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
+ while (true) {
+ {
+ mutex_lock l(mu_);
+ // Wait until this thread is cancelled, the end of input has been
+ // reached, or the cycle element at the `cycle_index_` position is
+ // not in use and there is space in the `invocation_results_` queue.
+ while (!cancelled_ && (!end_of_input_ || num_open_ > 0) &&
+ (element_in_use_[cycle_index_] ||
+ num_calls_ >= num_parallel_calls_ ||
+ invocation_results_.size() >= MaxInvocationResults())) {
+ StopWork(ctx.get());
+ cond_var_.wait(l);
+ StartWork(ctx.get());
+ }
+
+ if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
+ return;
+ }
+
+ while (!element_in_use_[cycle_index_] &&
+ (!end_of_input_ || num_open_ > 0) &&
+ num_calls_ < num_parallel_calls_ &&
+ invocation_results_.size() < MaxInvocationResults()) {
+ if (!current_elements_[cycle_index_]) {
+ // Try to create a new iterator from the next input element.
+ Status status = input_impl_->GetNext(
+ ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+ if (!status.ok()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ std::shared_ptr<InvocationResult>& result =
+ invocation_results_.back();
+ result->status.Update(status);
+ result->notification.Notify();
+ break;
+ }
+ if (!end_of_input_) {
+ Status status = MakeIteratorFromInputElement(
+ ctx.get(), args_list_[cycle_index_], cycle_index_,
+ dataset()->captured_func_.get(), prefix(),
+ &current_elements_[cycle_index_]);
+ if (!status.ok()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ std::shared_ptr<InvocationResult>& result =
+ invocation_results_.back();
+ result->status.Update(status);
+ result->notification.Notify();
+ break;
+ }
+ ++num_open_;
+ }
+ }
+ if (current_elements_[cycle_index_]) {
+ // Pre-allocate invocation results for outputs to be fetched
+ // and then fetch the outputs asynchronously.
+ std::vector<std::shared_ptr<InvocationResult>> results;
+ results.reserve(dataset()->block_length_);
+ for (int i = 0; i < dataset()->block_length_; ++i) {
+ invocation_results_.emplace_back(new InvocationResult());
+ results.push_back(invocation_results_.back());
+ }
+ num_calls_++;
+ element_in_use_[cycle_index_] = true;
+ thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
+ ctx, cycle_index_,
+ std::move(results)));
+ }
+ cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
+ }
+ }
+ cond_var_.notify_all();
+ }
+ }
+
+ Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
+ const Status& status)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ CodeKey(index), static_cast<int64>(status.code())));
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
+ status.error_message()));
+ }
+ return Status::OK();
+ }
+
+ Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 code_int;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
+ error::Code code = static_cast<error::Code>(code_int);
+
+ if (code != error::Code::OK) {
+ string error_message;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(ErrorMessageKey(index), &error_message));
+ *status = Status(code, error_message);
+ } else {
+ *status = Status::OK();
+ }
+ return Status::OK();
+ }
+
+ string CodeKey(size_t index) {
+ return full_name(
+ strings::StrCat("invocation_results[", index, "].code"));
+ }
+
+ string ErrorMessageKey(size_t index) {
+ return full_name(
+ strings::StrCat("invocation_results[", index, "].error_message"));
+ }
+
+ Status WriteCurrentElements(IteratorStateWriter* writer)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ for (int idx = 0; idx < current_elements_.size(); idx++) {
+ if (current_elements_[idx]) {
+ TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx]));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("args_size[", idx, "]")),
+ args_list_[idx].size()));
+ for (int i = 0; i < args_list_[idx].size(); i++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
+ args_list_[idx][i]));
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ Status ReadCurrentElements(IteratorContext* ctx,
+ IteratorStateReader* reader)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ for (int idx = 0; idx < current_elements_.size(); idx++) {
+ if (reader->Contains(
+ full_name(strings::StrCat("args_size[", idx, "]")))) {
+ int64 args_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("args_size[", idx, "]")),
+ &args_size));
+ args_list_[idx].resize(args_size);
+ for (int i = 0; i < args_size; i++) {
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
+ &args_list_[idx][i]));
+ }
+ TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
+ ctx, args_list_[idx], idx, dataset()->captured_func_.get(),
+ prefix(), &current_elements_[idx]));
+ TF_RETURN_IF_ERROR(
+ RestoreInput(ctx, reader, current_elements_[idx]));
+ } else {
+ current_elements_[idx].reset();
+ }
+ }
+ return Status::OK();
+ }
+
+ // Used for coordination between the main thread, the runner thread, and
+ // the worker threads.
+ mutex mu_;
+
+ // Used for coordination between the main thread, the runner thread, and
+ // the worker threads. In particular, the runner thread should only
+ // schedule new calls when the number of in-flight calls is less than the
+ // user specified level of parallelism, there are slots available in the
+ // `invocation_results_` buffer, the current cycle element is not in use,
+ // and there are elements left to be fetched.
+ condition_variable cond_var_;
+
+ // Iterator for input elements.
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+
+ // Identifies current cycle element.
+ int64 cycle_index_ = 0;
+
+ // Arguments for creating an iterator for cycle elements.
+ std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_);
+
+ // Iterators for the current cycle elements. Concurrent access is
+ // protected by `element_in_use_`.
+ std::vector<std::unique_ptr<IteratorBase>> current_elements_;
+
+ // Identifies cycle elements that are in use by worker threads.
+ std::vector<bool> element_in_use_ GUARDED_BY(mu_);
+
+ // Buffer for storing the invocation results.
+ std::deque<std::shared_ptr<InvocationResult>> invocation_results_
+ GUARDED_BY(mu_);
+
+ // Identifies whether end of input has been reached.
+ bool end_of_input_ GUARDED_BY(mu_) = false;
+
+ // Identifies the number of open iterators.
+ int64 num_open_ GUARDED_BY(mu_) = 0;
+
+ // Identifies the maximum number of parallel calls.
+ int64 num_parallel_calls_ GUARDED_BY(mu_) = 0;
+
+ // Identifies the number of outstanding calls.
+ int64 num_calls_ GUARDED_BY(mu_) = 0;
+
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
+
+ // Identifies whether background activity should be cancelled.
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ };
+
+ const DatasetBase* const input_;
+ const NameAttrList interleave_func_;
+ const std::unique_ptr<CapturedFunction> captured_func_;
+ const int64 cycle_length_;
+ const int64 block_length_;
+ const int64 num_parallel_calls_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ NameAttrList interleave_func_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU),
+ ParallelInterleaveDatasetV2Op);
+
+} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index bff54813d6..6abe6c8338 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -33,37 +33,32 @@ namespace {
class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ParallelMapDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
+ &use_inter_op_parallelism_));
}
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
int32 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ use_inter_op_parallelism_,
+ &captured_func));
*output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_,
- output_shapes_, std::move(captured_func));
+ output_shapes_, use_inter_op_parallelism_,
+ std::move(captured_func));
}
private:
@@ -73,6 +68,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const NameAttrList& func, int32 num_parallel_calls,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
+ bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction> captured_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
@@ -80,6 +76,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
num_parallel_calls_(num_parallel_calls),
output_types_(output_types),
output_shapes_(output_shapes),
+ use_inter_op_parallelism_(use_inter_op_parallelism),
captured_func_(std::move(captured_func)) {
input_->Ref();
}
@@ -92,16 +89,26 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
return captured_func_->Instantiate(ctx);
};
- auto map_func = [this](IteratorContext* ctx,
+ const string& new_prefix = strings::StrCat(prefix, "::ParallelMap");
+ ParallelMapIteratorFunction map_func =
+ [this, new_prefix](IteratorContext* ctx,
std::vector<Tensor> input_element,
std::vector<Tensor>* result, StatusCallback done) {
- captured_func_->RunAsync(ctx, std::move(input_element), result,
- std::move(done));
- };
+ captured_func_->RunAsync(ctx, std::move(input_element), result,
+ std::move(done), new_prefix);
+ };
+ if (!use_inter_op_parallelism_) {
+ map_func = [map_func](
+ IteratorContext* ctx, std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element),
+ result, std::move(done)));
+ };
+ }
- return NewParallelMapIterator(
- {this, strings::StrCat(prefix, "::ParallelMap")}, input_,
- std::move(init_func), std::move(map_func), num_parallel_calls_);
+ return NewParallelMapIterator({this, new_prefix}, input_,
+ std::move(init_func), std::move(map_func),
+ num_parallel_calls_);
}
const DataTypeVector& output_dtypes() const override {
@@ -167,12 +174,13 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const int32 num_parallel_calls_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
+ const bool use_inter_op_parallelism_;
const std::unique_ptr<CapturedFunction> captured_func_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
+ bool use_inter_op_parallelism_;
NameAttrList func_;
};
@@ -180,5 +188,5 @@ REGISTER_KERNEL_BUILDER(Name("ParallelMapDataset").Device(DEVICE_CPU),
ParallelMapDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 61f8139b9e..5f6052ce83 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -19,7 +19,11 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/cpu_info.h"
+
namespace tensorflow {
+namespace data {
namespace {
class ParallelMapIterator : public DatasetBaseIterator {
@@ -52,6 +56,25 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ if (num_parallel_calls_ == kAutoTune) {
+ num_parallel_calls_ = 1;
+ auto set_fn = [this](int64 value) {
+ {
+ mutex_lock l(mu_);
+ num_parallel_calls_ = value;
+ }
+ VLOG(2) << "setting parallelism knob to " << value;
+ cond_var_.notify_all();
+ };
+ // TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and
+ // use it here for the maximum.
+ AddTunableParameter(ctx, "parallelism", num_parallel_calls_ /* value */,
+ 1 /* min */, port::NumSchedulableCPUs() /* max */,
+ std::move(set_fn));
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ }
TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
if (init_func_) {
@@ -67,13 +90,17 @@ class ParallelMapIterator : public DatasetBaseIterator {
mutex_lock l(mu_);
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty()) {
+ StopWork(ctx);
cond_var_.wait(l);
+ StartWork(ctx);
}
std::swap(result, invocation_results_.front());
invocation_results_.pop_front();
}
cond_var_.notify_all();
+ StopWork(ctx);
result->notification.WaitForNotification();
+ StartWork(ctx);
return ProcessResult(result, out_tensors, end_of_sequence);
}
@@ -86,9 +113,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("invocation_results.size"),
- invocation_results_.size()));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"),
+ invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
std::shared_ptr<InvocationResult> result = invocation_results_[i];
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
@@ -204,8 +230,6 @@ class ParallelMapIterator : public DatasetBaseIterator {
std::move(done));
}
- int64 MaxInvocationResults() { return num_parallel_calls_; }
-
Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
@@ -225,21 +249,28 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ StartWork(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
- new_calls.reserve(num_parallel_calls_);
+ {
+ tf_shared_lock l(mu_);
+ new_calls.reserve(num_parallel_calls_);
+ }
while (true) {
{
mutex_lock l(mu_);
while (!cancelled_ &&
(num_calls_ >= num_parallel_calls_ ||
- invocation_results_.size() >= MaxInvocationResults())) {
+ invocation_results_.size() >= num_parallel_calls_)) {
+ StopWork(ctx.get());
cond_var_.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_) {
return;
}
while (num_calls_ < num_parallel_calls_ &&
- invocation_results_.size() < MaxInvocationResults()) {
+ invocation_results_.size() < num_parallel_calls_) {
invocation_results_.emplace_back(new InvocationResult());
new_calls.push_back(invocation_results_.back());
num_calls_++;
@@ -294,7 +325,6 @@ class ParallelMapIterator : public DatasetBaseIterator {
const DatasetBase* const input_dataset_; // Not owned.
const std::function<Status(IteratorContext*)> init_func_;
const ParallelMapIteratorFunction map_func_;
- const int32 num_parallel_calls_;
// Used for coordination between the main thread and the runner thread.
mutex mu_;
// Used for coordination between the main thread and the runner thread. In
@@ -303,6 +333,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
// parallelism and there are slots available in the `invocation_results_`
// buffer.
condition_variable cond_var_;
+ // Identifies the maximum number of parallel calls.
+ int64 num_parallel_calls_ GUARDED_BY(mu_) = 0;
// Counts the number of outstanding calls.
int64 num_calls_ GUARDED_BY(mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
@@ -333,4 +365,5 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
std::move(map_func), num_parallel_calls));
}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h
index 7e6cc586f3..dc26c5cf25 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.h
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
+namespace data {
// A function that transforms elements of one dataset into another
// asynchronously. The arguments are:
@@ -47,6 +48,7 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
int32 num_parallel_calls);
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_MAP_ITERATOR_H_
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
index 9057800d94..c28c06da62 100644
--- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/core/util/example_proto_fast_parsing.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -87,11 +87,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
"Expected len(dense_defaults) == len(dense_keys) but got: ",
dense_default_tensors.size(), " vs. ", dense_keys_.size()));
- std::vector<Tensor> dense_defaults;
- dense_defaults.reserve(dense_default_tensors.size());
- for (const Tensor& dense_default_t : dense_default_tensors) {
- dense_defaults.push_back(dense_default_t);
- }
+ std::vector<Tensor> dense_defaults(dense_default_tensors.begin(),
+ dense_default_tensors.end());
for (int d = 0; d < dense_keys_.size(); ++d) {
const Tensor& def_value = dense_defaults[d];
@@ -368,5 +365,5 @@ REGISTER_KERNEL_BUILDER(Name("ParseExampleDataset").Device(DEVICE_CPU),
ParseExampleDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.cc b/tensorflow/core/kernels/data/prefetch_autotuner.cc
index b3272f6bcd..da357339c9 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner.cc
+++ b/tensorflow/core/kernels/data/prefetch_autotuner.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/prefetch_autotuner.h"
namespace tensorflow {
+namespace data {
PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size)
: buffer_limit_(initial_buffer_size) {
@@ -25,6 +26,13 @@ PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size)
}
}
+namespace {
+// Determines what strategy to use for increasing the buffer size limit. For
+// limits less than the threshold, an exponential increase is used, while for
+// limits greater than or equal to the threshold, a linear increase is used.
+size_t kBufferLimitThreshold = 2048;
+} // namespace
+
void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) {
switch (mode_) {
case Mode::kDisabled:
@@ -36,11 +44,16 @@ void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) {
return;
case Mode::kDownswing:
if (current_buffer_size == 0) {
- buffer_limit_ *= 2; // Increase the buffer size.
+ if (buffer_limit_ >= kBufferLimitThreshold) {
+ buffer_limit_ += kBufferLimitThreshold;
+ } else {
+ buffer_limit_ *= 2;
+ }
mode_ = Mode::kUpswing;
}
return;
}
}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.h b/tensorflow/core/kernels/data/prefetch_autotuner.h
index fa8a184072..8693205512 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner.h
+++ b/tensorflow/core/kernels/data/prefetch_autotuner.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
+namespace data {
// PrefetchAutotuner dynamically adjusts the buffer size of a prefetch iterator.
//
@@ -66,6 +67,7 @@ class PrefetchAutotuner {
Mode mode_ = Mode::kDisabled;
};
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_AUTOTUNER_H_
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
index 29a8cc50cd..cfc324fc7e 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
+++ b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
+namespace data {
namespace {
TEST(PrefetchAutotuner, Disabled) {
@@ -79,4 +80,5 @@ TEST(PrefetchAutotuner, EnabledSteady) {
}
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 50efbcbe2a..52c421caee 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -12,15 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <deque>
-
#include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
+#include <deque>
+
#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
+namespace data {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
@@ -70,7 +74,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- auto_tuner_(params.dataset->buffer_size_) {}
+ auto_tuner_(params.dataset->buffer_size_) {
+ std::vector<string> components =
+ str_util::Split(params.prefix, "::", str_util::SkipEmpty());
+ prefix_end_ = components.back();
+ }
~Iterator() override {
// Signal the prefetch thread to terminate it. We will then
@@ -97,13 +105,16 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
bool* end_of_sequence) override {
{
mutex_lock l(mu_);
+ auto stats_aggregator = ctx->stats_aggregator();
TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
// Wait until the next element in the buffer has been
// produced, or we are shutting down.
while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
auto_tuner_.buffer_limit() != 0) {
auto_tuner_.RecordEmpty();
+ StopWork(ctx);
cond_var_.wait(l);
+ StartWork(ctx);
}
if (cancelled_) {
@@ -112,7 +123,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
}
if (!buffer_.empty()) {
- return Consume(out_tensors, end_of_sequence);
+ return Consume(out_tensors, end_of_sequence, stats_aggregator);
}
if (prefetch_thread_finished_) {
@@ -200,14 +211,22 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
std::vector<Tensor> value;
};
- Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence)
+ Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence,
+ const std::shared_ptr<StatsAggregator>& stats_aggregator)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (stats_aggregator) {
+ stats_aggregator->AddToHistogram(
+ strings::StrCat(prefix_end_, "::buffer_utilization"),
+ {static_cast<float>(buffer_.size()) /
+ static_cast<float>(auto_tuner_.buffer_limit())});
+ }
// A new element is available. Forward the status from computing it, and
// (if we successfully got an element) the output values.
Status s = buffer_.front().status;
if (s.ok()) {
*out_tensors = std::move(buffer_.front().value);
}
+ auto_tuner_.RecordConsumption(buffer_.size());
buffer_.pop_front();
*end_of_sequence = false;
@@ -223,10 +242,10 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!prefetch_thread_) {
- prefetch_thread_.reset(
- ctx->env()->StartThread({}, "prefetch_thread",
- std::bind(&Iterator::PrefetchThread, this,
- new IteratorContext(*ctx))));
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+ prefetch_thread_.reset(ctx->env()->StartThread(
+ {}, "prefetch_thread",
+ [this, new_ctx]() { PrefetchThread(new_ctx); }));
}
return Status::OK();
}
@@ -235,8 +254,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
// buffer.
//
// It owns the iterator context passed to it.
- void PrefetchThread(IteratorContext* ctx) {
- std::unique_ptr<IteratorContext> cleanup(ctx);
+ void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
+ StartWork(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
while (true) {
std::vector<Tensor> value;
@@ -244,7 +264,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
{
mutex_lock l(mu_);
while (!cancelled_ && buffer_.size() >= auto_tuner_.buffer_limit()) {
+ StopWork(ctx.get());
cond_var_.wait(l);
+ StartWork(ctx.get());
}
if (cancelled_) {
@@ -261,8 +283,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
mutex_lock parent_l(parent_mu_);
bool end_of_sequence;
BufferElement buffer_element;
- buffer_element.status =
- input_impl_->GetNext(ctx, &buffer_element.value, &end_of_sequence);
+ buffer_element.status = input_impl_->GetNext(
+ ctx.get(), &buffer_element.value, &end_of_sequence);
if (buffer_element.status.ok() && end_of_sequence) {
mutex_lock l(mu_);
prefetch_thread_finished_ = true;
@@ -324,6 +346,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
mutex parent_mu_ ACQUIRED_BEFORE(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_);
condition_variable cond_var_;
+ string prefix_end_;
PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_);
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_);
@@ -346,6 +369,7 @@ void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
*output = new Dataset(ctx, input, buffer_size);
}
+namespace {
REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU),
PrefetchDatasetOp);
REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
@@ -354,4 +378,7 @@ REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
.HostMemory("input_dataset")
.HostMemory("handle"),
PrefetchDatasetOp);
+} // namespace
+
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.h b/tensorflow/core/kernels/data/prefetch_dataset_op.h
index c40c4b00da..588fb25a06 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.h
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/prefetch_autotuner.h"
namespace tensorflow {
+namespace data {
class PrefetchDatasetOp : public UnaryDatasetOpKernel {
public:
@@ -34,6 +35,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
class Dataset;
};
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_DATASET_OP_H_
diff --git a/tensorflow/core/kernels/data/random_dataset_op.cc b/tensorflow/core/kernels/data/random_dataset_op.cc
index 7817170e73..044a791a3f 100644
--- a/tensorflow/core/kernels/data/random_dataset_op.cc
+++ b/tensorflow/core/kernels/data/random_dataset_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random_distributions.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -151,5 +151,5 @@ REGISTER_KERNEL_BUILDER(Name("RandomDataset").Device(DEVICE_CPU),
RandomDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc
index aa38775125..89fbaae369 100644
--- a/tensorflow/core/kernels/data/range_dataset_op.cc
+++ b/tensorflow/core/kernels/data/range_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -142,5 +142,5 @@ REGISTER_KERNEL_BUILDER(Name("RangeDataset").Device(DEVICE_CPU),
RangeDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc
index 086b552936..c474cb4773 100644
--- a/tensorflow/core/kernels/data/reader_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/zlib_inputstream.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -691,5 +691,5 @@ REGISTER_KERNEL_BUILDER(Name("TFRecordDataset").Device(DEVICE_CPU),
TFRecordDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc
index 299949b99f..94e96635ab 100644
--- a/tensorflow/core/kernels/data/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -250,5 +250,5 @@ REGISTER_KERNEL_BUILDER(Name("RepeatDataset").Device(DEVICE_CPU),
RepeatDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index fccad933d0..dbe31f37b8 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -45,23 +45,12 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
OpInputList initial_state_inputs;
OP_REQUIRES_OK(ctx,
ctx->input_list("initial_state", &initial_state_inputs));
- std::vector<Tensor> initial_state;
- initial_state.reserve(initial_state_inputs.size());
- for (const Tensor& t : initial_state_inputs) {
- initial_state.push_back(t);
- }
-
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
+ std::vector<Tensor> initial_state(initial_state_inputs.begin(),
+ initial_state_inputs.end());
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, func_, std::move(initial_state),
std::move(captured_func), state_types_, output_types_,
@@ -279,5 +268,5 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("ScanDataset").Device(DEVICE_CPU), ScanDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
index 93a4376836..66466d6a36 100644
--- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
const int64 kLogIntervalMicros = 10 * 1000000; // 10 seconds.
@@ -620,5 +620,5 @@ REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
ShuffleAndRepeatDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/single_threaded_executor.cc b/tensorflow/core/kernels/data/single_threaded_executor.cc
new file mode 100644
index 0000000000..5b084a16f0
--- /dev/null
+++ b/tensorflow/core/kernels/data/single_threaded_executor.cc
@@ -0,0 +1,380 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/single_threaded_executor.h"
+
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
+typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec;
+typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
+
+class SingleThreadedExecutorImpl : public Executor {
+ public:
+ explicit SingleThreadedExecutorImpl(const LocalExecutorParams& params)
+ : params_(params) {}
+
+ ~SingleThreadedExecutorImpl() override {
+ for (const KernelState& kernel_state : kernels_) {
+ params_.delete_kernel(kernel_state.kernel);
+ }
+ }
+
+ Status Initialize(const Graph& graph) {
+ // Topologicially sort `graph` to get a sequence of OpKernels.
+ std::vector<Node*> ordered_nodes;
+ ordered_nodes.reserve(graph.num_nodes());
+ GetReversePostOrder(graph, &ordered_nodes);
+
+ if (ordered_nodes.size() != graph.num_nodes()) {
+ return errors::InvalidArgument("Graph had ", graph.num_nodes(),
+ " but reverse post-order had ",
+ ordered_nodes.size());
+ }
+
+ kernels_.resize(ordered_nodes.size());
+
+ std::unordered_map<Node*, size_t> node_to_index_map;
+
+ // Create the kernel and input-related structures for each node in `graph`.
+ for (size_t i = 0; i < ordered_nodes.size(); ++i) {
+ Node* n = ordered_nodes[i];
+ node_to_index_map[n] = i;
+
+ for (DataType dt : n->output_types()) {
+ if (IsRefType(dt)) {
+ return errors::Unimplemented(
+ "Single-threaded executor does not support reference-typed "
+ "edges.");
+ }
+ }
+
+ if (n->IsControlFlow()) {
+ return errors::Unimplemented(
+ "Single-threaded executor does not support control flow.");
+ }
+ if (n->IsSend() || n->IsHostSend() || n->IsRecv() || n->IsHostRecv()) {
+ return errors::Unimplemented(
+ "Single-threaded executor does not support partitioned graphs.");
+ }
+ if (n->IsCollective()) {
+ return errors::Unimplemented(
+ "Single-threaded executor does not support collective ops.");
+ }
+
+ KernelState& kernel_state = kernels_[i];
+ TF_RETURN_IF_ERROR(params_.create_kernel(n->def(), &kernel_state.kernel));
+ kernel_state.num_inputs = n->num_inputs();
+ kernel_state.num_outputs = n->num_outputs();
+
+ if (i == 0) {
+ kernel_state.input_start_index = 0;
+ } else {
+ const KernelState& previous_kernel_state = kernels_[i - 1];
+ kernel_state.input_start_index =
+ previous_kernel_state.input_start_index +
+ previous_kernel_state.num_inputs;
+ }
+ }
+
+ // Build the mapping from each node output to the input slot for the
+ // corresponding destination node.
+ for (size_t i = 0; i < ordered_nodes.size(); ++i) {
+ Node* n = ordered_nodes[i];
+ KernelState& kernel_state = kernels_[i];
+ kernel_state.output_locations.resize(kernel_state.num_outputs);
+ for (const Edge* e : n->out_edges()) {
+ if (!e->IsControlEdge()) {
+ kernel_state.output_locations[e->src_output()].push_back(
+ kernels_[node_to_index_map[e->dst()]].input_start_index +
+ e->dst_input());
+ }
+ }
+
+ // Compute allocator attributes for each node output, and corresponding
+ // node input.
+ kernel_state.output_alloc_attrs.resize(kernel_state.num_outputs);
+ AllocatorAttributes* attrs = kernel_state.output_alloc_attrs.data();
+
+ OpKernel* op_kernel = kernel_state.kernel;
+ for (int out = 0; out < n->num_outputs(); out++) {
+ DCHECK_LT(out, op_kernel->output_memory_types().size());
+ bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY;
+ if (on_host) {
+ AllocatorAttributes h;
+ h.set_on_host(on_host);
+ attrs[out].Merge(h);
+ }
+ }
+ }
+
+ if (!kernels_.empty()) {
+ const KernelState& last_kernel_state = kernels_.back();
+ total_num_inputs_ =
+ last_kernel_state.input_start_index + last_kernel_state.num_inputs;
+ input_alloc_attrs_.resize(total_num_inputs_);
+ for (size_t i = 0; i < ordered_nodes.size(); ++i) {
+ for (size_t j = 0; j < kernels_[i].output_locations.size(); ++j) {
+ for (size_t output_location : kernels_[i].output_locations[j]) {
+ input_alloc_attrs_[output_location] =
+ kernels_[i].output_alloc_attrs[j];
+ }
+ }
+ }
+ } else {
+ total_num_inputs_ = 0;
+ }
+ return Status::OK();
+ }
+
+ // TODO(mrry): Consider specializing the implementation of Executor::Run()
+ // instead, to avoid unnecessary atomic operations in the callback when
+ // running synchronously.
+ void RunAsync(const Args& args, DoneCallback done) override {
+ // The inputs to each kernel are stored contiguously in `inputs`.
+ //
+ // We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to
+ // determine the range of elements in this vector that correspond to
+ // the inputs of `kernels_[i]`.
+ //
+ // This vector has the following layout:
+ //
+ // * Kernel 0, input 0.
+ // * Kernel 0, input 1.
+ // * ...
+ // * Kernel 0, input `kernels_[0].num_inputs - 1`.
+ // * Kernel 1, input 0.
+ // * ...
+ // * Kernel 1, input `kernels_[1].num_inputs - 1`.
+ // * ...
+ // * Kernel `kernels_.size() - 1`, input 0.
+ // * ...
+ // * Kernel `kernels_.size() - 1`, input `kernels_.back().num_inputs - 1`.
+ //
+ // Note that kernels with zero inputs do not correspond to any elements in
+ // this vector.
+ //
+ // We use `ManualConstructor<Tensor>` to avoid the overhead of
+ // default-constructing an invalid `Tensor` for each slot at the beginning
+ // of execution:
+ // * Elements are initialized when the outputs of a kernel execution are
+ // propagated to the inputs of kernels that depend on them.
+ // * The elements corresponding to the inputs for kernel `i` are destroyed
+ // after kernel `i` executes.
+ // * In an error case (see below), we use the connectivity information in
+ // `KernelState::output_locations` to determine which locations have been
+ // initialized, and manually destroy them.
+ std::vector<ManualConstructor<Tensor>> inputs(total_num_inputs_);
+
+ // TODO(mrry): Can we avoid copying into these vectors? Consider modifying
+ // OpKernelContext to take the TensorValueVec as a pointer into `inputs`.
+ TensorValueVec node_inputs;
+ DeviceContextVec input_device_contexts;
+ AllocatorAttributeVec input_alloc_attrs;
+
+ // Prepare the parameters that will be the same for all kernels.
+ OpKernelContext::Params params;
+ params.step_id = args.step_id;
+ Device* device = params_.device;
+ params.device = device;
+ params.log_memory = false; // TODO(mrry): Too severe?
+ params.record_tensor_accesses = false; // TODO(mrry): Too severe?
+ params.rendezvous = args.rendezvous;
+ params.session_state = args.session_state;
+ params.tensor_store = args.tensor_store;
+ params.cancellation_manager = args.cancellation_manager;
+ // TODO(mrry): ArgOp is a relatively expensive OpKernel due to the Tensor
+ // allocations that it performs. Consider specializing its handling in the
+ // executor.
+ params.call_frame = args.call_frame;
+ params.function_library = params_.function_library;
+ params.resource_manager = device->resource_manager();
+ params.step_container = args.step_container;
+ params.slice_reader_cache = nullptr; // TODO(mrry): Too severe?
+ params.inputs = &node_inputs;
+ params.input_device_contexts = &input_device_contexts;
+ params.input_alloc_attrs = &input_alloc_attrs;
+
+ Args::Runner runner_copy = args.runner;
+ params.runner = &runner_copy;
+ params.stats_collector = args.stats_collector;
+
+ // NOTE(mrry): We are assuming that the graph is loopless and condless.
+ params.frame_iter = FrameAndIter(0, 0);
+ params.is_input_dead = false;
+
+ // TODO(mrry): Add non-default device context inference.
+ params.op_device_context = nullptr;
+ // TODO(mrry): Consider implementing forwarding.
+ params.forward_from_array = nullptr;
+
+ // Execute the kernels one-at-a-time in topological order.
+ for (size_t i = 0; i < kernels_.size(); ++i) {
+ const KernelState& kernel_state = kernels_[i];
+
+ // Prepare the per-kernel parameters.
+ const size_t input_start_index = kernel_state.input_start_index;
+ const size_t num_inputs = kernel_state.num_inputs;
+ const size_t num_outputs = kernel_state.num_outputs;
+
+ node_inputs.clear();
+ node_inputs.resize(num_inputs);
+ input_alloc_attrs.clear();
+ input_alloc_attrs.resize(num_inputs);
+ for (size_t j = 0; j < num_inputs; ++j) {
+ auto t = inputs[input_start_index + j].get();
+ node_inputs[j].tensor = t;
+ input_alloc_attrs[j] = input_alloc_attrs_[input_start_index + j];
+ }
+ params.op_kernel = kernel_state.kernel;
+ input_device_contexts.clear();
+ input_device_contexts.resize(num_inputs);
+ params.output_attr_array = kernel_state.output_alloc_attrs.data();
+ OpKernelContext ctx(&params, num_outputs);
+
+ // Actually execute the kernel.
+ device->Compute(kernel_state.kernel, &ctx);
+
+ if (!ctx.status().ok()) {
+ // On failure, we must manually free all intermediate tensors. We have
+ // already freed all the inputs for kernels up to (but not including)
+ // the `i`th kernel. We scan through the previously executed kernels and
+ // destroy any tensors that were destined to be the input for a kernel
+ // that has not yet executed.
+ for (size_t j = 0; j < i; ++j) {
+ const KernelState& executed_kernel_state = kernels_[j];
+ for (size_t k = 0; k < executed_kernel_state.num_outputs; ++k) {
+ for (size_t output_location :
+ executed_kernel_state.output_locations[k]) {
+ if (output_location >= input_start_index) {
+ // Only destroy an output location if it is an input to an
+ // operation that has not yet executed.
+ inputs[output_location].Destroy();
+ }
+ }
+ }
+ }
+ done(ctx.status());
+ return;
+ }
+
+ // Free the inputs to the current kernel.
+ for (size_t j = 0; j < num_inputs; ++j) {
+ inputs[input_start_index + j].Destroy();
+ }
+
+ // Forward the outputs of the kernel to the inputs of subsequent kernels.
+ for (size_t j = 0; j < num_outputs; ++j) {
+ TensorValue val = ctx.release_output(j);
+ // TODO(mrry): Consider flattening the `output_locations` vector
+ // to improve the cache-friendliness of this loop.
+ for (size_t output_location : kernel_state.output_locations[j]) {
+ // TODO(mrry): Validate that the types match the expected values or
+ // ensure that the necessary validation has already happened.
+ inputs[output_location].Init(*val.tensor);
+ }
+ delete val.tensor;
+ }
+ }
+ done(Status::OK());
+ }
+
+ private:
+ const LocalExecutorParams params_;
+
+ // All following members are read-only after Initialize().
+
+ // The sum of the number of inputs for each node in the graph. This determines
+ // the length of the flat `inputs` vector. See comment at the beginning of
+ // `RunAsync()` for details.
+ size_t total_num_inputs_;
+
+ // Represents cached graph structure state for each kernel.
+ struct KernelState {
+ // The kernel object. Not owned.
+ //
+ // This pointer is managed by `params_.create_kernel()` and
+ // `params_.delete_kernel()`.
+ OpKernel* kernel;
+
+ // These fields determine the range of elements in `inputs` that corresponds
+ // to the inputs of `kernel`.
+ size_t input_start_index;
+ size_t num_inputs;
+
+ size_t num_outputs;
+
+ // For the `j`th output of `kernel`, `output_locations[j]` contains the
+ // locations in the flat `inputs` vector to which that output must be
+ // copied. See comment at the beginning of `RunAsync()` for details.
+ std::vector<std::vector<size_t>>
+ output_locations; // Length = `num_outputs`.
+
+ // Memory space information for each output of `kernel`.
+ std::vector<AllocatorAttributes>
+ output_alloc_attrs; // Length = `num_outputs`.
+ };
+ std::vector<KernelState> kernels_;
+
+ // Memory space information for each input. This information is stored in the
+ // same order as the flat `inputs` vector. See comment at the beginning of
+ // `RunAsync()` for details.
+ std::vector<AllocatorAttributes>
+ input_alloc_attrs_; // Length = `total_num_inputs_`.
+};
+
+class SingleThreadedExecutorRegistrar {
+ public:
+ SingleThreadedExecutorRegistrar() {
+ ExecutorFactory::Register("SINGLE_THREADED_EXECUTOR", new Factory());
+ }
+
+ private:
+ class Factory : public ExecutorFactory {
+ Status NewExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) override {
+ Executor* ret;
+ TF_RETURN_IF_ERROR(
+ NewSingleThreadedExecutor(params, std::move(graph), &ret));
+ out_executor->reset(ret);
+ return Status::OK();
+ }
+ };
+};
+static SingleThreadedExecutorRegistrar registrar;
+
+} // namespace
+
+Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ Executor** executor) {
+ std::unique_ptr<SingleThreadedExecutorImpl> impl(
+ new SingleThreadedExecutorImpl(params));
+ TF_RETURN_IF_ERROR(impl->Initialize(*graph));
+ *executor = impl.release();
+ return Status::OK();
+}
+
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/single_threaded_executor.h b/tensorflow/core/kernels/data/single_threaded_executor.h
new file mode 100644
index 0000000000..e934352a1d
--- /dev/null
+++ b/tensorflow/core/kernels/data/single_threaded_executor.h
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_
+
+#include "tensorflow/core/common_runtime/executor.h"
+
+namespace tensorflow {
+namespace data {
+
+// Creates a new `Executor` for executing `graph` synchronously on the caller
+// thread.
+//
+// NOTE(mrry): The returned executor is optimized to impose low overhead on
+// graphs that perform a small amount of work (e.g. <15us of work per graph on
+// present architectures). It eschews concurrency, because issuing work to
+// multiple threads can dominate the cost of executing small ops synchronously,
+// and because contention in the executor data structures can reduce throughput
+// (in terms of ops executed per unit time).
+//
+// However, the current implementation has the following limitations:
+//
+// 1. Reference-typed tensors are not supported and will not be supported in
+// future.
+// 2. Graphs with control flow (containing "Switch" and "Merge" nodes) are not
+// currently supported. The current plan is to extend support to "functional"
+// control flow after the TensorFlow APIs transition to building graphs in
+// that form (e.g. `tf.cond_v2()`).
+// 3. Partitioned graphs (containing "_Recv" nodes) are not currently supported.
+// The present implementation executes kernels one at a time in topological
+// order, and cannot currently distinguish between disconnected subgraphs
+// that are logically connected by subgraphs on a different device.
+// 4. Memory logging is not currently supported.
+// 5. Allocation forwarding is not currently supported.
+// 6. Non-default device contexts are not currently supported. In effect, this
+// limits the executor to CPU devices.
+// 7. Ops that rely on `OpKernelContext::slice_reader_cache()` being non-null
+// are not currently supported.
+//
+// The single-threaded executor is primarily suitable for executing simple
+// TensorFlow functions, such as one might find in a `tf.data` pipeline.
+Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ Executor** executor);
+
+} // namespace data
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_
diff --git a/tensorflow/core/kernels/data/single_threaded_executor_test.cc b/tensorflow/core/kernels/data/single_threaded_executor_test.cc
new file mode 100644
index 0000000000..6244e287bb
--- /dev/null
+++ b/tensorflow/core/kernels/data/single_threaded_executor_test.cc
@@ -0,0 +1,332 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/single_threaded_executor.h"
+
+#include <algorithm>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class ExecutorTest : public ::testing::Test {
+ protected:
+ ExecutorTest()
+ : device_(DeviceFactory::NewDevice("CPU", {},
+ "/job:localhost/replica:0/task:0")) {}
+
+ ~ExecutorTest() override {
+ // There should always be exactly one Ref left on the Rendezvous
+ // when the test completes.
+ CHECK(rendez_->Unref());
+ delete exec_;
+ delete device_;
+ }
+
+ // Resets executor_ with a new executor based on a graph 'gdef'.
+ void Create(std::unique_ptr<const Graph> graph) {
+ const int version = graph->versions().producer();
+ LocalExecutorParams params;
+ params.device = device_;
+ params.create_kernel = [this, version](const NodeDef& ndef,
+ OpKernel** kernel) {
+ return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel);
+ };
+ params.delete_kernel = [](OpKernel* kernel) {
+ DeleteNonCachedKernel(kernel);
+ };
+ delete exec_;
+ TF_CHECK_OK(NewSingleThreadedExecutor(params, std::move(graph), &exec_));
+ runner_ = [](std::function<void()> fn) { fn(); };
+ rendez_ = NewLocalRendezvous();
+ }
+
+ Status Run(Rendezvous* rendez) {
+ Executor::Args args;
+ args.rendezvous = rendez;
+ args.runner = runner_;
+ return exec_->Run(args);
+ }
+
+ Status Run(CallFrameInterface* call_frame) {
+ Executor::Args args;
+ args.call_frame = call_frame;
+ args.runner = runner_;
+ return exec_->Run(args);
+ }
+
+ Device* device_ = nullptr;
+ Executor* exec_ = nullptr;
+ Executor::Args::Runner runner_;
+ Rendezvous* rendez_ = nullptr;
+};
+
+// A float val -> Tensor<float>
+Tensor V(const float val) {
+ Tensor tensor(DT_FLOAT, TensorShape({}));
+ tensor.scalar<float>()() = val;
+ return tensor;
+}
+
+// A int32 val -> Tensor<int32>
+Tensor VI(const int32 val) {
+ Tensor tensor(DT_INT32, TensorShape({}));
+ tensor.scalar<int32>()() = val;
+ return tensor;
+}
+
+// A bool val -> Tensor<bool>
+Tensor VB(const bool val) {
+ Tensor tensor(DT_BOOL, TensorShape({}));
+ tensor.scalar<bool>()() = val;
+ return tensor;
+}
+
+// A double val -> Tensor<double>
+Tensor VD(const double val) {
+ Tensor tensor(DT_DOUBLE, TensorShape({}));
+ tensor.scalar<double>()() = val;
+ return tensor;
+}
+
+// Tensor<float> -> a float val.
+float V(const Tensor& tensor) {
+ CHECK_EQ(tensor.dtype(), DT_FLOAT);
+ CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
+ return tensor.scalar<float>()();
+}
+
+Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
+ const string& receiver, const string& name) {
+ Rendezvous::ParsedKey result;
+ TF_CHECK_OK(
+ Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
+ name, FrameAndIter(0, 0)),
+ &result));
+ return result;
+}
+
+TEST_F(ExecutorTest, SimpleAdd) {
+ // c = a + b
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+ auto in0 = test::graph::Arg(g.get(), 0, DT_FLOAT);
+ auto in1 = test::graph::Arg(g.get(), 0, DT_FLOAT);
+ auto tmp = test::graph::Add(g.get(), in0, in1);
+ test::graph::Retval(g.get(), 0, tmp);
+ FixupSourceAndSinkEdges(g.get());
+ Create(std::move(g));
+ FunctionCallFrame call_frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT});
+ TF_ASSERT_OK(call_frame.SetArgs({V(1.0), V(1.0)}));
+ TF_ASSERT_OK(Run(&call_frame));
+ std::vector<Tensor> retvals;
+ TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
+ EXPECT_EQ(2.0, V(retvals[0])); // out = 1.0 + 1.0 = 2.0
+}
+
+TEST_F(ExecutorTest, SelfAdd) {
+ // v0 <- a
+ // v1 = v0 + v0
+ // v2 = v1 + v1
+ // ... ...
+ // v10 = v9 + v9
+ //
+ // b <- v10
+ // All nodes are executed by one thread.
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+ auto v = test::graph::Arg(g.get(), 0, DT_FLOAT);
+ const int N = 10;
+ for (int i = 1; i <= N; ++i) {
+ v = test::graph::Add(g.get(), v, v);
+ }
+ // out <- v10
+ test::graph::Retval(g.get(), 0, v);
+ FixupSourceAndSinkEdges(g.get());
+ Create(std::move(g));
+ FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT});
+ // a = 1.0
+ TF_ASSERT_OK(call_frame.SetArgs({V(1.0)}));
+ TF_ASSERT_OK(Run(&call_frame));
+ std::vector<Tensor> retvals;
+ TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
+ EXPECT_EQ(1024.0, V(retvals[0])); // b=v10=2*v9=4*v8=...=1024*a=1024.0
+}
+
+// Builds a graph which adds N copies of one variable "in". I.e.,
+// a + a + a + ... + a
+// The returned graph is parenthesized ramdonly. I.e.,
+// a + ((a + a) + a)
+// (a + a) + (a + a)
+// ((a + a) + a) + a
+// are all possibly generated.
+void BuildTree(int N, Graph* g) {
+ CHECK_GT(N, 1);
+ // A single input node "in".
+ auto in = test::graph::Arg(g, 0, DT_FLOAT);
+ std::vector<Node*> nodes;
+ int i = 0;
+ // Duplicate "in" N times. Each copies is named as l0, l1, l2, ....
+ for (; i < N; ++i) {
+ nodes.push_back(test::graph::Identity(g, in, 0));
+ }
+ random::PhiloxRandom philox(0, 17);
+ random::SimplePhilox rnd(&philox);
+ while (nodes.size() > 1) {
+ // Randomly pick two from nodes and add them. The resulting node
+ // is named lik n10, n11, .... and is put back into "nodes".
+ int x = rnd.Uniform(nodes.size());
+ auto in0 = nodes[x];
+ nodes[x] = nodes.back();
+ nodes.resize(nodes.size() - 1);
+ x = rnd.Uniform(nodes.size());
+ auto in1 = nodes[x];
+ // node = in0 + in1.
+ nodes[x] = test::graph::Add(g, in0, in1);
+ }
+ // The final output node "out".
+ test::graph::Retval(g, 0, nodes.back());
+ FixupSourceAndSinkEdges(g);
+}
+
+TEST_F(ExecutorTest, RandomTree) {
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+ BuildTree(4096, g.get());
+ Create(std::move(g));
+ FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT});
+ TF_ASSERT_OK(call_frame.SetArgs({V(1.0)}));
+ TF_ASSERT_OK(Run(&call_frame));
+ std::vector<Tensor> retvals;
+ TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
+ EXPECT_EQ(4096.0, V(retvals[0]));
+}
+
+TEST_F(ExecutorTest, OpError) {
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+ auto zero = test::graph::Constant(g.get(), V(0.0));
+ auto inf = test::graph::Unary(g.get(), "Reciprocal", zero);
+ auto check = test::graph::CheckNumerics(g.get(), inf, "message");
+ auto two = test::graph::Constant(g.get(), V(2.0));
+ test::graph::Binary(g.get(), "Mul", check, two);
+ FixupSourceAndSinkEdges(g.get());
+ Create(std::move(g));
+ FunctionCallFrame call_frame({}, {});
+ // Fails due to invalid dtype.
+ EXPECT_TRUE(errors::IsInvalidArgument(Run(&call_frame)));
+}
+
+static void BM_executor(int iters, int width, int depth) {
+#ifdef PLATFORM_GOOGLE
+ BenchmarkUseRealTime();
+#endif // PLATFORM_GOOGLE
+ Graph* g = new Graph(OpRegistry::Global());
+ random::PhiloxRandom philox(1729, 17);
+ random::SimplePhilox rand(&philox);
+ uint64 cur = 0;
+ uint32 r = 1 + rand.Rand32() % width;
+ std::vector<Node*> ready_nodes;
+ for (int i = 0; i < r; ++i) {
+ ready_nodes.push_back(test::graph::NoOp(g, {}));
+ ++cur;
+ }
+ for (int i = 0; i < depth; ++i) {
+ std::random_shuffle(ready_nodes.begin(), ready_nodes.end());
+ r = 1 + rand.Rand32() % (ready_nodes.size());
+ std::vector<Node*> control_inputs;
+ for (int j = 0; j < r; ++j) {
+ control_inputs.push_back(ready_nodes.back());
+ ready_nodes.pop_back();
+ }
+ Node* n = test::graph::NoOp(g, control_inputs);
+ ++cur;
+ r = 1 + rand.Rand32() % width;
+ for (int j = 0; j < r; ++j) {
+ ready_nodes.push_back(test::graph::NoOp(g, {n}));
+ ++cur;
+ }
+ }
+ FixupSourceAndSinkEdges(g);
+#ifdef PLATFORM_GOOGLE
+ SetBenchmarkLabel(strings::StrCat("Nodes = ", cur));
+ SetBenchmarkItemsProcessed(cur * static_cast<int64>(iters));
+#endif // PLATFORM_GOOGLE
+ test::Benchmark("cpu", g, nullptr, nullptr, nullptr,
+ "SINGLE_THREADED_EXECUTOR")
+ .Run(iters);
+}
+
+// Tall skinny graphs
+BENCHMARK(BM_executor)->ArgPair(16, 1024);
+BENCHMARK(BM_executor)->ArgPair(32, 8192);
+
+// Short fat graphs
+BENCHMARK(BM_executor)->ArgPair(1024, 16);
+BENCHMARK(BM_executor)->ArgPair(8192, 32);
+
+// Tall fat graph
+BENCHMARK(BM_executor)->ArgPair(1024, 1024);
+
+// TODO(mrry): This benchmark currently crashes with a use-after free, because
+// test::Benchmark::RunWithArgs() assumes that the executor will take ownership
+// of the given graph, *and* keep its nodes (`x`, `y` and `z`) alive for the
+// duration of the benchmark. Since the single threaded executor does not retain
+// a copy of the graph, this fails.
+//
+// TODO(mrry): Add support for Arg/Retval "function call convention" in
+// `test::Benchmark::RunWithArgs()`.
+#if 0
+#define ALICE "/job:j/replica:0/task:0/cpu:0"
+#define BOB "/job:j/replica:0/task:0/gpu:0"
+
+static void BM_FeedInputFetchOutput(int iters) {
+ Graph* g = new Graph(OpRegistry::Global());
+ // z = x + y: x and y are provided as benchmark inputs. z is the
+ // output of the benchmark. Conceptually, the caller is ALICE, the
+ // benchmark is BOB.
+ Node* x = test::graph::Recv(g, "x", "float", ALICE, 1, BOB);
+ Node* y = test::graph::Recv(g, "y", "float", ALICE, 1, BOB);
+ Node* sum = test::graph::Add(g, x, y);
+ Node* z = test::graph::Send(g, sum, "z", BOB, 1, ALICE);
+ FixupSourceAndSinkEdges(g);
+ Tensor val(DT_FLOAT, TensorShape({}));
+ val.scalar<float>()() = 3.14;
+ SetBenchmarkItemsProcessed(static_cast<int64>(iters));
+ test::Benchmark("cpu", g, nullptr, nullptr, nullptr,
+ "SINGLE_THREADED_EXECUTOR")
+ .RunWithArgs({{x, val}, {y, val}}, {z}, iters);
+}
+BENCHMARK(BM_FeedInputFetchOutput);
+#endif
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc
index fe7ef38d5f..b8c7fb15f4 100644
--- a/tensorflow/core/kernels/data/skip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/skip_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -187,5 +187,5 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU), SkipDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc
index 14df3a6801..1e73cfc753 100644
--- a/tensorflow/core/kernels/data/slide_dataset_op.cc
+++ b/tensorflow/core/kernels/data/slide_dataset_op.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -293,5 +293,5 @@ REGISTER_KERNEL_BUILDER(Name("SlideDataset").Device(DEVICE_CPU),
SlideDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
index e526578701..85b1e50695 100644
--- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/util/sparse/sparse_tensor.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -274,5 +274,5 @@ TF_CALL_DATASET_TYPES(REGISTER_DATASET_KERNEL);
#undef REGISTER_DATASET_KERNEL
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/sql/driver_manager.cc b/tensorflow/core/kernels/data/sql/driver_manager.cc
index ffabda1a8a..783d1e6cb2 100644
--- a/tensorflow/core/kernels/data/sql/driver_manager.cc
+++ b/tensorflow/core/kernels/data/sql/driver_manager.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/sql/sqlite_query_connection.h"
namespace tensorflow {
-
+namespace data {
namespace sql {
std::unique_ptr<QueryConnection> DriverManager::CreateQueryConnection(
@@ -30,5 +30,5 @@ std::unique_ptr<QueryConnection> DriverManager::CreateQueryConnection(
}
} // namespace sql
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/sql/driver_manager.h b/tensorflow/core/kernels/data/sql/driver_manager.h
index a34691b5a2..c5428f396b 100644
--- a/tensorflow/core/kernels/data/sql/driver_manager.h
+++ b/tensorflow/core/kernels/data/sql/driver_manager.h
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/sql/query_connection.h"
namespace tensorflow {
-
+namespace data {
namespace sql {
// A factory class for creating `QueryConnection` instances.
@@ -35,7 +35,7 @@ class DriverManager {
};
} // namespace sql
-
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_
diff --git a/tensorflow/core/kernels/data/sql/query_connection.h b/tensorflow/core/kernels/data/sql/query_connection.h
index e9ffca202f..2fd229a9bf 100644
--- a/tensorflow/core/kernels/data/sql/query_connection.h
+++ b/tensorflow/core/kernels/data/sql/query_connection.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
+namespace data {
class IteratorContext;
@@ -63,7 +64,7 @@ class QueryConnection {
};
} // namespace sql
-
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_
diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
index 7cd07bd8ec..5108e83976 100644
--- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
+++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace tensorflow {
-
+namespace data {
namespace sql {
SqliteQueryConnection::SqliteQueryConnection() {}
@@ -115,5 +115,5 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry(
}
} // namespace sql
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h
index 81b19530b7..175492c49d 100644
--- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h
+++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-
+namespace data {
namespace sql {
class SqliteQueryConnection : public QueryConnection {
@@ -50,7 +50,7 @@ class SqliteQueryConnection : public QueryConnection {
};
} // namespace sql
-
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_
diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc
index 2aa153fcfa..6bbe459332 100644
--- a/tensorflow/core/kernels/data/sql_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc
@@ -24,8 +24,9 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace tensorflow {
-
+namespace data {
namespace {
+
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following ops.
@@ -211,5 +212,5 @@ class SqlDatasetOp : public DatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("SqlDataset").Device(DEVICE_CPU), SqlDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
index 75af73df54..f5314f7a75 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
namespace {
class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
@@ -135,4 +136,5 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("SetStatsAggregatorDataset").Device(DEVICE_CPU),
SetStatsAggregatorDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
index b133cfab54..a7ded67876 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
+namespace data {
namespace {
static mutex* get_counters_map_lock() {
@@ -145,4 +146,5 @@ REGISTER_KERNEL_BUILDER(Name("StatsAggregatorSummary").Device(DEVICE_CPU),
StatsAggregatorSummaryOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc
index 8957f5d997..e9e42f05a1 100644
--- a/tensorflow/core/kernels/data/stats_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
namespace {
// This op defines a `Dataset` that passes through its input elements and
@@ -248,4 +249,5 @@ REGISTER_KERNEL_BUILDER(Name("BytesProducedStatsDataset").Device(DEVICE_CPU),
BytesProducedStatsDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc
index e5c237dfaa..e5cdfdd732 100644
--- a/tensorflow/core/kernels/data/take_dataset_op.cc
+++ b/tensorflow/core/kernels/data/take_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -174,5 +174,5 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc
index fc21c3235a..ca4ea25b89 100644
--- a/tensorflow/core/kernels/data/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc
@@ -14,10 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -28,17 +29,11 @@ class TensorDatasetOp : public DatasetOpKernel {
explicit TensorDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- // Create a new TensorDatasetOp::Dataset, insert it in the step
- // container, and return it as the output.
OpInputList inputs;
OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs));
// TODO(mrry): Validate that the shapes of the "components" tensors match
// the "shapes" attr.;
- std::vector<Tensor> components;
- components.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- components.push_back(t);
- }
+ std::vector<Tensor> components(inputs.begin(), inputs.end());
*output = new Dataset(ctx, std::move(components));
}
@@ -74,7 +69,13 @@ class TensorDatasetOp : public DatasetOpKernel {
components.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
Node* node;
- TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ std::vector<std::pair<string, Tensor>>* input_list = ctx->input_list();
+ if (input_list) {
+ TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
+ input_list->emplace_back(node->name(), t);
+ } else {
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ }
components.emplace_back(node);
}
AttrValue dtypes;
@@ -135,5 +136,5 @@ REGISTER_KERNEL_BUILDER(Name("TensorDataset").Device(DEVICE_CPU),
TensorDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
index ccd5e60acc..2ed636a400 100644
--- a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
bool IsGreaterEqualToOrCompatibleWith(const PartialTensorShape& a,
@@ -648,5 +648,5 @@ REGISTER_KERNEL_BUILDER(Name("EnqueueInQueueDataset").Device(DEVICE_CPU),
EnqueueInQueueDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
index 5b051e0e08..7dc64b0a75 100644
--- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
@@ -14,11 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -30,8 +31,6 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
: DatasetOpKernel(ctx) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- // Create a new TensorDatasetOp::Dataset, insert it in the step
- // container, and return it as the output.
OpInputList inputs;
OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs));
std::vector<Tensor> components;
@@ -93,7 +92,13 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
components.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
Node* node;
- TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ std::vector<std::pair<string, Tensor>>* input_list = ctx->input_list();
+ if (input_list) {
+ TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
+ input_list->emplace_back(node->name(), t);
+ } else {
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ }
components.emplace_back(node);
}
AttrValue dtypes;
@@ -163,5 +168,5 @@ REGISTER_KERNEL_BUILDER(Name("TensorSliceDataset").Device(DEVICE_CPU),
TensorSliceDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
index 1a79f72b28..81c432b938 100644
--- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -204,5 +204,5 @@ REGISTER_KERNEL_BUILDER(Name("UnbatchDataset").Device(DEVICE_CPU),
UnbatchDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc
index 0ab6beabfc..2ad4711aab 100644
--- a/tensorflow/core/kernels/data/window_dataset.cc
+++ b/tensorflow/core/kernels/data/window_dataset.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
+namespace data {
namespace {
class WindowDataset : public DatasetBase {
@@ -107,4 +108,5 @@ Status NewWindowDataset(std::vector<std::vector<Tensor>> elements,
return Status::OK();
}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/window_dataset.h b/tensorflow/core/kernels/data/window_dataset.h
index 7bd31a0bc7..84cb3c7860 100644
--- a/tensorflow/core/kernels/data/window_dataset.h
+++ b/tensorflow/core/kernels/data/window_dataset.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
+namespace data {
// Creates a dataset representing an eagerly-collected window of elements.
//
@@ -43,6 +44,7 @@ Status NewWindowDataset(std::vector<std::vector<Tensor>> elements,
std::vector<PartialTensorShape> output_shapes,
DatasetBase** out_dataset);
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_
diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc
index 41bf9d43fe..ac44623ce2 100644
--- a/tensorflow/core/kernels/data/window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/window_dataset_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/window_dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -33,22 +33,44 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 window_size = 0;
- OP_REQUIRES_OK(
- ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "size", &window_size));
OP_REQUIRES(
ctx, window_size > 0,
errors::InvalidArgument("Window size must be greater than zero."));
- *output = new Dataset(ctx, window_size, input);
+ int64 window_shift = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "shift", &window_shift));
+ OP_REQUIRES(
+ ctx, window_shift > 0,
+ errors::InvalidArgument("Window shift must be greater than zero."));
+
+ int64 window_stride = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "stride", &window_stride));
+ OP_REQUIRES(
+ ctx, window_stride > 0,
+ errors::InvalidArgument("Window stride must be greater than zero."));
+
+ bool drop_remainder;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<bool>(ctx, "drop_remainder", &drop_remainder));
+
+ *output = new Dataset(ctx, input, window_size, window_shift, window_stride,
+ drop_remainder);
}
private:
class Dataset : public DatasetBase {
public:
- Dataset(OpKernelContext* ctx, int64 window_size, const DatasetBase* input)
+ Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 window_size,
+ int64 window_shift, int64 window_stride, bool drop_remainder)
: DatasetBase(DatasetContext(ctx)),
+ input_(input),
window_size_(window_size),
- input_(input) {
+ window_shift_(window_shift),
+ window_stride_(window_stride),
+ drop_remainder_(drop_remainder) {
input_->Ref();
}
@@ -72,7 +94,8 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
}
string DebugString() const override {
- return strings::StrCat("WindowDatasetOp(", window_size_, ")::Dataset");
+ return strings::StrCat("WindowDatasetOp(", window_size_, window_shift_,
+ window_stride_, drop_remainder_, ")::Dataset");
}
protected:
@@ -81,10 +104,19 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
- Node* window_size = nullptr;
- TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size));
+ Node* window_size_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size_node));
+ Node* window_shift_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_shift_, &window_shift_node));
+ Node* window_stride_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_stride_, &window_stride_node));
+ Node* drop_remainder_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node));
TF_RETURN_IF_ERROR(
- b->AddDataset(this, {input_graph_node, window_size}, output));
+ b->AddDataset(this,
+ {input_graph_node, window_size_node, window_shift_node,
+ window_stride_node, drop_remainder_node},
+ output));
return Status::OK();
}
@@ -101,37 +133,79 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- // Each row of `window_elements` is a tuple of tensors from the
- // input iterator.
+ const int64 window_size = dataset()->window_size_;
+ const int64 window_shift = dataset()->window_shift_;
+ const int64 window_stride = dataset()->window_stride_;
std::vector<std::vector<Tensor>> window_elements;
+ Status status = Status::OK();
{
mutex_lock l(mu_);
- if (!input_impl_) {
+ if (!input_impl_ && buffer_.empty()) {
*end_of_sequence = true;
return Status::OK();
}
- window_elements.reserve(dataset()->window_size_);
- *end_of_sequence = false;
- for (int i = 0; i < dataset()->window_size_ && !*end_of_sequence;
- ++i) {
- std::vector<Tensor> window_element_tuple;
- TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &window_element_tuple,
- end_of_sequence));
- if (!*end_of_sequence) {
- window_elements.emplace_back(std::move(window_element_tuple));
- } else {
- input_impl_.reset();
+
+ // Add elements to the buffer.
+ size_t target_size = TargetBufferSize(window_size, window_stride);
+ if (input_impl_) {
+ *end_of_sequence = false;
+ for (size_t i = buffer_.size();
+ i < target_size && !*end_of_sequence; ++i) {
+ std::vector<Tensor> element;
+ Status status =
+ input_impl_->GetNext(ctx, &element, end_of_sequence);
+ if (!*end_of_sequence) {
+ buffer_.emplace_back(std::move(element), status);
+ } else {
+ input_impl_.reset();
+ }
+ }
+ }
+
+ // If there are not enough elements and `drop_remainder` is set, we do
+ // not wish to return a smaller window.
+ if (buffer_.empty() ||
+ (dataset()->drop_remainder_ && buffer_.size() < target_size)) {
+ DCHECK(*end_of_sequence);
+ return Status::OK();
+ }
+
+ int num_elements = 1 + (buffer_.size() - 1) / window_stride;
+ window_elements.reserve(num_elements);
+ for (size_t i = 0; i < num_elements; ++i) {
+ status.Update(buffer_[window_stride * i].status);
+ if (!status.ok()) {
+ break;
+ }
+ window_elements.emplace_back(buffer_[window_stride * i].result);
+ }
+
+ // Shift the window, discarding elements if necessary.
+ int buffer_size = buffer_.size();
+ if (window_shift >= buffer_size) {
+ for (size_t i = buffer_size; input_impl_ && i < window_shift; ++i) {
+ bool end_of_input;
+ std::vector<Tensor> element;
+ // Ignore non-error status of discarded elements.
+ input_impl_->GetNext(ctx, &element, &end_of_input).IgnoreError();
+ if (end_of_input) {
+ input_impl_.reset();
+ }
}
+ buffer_.clear();
+ } else {
+ buffer_.erase(buffer_.begin(), buffer_.begin() + window_shift);
}
}
- if (window_elements.empty()) {
- DCHECK(*end_of_sequence);
- return Status::OK();
+ if (!status.ok()) {
+ return status;
}
+ // Construct output tensors.
const size_t num_tuple_components = window_elements[0].size();
const int64 num_window_elements = window_elements.size();
+ *end_of_sequence = false;
for (size_t idx = 0; idx < num_tuple_components; ++idx) {
DatasetBase* window_dataset;
std::vector<std::vector<Tensor>> window_component_elements;
@@ -154,7 +228,6 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(window_dataset,
&out_tensors->back()));
}
- *end_of_sequence = false;
return Status::OK();
}
@@ -167,6 +240,20 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
} else {
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
}
+ // Save buffer.
+ TF_RETURN_IF_ERROR(writer->WriteScalar(strings::StrCat("buffer_size"),
+ buffer_.size()));
+ for (int64 i = 0; i < buffer_.size(); i++) {
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, buffer_[i].status));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(strings::StrCat("buffer[", i, "].size"),
+ buffer_[i].result.size()));
+ for (int64 j = 0; j < buffer_[i].result.size(); j++) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteTensor(strings::StrCat("buffer[", i, "][", j, "]"),
+ buffer_[i].result[j]));
+ }
+ }
return Status::OK();
}
@@ -178,22 +265,92 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
} else {
input_impl_.reset();
}
+ // Restore buffer.
+ int64 buffer_size;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(strings::StrCat("buffer_size"), &buffer_size));
+ buffer_.resize(buffer_size);
+ for (int64 i = 0; i < buffer_size; i++) {
+ int64 vector_size;
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &buffer_[i].status));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ strings::StrCat("buffer[", i, "].size"), &vector_size));
+ buffer_[i].result.resize(vector_size);
+ for (int64 j = 0; j < vector_size; j++) {
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(strings::StrCat("buffer[", i, "][", j, "]"),
+ &buffer_[i].result[j]));
+ }
+ }
return Status::OK();
}
private:
+ struct InvocationResult {
+ InvocationResult() = default;
+ InvocationResult(std::vector<Tensor>&& result, const Status& status)
+ : result(result), status(status) {}
+
+ std::vector<Tensor> result;
+ Status status;
+ };
+
+ Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
+ const Status& status)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ CodeKey(index), static_cast<int64>(status.code())));
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
+ status.error_message()));
+ }
+ return Status::OK();
+ }
+
+ Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 code_int;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
+ error::Code code = static_cast<error::Code>(code_int);
+
+ if (code != error::Code::OK) {
+ string error_message;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(ErrorMessageKey(index), &error_message));
+ *status = Status(code, error_message);
+ } else {
+ *status = Status::OK();
+ }
+ return Status::OK();
+ }
+
+ string CodeKey(size_t index) {
+ return full_name(strings::StrCat("buffer[", index, "].code"));
+ }
+
+ string ErrorMessageKey(size_t index) {
+ return full_name(strings::StrCat("buffer[", index, "].error_message"));
+ }
+
+ size_t TargetBufferSize(int64 window_size, int64 window_stride) {
+ return (window_size - 1) * window_stride + 1;
+ }
+
mutex mu_;
+ std::deque<InvocationResult> buffer_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
- const int64 window_size_;
const DatasetBase* const input_;
+ const int64 window_size_;
+ const int64 window_shift_;
+ const int64 window_stride_;
+ const bool drop_remainder_;
};
};
REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU),
WindowDatasetOp);
-
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc
index 1c49874a6a..3f76695bb1 100644
--- a/tensorflow/core/kernels/data/writer_ops.cc
+++ b/tensorflow/core/kernels/data/writer_ops.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/core/platform/file_system.h"
namespace tensorflow {
-
+namespace data {
namespace {
class ToTFRecordOp : public AsyncOpKernel {
@@ -104,4 +104,5 @@ REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU),
ToTFRecordOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc
index e4306579ed..61a2078f46 100644
--- a/tensorflow/core/kernels/data/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/zip_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -175,5 +175,5 @@ class ZipDatasetOp : public DatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("ZipDataset").Device(DEVICE_CPU), ZipDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h
index 33ed5522d0..d705e82b0d 100644
--- a/tensorflow/core/kernels/debug_ops.h
+++ b/tensorflow/core/kernels/debug_ops.h
@@ -255,7 +255,7 @@ class DebugNanCountOp : public BaseDebugOp {
TensorShape shape({1});
OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output_tensor));
output_tensor->vec<int64>()(0) = nan_count;
- PublishTensor(*output_tensor);
+ OP_REQUIRES_OK(context, PublishTensor(*output_tensor));
}
};
@@ -380,7 +380,7 @@ class DebugNumericSummaryOp : public BaseDebugOp {
bool mute = mute_if_healthy_ && nan_count == 0 && negative_inf_count == 0 &&
positive_inf_count == 0;
if (!mute) {
- PublishTensor(*output_tensor);
+ OP_REQUIRES_OK(context, PublishTensor(*output_tensor));
}
}
diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc
index b4dcf0a74b..ae451be7e2 100644
--- a/tensorflow/core/kernels/decode_bmp_op.cc
+++ b/tensorflow/core/kernels/decode_bmp_op.cc
@@ -91,8 +91,10 @@ class DecodeBmpOp : public OpKernel {
errors::InvalidArgument(
"Number of channels must be 1, 3 or 4, was ", channels_));
- OP_REQUIRES(context, width > 0 && header_size >= 0,
+ OP_REQUIRES(context, width > 0,
errors::InvalidArgument("Width must be positive"));
+ OP_REQUIRES(context, height != 0,
+ errors::InvalidArgument("Height must be nonzero"));
OP_REQUIRES(context, header_size >= 0,
errors::InvalidArgument("header size must be nonnegative"));
@@ -108,8 +110,7 @@ class DecodeBmpOp : public OpKernel {
const int32 abs_height = abs(height);
// there may be padding bytes when the width is not a multiple of 4 bytes
- // 8 * channels == bits per pixel
- const int row_size = (8 * channels_ * width + 31) / 32 * 4;
+ const int row_size = (channels_ * width + 3) / 4 * 4;
const int64 last_pixel_offset = static_cast<int64>(header_size) +
(abs_height - 1) * row_size +
diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc
index 3eed847c16..6bfb5bd5bc 100644
--- a/tensorflow/core/kernels/decode_csv_op.cc
+++ b/tensorflow/core/kernels/decode_csv_op.cc
@@ -61,6 +61,9 @@ class DecodeCSVOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults));
for (int i = 0; i < record_defaults.size(); ++i) {
+ OP_REQUIRES(ctx, record_defaults[i].dims() <= 1,
+ errors::InvalidArgument(
+ "Each record default should be at most rank 1"));
OP_REQUIRES(ctx, record_defaults[i].NumElements() < 2,
errors::InvalidArgument(
"There should only be 1 default per field but field ", i,
diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc
index b01db91720..fb2a4cc8ef 100644
--- a/tensorflow/core/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/core/kernels/dynamic_stitch_op.cc
@@ -247,8 +247,8 @@ class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> {
data.shaped<T, 2>({indices_vec.dimension(0), slice_size});
if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
- T* merged_base = &merged_flat(0, 0);
- const T* data_base = &data_flat(0, 0);
+ T* merged_base = merged_flat.data();
+ const T* data_base = data_flat.data();
for (int i = 0; i < indices_vec.size(); i++) {
int32 index = internal::SubtleMustCopy(indices_vec(i));
OP_REQUIRES(
diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
index e13e548f86..8edf7d4a2c 100644
--- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
@@ -51,48 +51,55 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
internal::traits<OutputBackward>::NumDimensions>,
const TensorContractionOp<
const array<
- IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
- const TensorReshapingOp<
+ IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
+ const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
- 3>,
- const TensorReverseOp<const array<bool, 5>, const Kernel> >,
+ 2>,
+ const TensorShufflingOp<
+ const array<
+ typename internal::traits<OutputBackward>::Index, 5>,
+ const TensorReverseOp<const Eigen::array<bool, 5>,
+ const Kernel>>>>,
const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
- 3>,
+ 2>,
const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const OutputBackward> > > >,
+ const OutputBackward>>>>,
TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
internal::traits<OutputBackward>::NumDimensions>,
const TensorContractionOp<
const array<
- IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
+ IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
- 3>,
+ 2>,
const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const OutputBackward> >,
- const TensorReshapingOp<
+ const OutputBackward>>,
+ const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
- 3>,
- const TensorReverseOp<const array<bool, 5>,
- const Kernel> > > > >::type
+ 2>,
+ const TensorShufflingOp<
+ const array<
+ typename internal::traits<OutputBackward>::Index, 5>,
+ const TensorReverseOp<const Eigen::array<bool, 5>,
+ const Kernel>>>>>>>::type
CuboidConvolutionBackwardInput(
const Kernel& kernel, const OutputBackward& output_backward,
typename internal::traits<OutputBackward>::Index inputPlanes,
typename internal::traits<OutputBackward>::Index inputRows,
typename internal::traits<OutputBackward>::Index inputCols,
- const DenseIndex stridePlanes = 1, const DenseIndex strideRows = 1,
- const DenseIndex strideCols = 1) {
+ const DenseIndex plane_stride = 1, const DenseIndex row_stride = 1,
+ const DenseIndex col_stride = 1) {
typedef typename internal::traits<OutputBackward>::Index TensorIndex;
const TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar,
internal::traits<Kernel>::NumDimensions,
- internal::traits<Kernel>::Layout, TensorIndex> >
+ internal::traits<Kernel>::Layout, TensorIndex>>
kern(kernel);
const TensorRef<
const Tensor<typename internal::traits<OutputBackward>::Scalar,
internal::traits<OutputBackward>::NumDimensions,
- internal::traits<OutputBackward>::Layout, TensorIndex> >
+ internal::traits<OutputBackward>::Layout, TensorIndex>>
out(output_backward);
EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout ==
@@ -125,58 +132,45 @@ CuboidConvolutionBackwardInput(
const TensorIndex outputCols =
isColMajor ? out.dimensions()[3] : out.dimensions()[NumDims - 4];
- TensorIndex forward_pad_z, forward_pad_y, forward_pad_x;
- const TensorIndex size_z =
- Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
- const TensorIndex size_y =
- Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
- const TensorIndex size_x =
- Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
-
- // Infer padding type.
- if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) {
- // SAME padding.
- const TensorIndex dz = numext::maxi<TensorIndex>(
- 0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes);
- const TensorIndex dy = numext::maxi<TensorIndex>(
- 0, (size_y - 1) * strideRows + kernelRows - inputRows);
- const TensorIndex dx = numext::maxi<TensorIndex>(
- 0, (size_x - 1) * strideCols + kernelCols - inputCols);
-
- forward_pad_z = dz / 2;
- forward_pad_y = dy / 2;
- forward_pad_x = dx / 2;
- } else {
- // VALID padding.
- forward_pad_z = 0;
- forward_pad_y = 0;
- forward_pad_x = 0;
- }
- const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z;
- const TensorIndex padding_top = kernelRows - 1 - forward_pad_y;
- const TensorIndex padding_left = kernelCols - 1 - forward_pad_x;
-
- const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 -
- (outputPlanes - 1) * stridePlanes - 1 -
- padding_ztop;
- const TensorIndex padding_bottom = inputRows + kernelRows - 1 -
- (outputRows - 1) * strideRows - 1 -
- padding_top;
- const TensorIndex padding_right = inputCols + kernelCols - 1 -
- (outputCols - 1) * strideCols - 1 -
- padding_left;
-
- eigen_assert(padding_ztop >= 0);
- eigen_assert(padding_zbottom >= 0);
+ // TODO(ezhulenev): Add support for inflated strides. Without inflated strides
+ // effective kernel planes/rows/cols are always the same as the kernel itself
+ // (see eigen_spatial_convolutions for details).
+ const TensorIndex kernelPlanesEff = kernelPlanes;
+ const TensorIndex kernelRowsEff = kernelRows;
+ const TensorIndex kernelColsEff = kernelCols;
+
+ // Computing the forward padding.
+ const TensorIndex forward_pad_top_z = numext::maxi<Index>(
+ 0,
+ ((outputPlanes - 1) * plane_stride + kernelPlanesEff - inputPlanes) / 2);
+ const TensorIndex forward_pad_top = numext::maxi<Index>(
+ 0, ((outputRows - 1) * row_stride + kernelRowsEff - inputRows) / 2);
+ const TensorIndex forward_pad_left = numext::maxi<Index>(
+ 0, ((outputCols - 1) * col_stride + kernelColsEff - inputCols) / 2);
+
+ const TensorIndex padding_top_z = kernelPlanesEff - 1 - forward_pad_top_z;
+ const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top;
+ const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left;
+
+ const TensorIndex padding_bottom_z = inputPlanes -
+ (outputPlanes - 1) * plane_stride - 2 -
+ padding_top_z + kernelPlanesEff;
+ const TensorIndex padding_bottom = inputRows - (outputRows - 1) * row_stride -
+ 2 - padding_top + kernelRowsEff;
+ const TensorIndex padding_right = inputCols - (outputCols - 1) * col_stride -
+ 2 - padding_left + kernelColsEff;
+
+ eigen_assert(padding_top_z >= 0);
eigen_assert(padding_top >= 0);
eigen_assert(padding_left >= 0);
+ eigen_assert(padding_bottom_z >= 0);
eigen_assert(padding_bottom >= 0);
eigen_assert(padding_right >= 0);
- // The kernel has dimensions filters X channels X patch_planes X patch_rows X
- // patch_cols.
+ // The kernel has dimensions :
+ // filters x channels x patch_planes x patch_rows x patch_cols.
// We need to reverse the kernel along the spatial dimensions.
- array<bool, 5> kernel_reverse;
+ Eigen::array<bool, 5> kernel_reverse;
if (isColMajor) {
kernel_reverse[0] = false;
kernel_reverse[1] = false;
@@ -191,15 +185,35 @@ CuboidConvolutionBackwardInput(
kernel_reverse[4] = false;
}
- DSizes<TensorIndex, 3> kernel_dims;
+ // Reorder the dimensions to:
+ // filters x patch_planes x patch_rows x patch_cols x channels
+ array<TensorIndex, 5> kernel_shuffle;
if (isColMajor) {
- kernel_dims[0] = kernelFilters;
- kernel_dims[1] = kernelChannels;
- kernel_dims[2] = kernelRows * kernelCols * kernelPlanes;
+ // From: filters x channels x planes x rows x cols
+ // To: filters x planes x rows x cols x channels
+ kernel_shuffle[0] = 0;
+ kernel_shuffle[1] = 2;
+ kernel_shuffle[2] = 3;
+ kernel_shuffle[3] = 4;
+ kernel_shuffle[4] = 1;
} else {
- kernel_dims[0] = kernelRows * kernelCols * kernelPlanes;
+ // From: cols x rows x planes x channels x filters
+ // To: channels x cols x rows x planes x filters
+ kernel_shuffle[0] = 3;
+ kernel_shuffle[1] = 0;
+ kernel_shuffle[2] = 1;
+ kernel_shuffle[3] = 2;
+ kernel_shuffle[4] = 4;
+ }
+
+ // Collapse the dims
+ DSizes<TensorIndex, 2> kernel_dims;
+ if (isColMajor) {
+ kernel_dims[0] = kernelFilters * kernelPlanes * kernelRows * kernelCols;
kernel_dims[1] = kernelChannels;
- kernel_dims[2] = kernelFilters;
+ } else {
+ kernel_dims[1] = kernelFilters * kernelPlanes * kernelRows * kernelCols;
+ kernel_dims[0] = kernelChannels;
}
// The output_backward has dimensions out_depth X out_planes X out_rows X
@@ -208,36 +222,32 @@ CuboidConvolutionBackwardInput(
// dimensions:
// out_depth X (patch_planes * patch_rows * patch_cols) X (input_planes *
// input_rows * input_cols * OTHERS)
- DSizes<TensorIndex, 3> pre_contract_dims;
+ DSizes<TensorIndex, 2> pre_contract_dims;
if (isColMajor) {
- pre_contract_dims[0] = kernelFilters;
- pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
- pre_contract_dims[2] = inputRows * inputCols * inputPlanes;
+ pre_contract_dims[0] =
+ kernelFilters * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[1] = inputPlanes * inputRows * inputCols;
for (int i = 4; i < NumDims; ++i) {
- pre_contract_dims[2] *= out.dimension(i);
+ pre_contract_dims[1] *= out.dimension(i);
}
} else {
- pre_contract_dims[2] = kernelFilters;
- pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
- pre_contract_dims[0] = inputRows * inputCols * inputPlanes;
+ pre_contract_dims[1] =
+ kernelFilters * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[0] = inputPlanes * inputRows * inputCols;
for (int i = 0; i < NumDims - 4; ++i) {
pre_contract_dims[0] *= out.dimension(i);
}
}
- // We will contract along dimensions (0, 2) in kernel and (0, 1) in
- // output_backward, if this is col-major, and
- // dimensions (0, 2) in kernel and (1, 2) in output_backward, if this
- // row-major.
- array<IndexPair<TensorIndex>, 2> contract_dims;
+ // We will contract along the collapsed dimension that contains the
+ // kernelFilters, kernelPlanes, kernelRows and kernelCols.
+ array<IndexPair<TensorIndex>, 1> contract_dims;
if (isColMajor) {
// col-major: kernel.contract(output.patches)
contract_dims[0] = IndexPair<TensorIndex>(0, 0);
- contract_dims[1] = IndexPair<TensorIndex>(2, 1);
} else {
// row-major: output.patches.contract(kernel)
- contract_dims[0] = IndexPair<TensorIndex>(1, 0);
- contract_dims[1] = IndexPair<TensorIndex>(2, 2);
+ contract_dims[0] = IndexPair<TensorIndex>(1, 1);
}
// Post contraction, the dimensions of the input_backprop is
@@ -261,40 +271,31 @@ CuboidConvolutionBackwardInput(
}
}
- DSizes<TensorIndex, NumDims> strides;
- for (int i = 0; i < NumDims; i++) {
- strides[i] = 1;
- }
- if (isColMajor) {
- strides[1] = stridePlanes;
- strides[2] = strideRows;
- strides[3] = strideCols;
- } else {
- strides[NumDims - 2] = stridePlanes;
- strides[NumDims - 3] = strideRows;
- strides[NumDims - 4] = strideCols;
- }
-
return choose(
Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
kernel.reverse(kernel_reverse)
+ .shuffle(kernel_shuffle)
.reshape(kernel_dims)
+ .eval()
.contract(output_backward
.extract_volume_patches(
kernelPlanes, kernelRows, kernelCols, 1, 1, 1,
- stridePlanes, strideRows, strideCols, padding_ztop,
- padding_zbottom, padding_top, padding_bottom,
+ plane_stride, row_stride, col_stride, padding_top_z,
+ padding_bottom_z, padding_top, padding_bottom,
padding_left, padding_right)
.reshape(pre_contract_dims),
contract_dims)
.reshape(post_contract_dims),
output_backward
.extract_volume_patches(kernelPlanes, kernelRows, kernelCols, 1, 1, 1,
- stridePlanes, strideRows, strideCols,
- padding_ztop, padding_zbottom, padding_top,
+ plane_stride, row_stride, col_stride,
+ padding_top_z, padding_bottom_z, padding_top,
padding_bottom, padding_left, padding_right)
.reshape(pre_contract_dims)
- .contract(kernel.reverse(kernel_reverse).reshape(kernel_dims),
+ .contract(kernel.reverse(kernel_reverse)
+ .shuffle(kernel_shuffle)
+ .reshape(kernel_dims)
+ .eval(),
contract_dims)
.reshape(post_contract_dims));
}
@@ -322,48 +323,69 @@ CuboidConvolutionBackwardInput(
*/
template <typename OutputBackward, typename Input>
EIGEN_ALWAYS_INLINE static const typename internal::conditional<
- internal::traits<OutputBackward>::Layout == ColMajor,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index, 5>,
- const TensorReverseOp<
- const array<bool, 5>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<OutputBackward>::Index,
- 5>,
+ internal::traits<Input>::Layout == ColMajor,
+ const TensorReverseOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorReshapingOp<
+ const Eigen::DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
const TensorContractionOp<
const array<
- IndexPair<typename internal::traits<Input>::Index>, 2>,
- const TensorReshapingOp<
+ IndexPair<typename internal::traits<Input>::Index>, 1>,
+ const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index,
- 3>,
- const Input>,
+ 2>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const OutputBackward>>>,
const TensorReshapingOp<
- const DSizes<
- typename internal::traits<OutputBackward>::Index,
- 4>,
+ const DSizes<typename internal::traits<Input>::Index,
+ 2>,
const TensorVolumePatchOp<
Dynamic, Dynamic, Dynamic,
- const OutputBackward> > > > > >,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index, 5>,
- const TensorReverseOp<
- const array<bool, 5>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<OutputBackward>::Index,
- 5>,
+ const Eigen::TensorForcedEvalOp<
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Input>>>>>>>>,
+ const TensorReverseOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorReshapingOp<
+ const Eigen::DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
const TensorContractionOp<
const array<
- IndexPair<typename internal::traits<Input>::Index>, 2>,
- const TensorReshapingOp<
- const DSizes<
- typename internal::traits<OutputBackward>::Index,
- 4>,
- const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const OutputBackward> >,
+ IndexPair<typename internal::traits<Input>::Index>, 1>,
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index,
- 3>,
- const Input> > > > > >::type
+ 2>,
+ const TensorVolumePatchOp<
+ Dynamic, Dynamic, Dynamic,
+ const Eigen::TensorForcedEvalOp<
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Input>>>>,
+ const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ 2>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const OutputBackward>>>>>>>>::type
CuboidConvolutionBackwardKernel(
const Input& input, const OutputBackward& output_backward,
typename internal::traits<Input>::Index kernelPlanes,
@@ -374,11 +396,11 @@ CuboidConvolutionBackwardKernel(
typedef typename internal::traits<Input>::Index TensorIndex;
TensorRef<Tensor<typename internal::traits<Input>::Scalar,
internal::traits<Input>::NumDimensions,
- internal::traits<Input>::Layout, TensorIndex> >
+ internal::traits<Input>::Layout, TensorIndex>>
in(input);
TensorRef<Tensor<typename internal::traits<OutputBackward>::Scalar,
internal::traits<OutputBackward>::NumDimensions,
- internal::traits<OutputBackward>::Layout, TensorIndex> >
+ internal::traits<OutputBackward>::Layout, TensorIndex>>
out(output_backward);
EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout ==
@@ -392,6 +414,13 @@ CuboidConvolutionBackwardKernel(
internal::traits<OutputBackward>::NumDimensions,
YOU_MADE_A_PROGRAMMING_MISTAKE);
+ // We do not support higher dimensional backward convolutions, or convolutions
+ // without batch dimension.
+ // TODO(ezhulenev): Relax this constraint, and turn on tests without batch
+ // dimension in eigen_backward_cuboid_convolutions_test.cc.
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 5,
+ YOU_MADE_A_PROGRAMMING_MISTAKE);
+
const TensorIndex inputPlanes =
isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
const TensorIndex inputRows =
@@ -406,213 +435,174 @@ CuboidConvolutionBackwardKernel(
const TensorIndex outputCols =
isColMajor ? out.dimension(3) : out.dimension(NumDims - 4);
+ // Number of filters. This is the same as the output depth.
const TensorIndex kernelFilters =
isColMajor ? out.dimension(0) : out.dimension(NumDims - 1);
+ // Number of channels. This is the same as the input depth.
const TensorIndex kernelChannels =
isColMajor ? in.dimension(0) : in.dimension(NumDims - 1);
- TensorIndex forward_pad_z, forward_pad_y, forward_pad_x;
- const TensorIndex size_z =
- Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
- const TensorIndex size_y =
- Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
- const TensorIndex size_x =
- Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
-
- // Infer padding type.
- if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) {
- // SAME padding.
- const TensorIndex dz = numext::maxi<TensorIndex>(
- 0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes);
- const TensorIndex dy = numext::maxi<TensorIndex>(
- 0, (size_y - 1) * strideRows + kernelRows - inputRows);
- const TensorIndex dx = numext::maxi<TensorIndex>(
- 0, (size_x - 1) * strideCols + kernelCols - inputCols);
-
- forward_pad_z = dz / 2;
- forward_pad_y = dy / 2;
- forward_pad_x = dx / 2;
+ // Number of batches in the input tensor.
+ const TensorIndex batch =
+ isColMajor ? in.dimension(4) : in.dimension(NumDims - 5);
+
+ // TODO(ezhulenev): Add support for inflated strides. Without inflated strides
+ // effective kernel planes/rows/cols are always the same as the kernel itself
+ // (see eigen_spatial_convolutions for details).
+ const TensorIndex kernelPlanesEff = kernelPlanes;
+ const TensorIndex kernelRowsEff = kernelRows;
+ const TensorIndex kernelColsEff = kernelCols;
+
+ // Compute forward padding from input and output_backward dimensions.
+ const TensorIndex padPlanes = numext::maxi<Index>(
+ 0, (outputPlanes - 1) * stridePlanes + kernelPlanesEff - inputPlanes);
+ const TensorIndex padRows = numext::maxi<Index>(
+ 0, (outputRows - 1) * strideRows + kernelRowsEff - inputRows);
+ const TensorIndex padCols = numext::maxi<Index>(
+ 0, (outputCols - 1) * strideCols + kernelColsEff - inputCols);
+
+ const TensorIndex padding_top_z = padPlanes / 2;
+ const TensorIndex padding_top = padRows / 2;
+ const TensorIndex padding_left = padCols / 2;
+
+ // Compute paddings for output_backward before extracting patches.
+ const auto expanded_out_planes = (outputPlanes - 1) * stridePlanes + 1;
+ const auto expanded_out_rows = (outputRows - 1) * strideRows + 1;
+ const auto expanded_out_cols = (outputCols - 1) * strideCols + 1;
+ const auto padded_out_planes = inputPlanes + kernelPlanes - 1;
+ const auto padded_out_rows = inputRows + kernelRows - 1;
+ const auto padded_out_cols = inputCols + kernelCols - 1;
+ const auto top_pad_planes = kernelPlanes - 1 - padding_top_z;
+ const auto top_pad_rows = kernelRows - 1 - padding_top;
+ const auto left_pad_cols = kernelCols - 1 - padding_left;
+ const auto bottom_pad_planes =
+ padded_out_planes - expanded_out_planes - top_pad_planes;
+ const auto bottom_pad_rows =
+ padded_out_rows - expanded_out_rows - top_pad_rows;
+ const auto right_pad_cols =
+ padded_out_cols - expanded_out_cols - left_pad_cols;
+
+ // Reorder output_backward dimensions.
+ array<TensorIndex, 5> output_backward_shuffle;
+ if (isColMajor) {
+ // From: [out_depth, out_planes, out_rows, out_cols, batch]
+ // To: [batch, out_planes, out_rows, out_cols, out_depth]
+ output_backward_shuffle = {4, 1, 2, 3, 0};
} else {
- // VALID padding.
- forward_pad_z = 0;
- forward_pad_y = 0;
- forward_pad_x = 0;
+ // From: [batch, out_cols, out_rows, out_planes, out_depth]
+ // To: [out_depth, out_cols, out_rows, out_planes, batch]
+ output_backward_shuffle = {4, 1, 2, 3, 0};
}
- const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z;
- const TensorIndex padding_top = kernelRows - 1 - forward_pad_y;
- const TensorIndex padding_left = kernelCols - 1 - forward_pad_x;
-
- const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 -
- (outputPlanes - 1) * stridePlanes - 1 -
- padding_ztop;
- const TensorIndex padding_bottom = inputRows + kernelRows - 1 -
- (outputRows - 1) * strideRows - 1 -
- padding_top;
- const TensorIndex padding_right = inputCols + kernelCols - 1 -
- (outputCols - 1) * strideCols - 1 -
- padding_left;
-
- eigen_assert(padding_ztop >= 0);
- eigen_assert(padding_zbottom >= 0);
- eigen_assert(padding_top >= 0);
- eigen_assert(padding_left >= 0);
- eigen_assert(padding_bottom >= 0);
- eigen_assert(padding_right >= 0);
-
- // The output_backward has dimensions out_depth X out_plaens X out_rows X
- // out_cols X OTHERS
- // When we extract the image patches from output_backward (with input as the
- // kernel), it will have dimensions
- // (out_depth) X (input_planes * input_rows * input_cols) X (kernel_planes *
- // kernel_rows * kernel_cols) X OTHERS
- DSizes<TensorIndex, 4> pre_contract_dims;
+ // Reorder input dimensions.
+ array<TensorIndex, 5> input_shuffle;
if (isColMajor) {
- pre_contract_dims[0] = kernelFilters;
- pre_contract_dims[1] = inputRows * inputCols * inputPlanes;
- pre_contract_dims[2] = kernelRows * kernelCols * kernelPlanes;
- pre_contract_dims[3] = 1;
- for (int i = 4; i < NumDims; ++i) {
- pre_contract_dims[3] *= out.dimension(i);
- }
+ // From: [in_depth, in_planes, in_rows, in_cols, batch]
+ // To: [in_depth, batch, in_planes, in_rows, in_cols]
+ input_shuffle = {0, 4, 1, 2, 3};
} else {
- pre_contract_dims[3] = kernelFilters;
- pre_contract_dims[2] = inputRows * inputCols * inputPlanes;
- pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
- pre_contract_dims[0] = 1;
- for (int i = 0; i < NumDims - 4; ++i) {
- pre_contract_dims[0] *= out.dimension(i);
- }
+ // From: [batch, in_cols, in_rows, in_planes, in_depth]
+ // To: [in_cols, in_rows, in_planes, batch, in_depth]
+ input_shuffle = {1, 2, 3, 0, 4};
}
- // The input has dimensions in_depth X (input_planes * input_rows *
- // input_cols) X OTHERS
- DSizes<TensorIndex, 3> input_dims;
+ // Input is playing the role of a "kernel" in this convolution.
+ DSizes<TensorIndex, 2> input_dims;
if (isColMajor) {
input_dims[0] = kernelChannels;
- input_dims[1] = inputRows * inputCols * inputPlanes;
- input_dims[2] = 1;
- for (int i = 4; i < NumDims; ++i) {
- input_dims[2] *= in.dimension(i);
- }
- eigen_assert(input_dims[2] == pre_contract_dims[3]);
+ input_dims[1] = batch * inputPlanes * inputRows * inputCols;
} else {
- input_dims[2] = kernelChannels;
- input_dims[1] = inputRows * inputCols * inputPlanes;
- input_dims[0] = 1;
- for (int i = 0; i < NumDims - 4; ++i) {
- input_dims[0] *= in.dimension(i);
- }
- eigen_assert(input_dims[0] == pre_contract_dims[0]);
+ input_dims[1] = kernelChannels;
+ input_dims[0] = inputCols * inputRows * inputPlanes * batch;
}
- // We will contract along dimensions (1, 2) in and (1, 3) in out, if
- // this is col-major.
- // For row-major, it's dimensions (0, 1) in and (0, 2) in out.
- array<IndexPair<TensorIndex>, 2> contract_dims;
+ // Molds the output of the patch extraction result into a 2D tensor:
+ // - the first dimension (dims[0]): the patch values to be multiplied with the
+ // kernels
+ // - the second dimension (dims[1]): everything else
+ DSizes<TensorIndex, 2> pre_contract_dims;
if (isColMajor) {
- // col-major: in.contract(output.patches)
- contract_dims[0] = IndexPair<TensorIndex>(1, 1);
- contract_dims[1] = IndexPair<TensorIndex>(2, 3);
+ pre_contract_dims[0] = batch * inputPlanes * inputRows * inputCols;
+ pre_contract_dims[1] =
+ kernelPlanes * kernelRows * kernelCols * kernelFilters;
} else {
- // row-major: output.patches.contract(in)
- contract_dims[0] = IndexPair<TensorIndex>(0, 0);
- contract_dims[1] = IndexPair<TensorIndex>(2, 1);
+ pre_contract_dims[1] = inputCols * inputRows * inputPlanes * batch;
+ pre_contract_dims[0] =
+ kernelFilters * kernelCols * kernelRows * kernelPlanes;
}
- // After the contraction, the kernel will have dimension
- // in_depth X out_depth X kernel_patches X kernel_rows X kernel_cols
- // We will need to shuffle the first two dimensions and reverse the spatial
- // dimensions.
- // The end shape is:
- // out_depth X in_shape X kernel_planes X kernel_rows X kernel_cols
+ // We will contract along the collapsed dimension that contains the
+ // batch, inputPlanes, inputRows and inputCols.
+ array<IndexPair<TensorIndex>, 1> contract_dims;
+ contract_dims[0] = IndexPair<TensorIndex>(1, 0);
- // This is the shape of the kernel *before* the shuffling.
- DSizes<TensorIndex, 5> kernel_dims;
+ // Dimensions after contraction.
+ DSizes<TensorIndex, NumDims> post_contract_dims;
if (isColMajor) {
- kernel_dims[0] = kernelChannels;
- kernel_dims[1] = kernelFilters;
- kernel_dims[2] = kernelPlanes;
- kernel_dims[3] = kernelRows;
- kernel_dims[4] = kernelCols;
+ post_contract_dims[0] = kernelChannels;
+ post_contract_dims[1] = kernelPlanes;
+ post_contract_dims[2] = kernelRows;
+ post_contract_dims[3] = kernelCols;
+ post_contract_dims[4] = kernelFilters;
} else {
- kernel_dims[0] = kernelCols;
- kernel_dims[1] = kernelRows;
- kernel_dims[2] = kernelPlanes;
- kernel_dims[3] = kernelFilters;
- kernel_dims[4] = kernelChannels;
+ post_contract_dims[0] = kernelFilters;
+ post_contract_dims[1] = kernelCols;
+ post_contract_dims[2] = kernelRows;
+ post_contract_dims[3] = kernelPlanes;
+ post_contract_dims[4] = kernelChannels;
}
- // Flip filters and channels.
+ // Reorder output of contraction to valid filter shape.
array<TensorIndex, 5> kernel_shuffle;
if (isColMajor) {
- kernel_shuffle[0] = 1;
- kernel_shuffle[1] = 0;
- kernel_shuffle[2] = 2;
- kernel_shuffle[3] = 3;
- kernel_shuffle[4] = 4;
+ // From: [in_depth, kernel_planes, kernel_rows, kernel_cols, out_depth]
+ // To: [out_depth, in_depth, kernel_planes, kernel_rows, kernel_cols]
+ kernel_shuffle = {4, 0, 1, 2, 3};
} else {
- kernel_shuffle[0] = 0;
- kernel_shuffle[1] = 1;
- kernel_shuffle[2] = 2;
- kernel_shuffle[3] = 4;
- kernel_shuffle[4] = 3;
+ // From: [out_depth, kernel_cols, kernel_rows, kernel_planes, in_depth]
+ // To: [kernel_cols, kernel_rows, kernel_planes, in_depth, out_depth]
+ kernel_shuffle = {1, 2, 3, 4, 0};
}
- // Reverse the spatial dimensions.
- array<bool, 5> kernel_reverse;
+ // Reverse kernel backprop dimensions.
+ array<TensorIndex, 5> kernel_reverse;
if (isColMajor) {
- kernel_reverse[0] = false;
- kernel_reverse[1] = false;
- kernel_reverse[2] = true;
- kernel_reverse[3] = true;
- kernel_reverse[4] = true;
+ kernel_reverse = {false, false, true, true, true};
} else {
- kernel_reverse[0] = true;
- kernel_reverse[1] = true;
- kernel_reverse[2] = true;
- kernel_reverse[3] = false;
- kernel_reverse[4] = false;
+ kernel_reverse = {true, true, true, false, false};
}
- DSizes<TensorIndex, NumDims> strides;
- for (int i = 0; i < NumDims; i++) {
- strides[i] = 1;
- }
- if (isColMajor) {
- strides[1] = stridePlanes;
- strides[2] = strideRows;
- strides[3] = strideCols;
- } else {
- strides[NumDims - 2] = stridePlanes;
- strides[NumDims - 3] = strideRows;
- strides[NumDims - 4] = strideCols;
- }
- return choose(
- Cond<internal::traits<Input>::Layout == ColMajor>(),
- input.reshape(input_dims)
- .contract(output_backward
+ // Create convolution input (aka source of patches) from output backward
+ // tensor by shuffling dimensions.
+ const auto the_input =
+ output_backward.shuffle(output_backward_shuffle).eval();
+
+ // Create convolution kernel (aka filter) from input by shuffling and
+ // reshaping.
+ const auto the_kernel =
+ input.shuffle(input_shuffle).reshape(input_dims).eval();
+
+ return choose(Cond<internal::traits<Input>::Layout == ColMajor>(),
+ the_kernel.contract(
+ the_input
.extract_volume_patches(
inputPlanes, inputRows, inputCols, 1, 1, 1,
stridePlanes, strideRows, strideCols,
-
- padding_ztop, padding_zbottom, padding_top,
- padding_bottom, padding_left, padding_right)
+ top_pad_planes, bottom_pad_planes, top_pad_rows,
+ bottom_pad_rows, left_pad_cols, right_pad_cols)
.reshape(pre_contract_dims),
- contract_dims)
- .reshape(kernel_dims)
- .reverse(kernel_reverse)
- .shuffle(kernel_shuffle),
- output_backward
- .extract_volume_patches(inputPlanes, inputRows, inputCols, 1, 1, 1,
- stridePlanes, strideRows, strideCols,
- padding_ztop, padding_zbottom, padding_top,
- padding_bottom, padding_left, padding_right)
- .reshape(pre_contract_dims)
- .contract(input.reshape(input_dims), contract_dims)
- .reshape(kernel_dims)
- .reverse(kernel_reverse)
- .shuffle(kernel_shuffle));
+ contract_dims),
+ the_input
+ .extract_volume_patches(
+ inputPlanes, inputRows, inputCols, 1, 1, 1,
+ stridePlanes, strideRows, strideCols, top_pad_planes,
+ bottom_pad_planes, top_pad_rows, bottom_pad_rows,
+ left_pad_cols, right_pad_cols)
+ .reshape(pre_contract_dims)
+ .contract(the_kernel, contract_dims))
+ .reshape(post_contract_dims)
+ .shuffle(kernel_shuffle)
+ .reverse(kernel_reverse);
}
} // end namespace Eigen
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
index cb0a76dac4..960920c55b 100644
--- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
@@ -189,14 +189,19 @@ SpatialConvolutionBackwardInput(
}
#endif
- // Reorder the dimensions to filters X patch_rows X patch_cols X channels
+ // Reorder the dimensions to:
+ // filters x patch_rows x patch_cols x channels
array<TensorIndex, 4> kernel_shuffle;
if (isColMajor) {
+ // From: filters x channels x rows x cols
+ // To: filters x rows x cols x channels
kernel_shuffle[0] = 0;
kernel_shuffle[1] = 2;
kernel_shuffle[2] = 3;
kernel_shuffle[3] = 1;
} else {
+ // From: cols x rows x channels x filters
+ // To: channels x cols x rows x filters
kernel_shuffle[0] = 2;
kernel_shuffle[1] = 0;
kernel_shuffle[2] = 1;
@@ -233,8 +238,8 @@ SpatialConvolutionBackwardInput(
}
}
- // We will contract along the fused dimension that contains the kernelFilters,
- // the kernelRows and the kernelCols.
+ // We will contract along the collapsed dimension that contains the
+ // kernelFilters, the kernelRows and the kernelCols.
array<IndexPair<TensorIndex>, 1> contract_dims;
if (isColMajor) {
// col-major: kernel.contract(output.patches)
@@ -327,23 +332,16 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 2>,
const OutputBackward>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorImagePatchOp<Dynamic, Dynamic,
- const Input> > > > >,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorImagePatchOp<Dynamic, Dynamic, const Input> > > >,
TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 4>,
const TensorContractionOp<
const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorImagePatchOp<Dynamic, Dynamic, const Input> > >,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 2>,
const OutputBackward> > > >::type
@@ -451,12 +449,16 @@ SpatialConvolutionBackwardKernel(
eigen_assert(output_dims[0] == pre_contract_dims[0]);
}
- array<TensorIndex, 2> shuffle_dims;
- shuffle_dims[0] = 1;
- shuffle_dims[1] = 0;
-
+ // We will contract along the collapsed dimension that contains the
+ // outputCols, outputRows and OTHERS.
array<IndexPair<TensorIndex>, 1> contract_dims;
- contract_dims[0] = IndexPair<TensorIndex>(1, 0);
+ if (isColMajor) {
+ // col-major: output_backward.contract(input.patches)
+ contract_dims[0] = IndexPair<TensorIndex>(1, 1);
+ } else {
+ // row-major: input.patches.contract(output_backward)
+ contract_dims[0] = IndexPair<TensorIndex>(0, 0);
+ }
// After the contraction, the kernel will have the desired shape
// out_depth X in_shape X kernel_rows X kernel_cols
@@ -482,8 +484,7 @@ SpatialConvolutionBackwardKernel(
kernelRows, kernelCols, row_stride, col_stride,
row_in_stride, col_in_stride, 1, 1, padding_top,
padding_bottom, padding_left, padding_right, OutScalar(0))
- .reshape(pre_contract_dims)
- .shuffle(shuffle_dims),
+ .reshape(pre_contract_dims),
contract_dims)
.reshape(kernel_dims),
input
@@ -492,7 +493,6 @@ SpatialConvolutionBackwardKernel(
padding_top, padding_bottom, padding_left,
padding_right, OutScalar(0))
.reshape(pre_contract_dims)
- .shuffle(shuffle_dims)
.contract(output_backward.reshape(output_dims), contract_dims)
.reshape(kernel_dims));
}
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
index 2229ec9659..673ec1458b 100644
--- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
@@ -1248,11 +1248,14 @@ TEST(EigenBackwardSpatialConvolutionsTest,
const int output_cols = input_cols - patch_cols + 1;
const int output_planes = input_planes - patch_planes + 1;
- Tensor<float, 4> input(input_depth, input_planes, input_rows, input_cols);
+ // TODO(ezhulenev): Support backward kernel convolution without batch
+ // dimension.
+ Tensor<float, 5> input(input_depth, input_planes, input_rows, input_cols,
+ /*num_batches*/ 1);
Tensor<float, 5> kernel(output_depth, input_depth, patch_planes, patch_rows,
patch_cols);
- Tensor<float, 4> output_backward(output_depth, output_planes, output_rows,
- output_cols);
+ Tensor<float, 5> output_backward(output_depth, output_planes, output_rows,
+ output_cols, /*num_batches*/ 1);
output_backward = output_backward.constant(11.0f) + output_backward.random();
input = input.constant(2.0f) + input.random();
@@ -1282,9 +1285,9 @@ TEST(EigenBackwardSpatialConvolutionsTest,
if (output_i >= 0 && output_i < output_planes &&
output_j >= 0 && output_j < output_rows &&
output_k >= 0 && output_k < output_cols) {
- expected +=
- input(id, i, j, k) *
- output_backward(od, output_i, output_j, output_k);
+ expected += input(id, i, j, k, /*batch*/ 0) *
+ output_backward(od, output_i, output_j,
+ output_k, /*batch*/ 0);
}
}
}
@@ -1311,12 +1314,14 @@ TEST(EigenBackwardSpatialConvolutionsTest,
const int output_cols = input_cols - patch_cols + 1;
const int output_planes = input_planes - patch_planes + 1;
- Tensor<float, 4, RowMajor> input(input_cols, input_rows, input_planes,
- input_depth);
+ // TODO(ezhulenev): Support backward kernel convolution without batch
+ // dimension.
+ Tensor<float, 5, RowMajor> input(/*num_batches*/ 1, input_cols, input_rows,
+ input_planes, input_depth);
Tensor<float, 5, RowMajor> kernel(patch_cols, patch_rows, patch_planes,
input_depth, output_depth);
- Tensor<float, 4, RowMajor> output_backward(output_cols, output_rows,
- output_planes, output_depth);
+ Tensor<float, 5, RowMajor> output_backward(
+ /*num_batches*/ 1, output_cols, output_rows, output_planes, output_depth);
output_backward = output_backward.constant(11.0f) + output_backward.random();
input = input.constant(2.0f) + input.random();
@@ -1346,9 +1351,9 @@ TEST(EigenBackwardSpatialConvolutionsTest,
if (output_i >= 0 && output_i < output_planes &&
output_j >= 0 && output_j < output_rows &&
output_k >= 0 && output_k < output_cols) {
- expected +=
- input(k, j, i, id) *
- output_backward(output_k, output_j, output_i, od);
+ expected += input(/*batch*/ 0, k, j, i, id) *
+ output_backward(/*batch*/ 0, output_k, output_j,
+ output_i, od);
}
}
}
diff --git a/tensorflow/core/kernels/eigen_benchmark.h b/tensorflow/core/kernels/eigen_benchmark.h
new file mode 100644
index 0000000000..87e41b89b3
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_benchmark.h
@@ -0,0 +1,304 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_
+#define TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h"
+#include "tensorflow/core/kernels/eigen_backward_spatial_convolutions.h"
+#include "tensorflow/core/kernels/eigen_cuboid_convolution.h"
+#include "tensorflow/core/kernels/eigen_spatial_convolutions.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+using ::tensorflow::TTypes;
+
+template <typename Scalar, typename Device>
+class SpatialConvolutionBenchmarksSuite {
+ public:
+ using Input = TTypes<float, 4>::ConstTensor;
+ using Filter = TTypes<float, 4>::ConstTensor;
+ using Output = TTypes<float, 4>::Tensor;
+
+ using Dimensions = Eigen::DSizes<Eigen::Index, 4>;
+
+ SpatialConvolutionBenchmarksSuite(int iters, Device& device)
+ : iters_(iters), device_(device) {}
+
+ Eigen::Index BufferSize(const Dimensions& dims) {
+ return dims.TotalSize() * sizeof(Scalar);
+ }
+
+ void SpatialConvolution(Dimensions input_dims, Dimensions filter_dims) {
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ filter_dims[3]); // filter_count
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+
+ Input input(input_data, input_dims);
+ Filter filter(filter_data, filter_dims);
+ Output output(output_data, output_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ output.device(device_) = Eigen::SpatialConvolution(input, filter);
+ tensorflow::testing::DoNotOptimize(output);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(filter_data);
+ device_.deallocate(output_data);
+ }
+
+ void SpatialConvolutionBackwardInput(Dimensions input_dims,
+ Dimensions filter_dims) {
+ using OutputBackward = TTypes<float, 4>::ConstTensor;
+ using InputBackward = TTypes<float, 4>::Tensor;
+
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ filter_dims[3]); // filter_count
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index input_rows = input_dims[1];
+ Eigen::Index input_cols = input_dims[2];
+
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+ Scalar* input_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+ device_.memset(output_backward_data, 123, BufferSize(output_dims));
+
+ Filter filter(filter_data, filter_dims);
+ OutputBackward output_backward(output_backward_data, output_dims);
+ InputBackward input_backward(input_backward_data, input_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ input_backward.device(device_) = Eigen::SpatialConvolutionBackwardInput(
+ filter, output_backward, input_rows, input_cols);
+ tensorflow::testing::DoNotOptimize(input_backward);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(filter_data);
+ device_.deallocate(output_backward_data);
+ device_.deallocate(input_backward_data);
+ }
+
+ void SpatialConvolutionBackwardKernel(Dimensions input_dims,
+ Dimensions filter_dims) {
+ using OutputBackward = TTypes<float, 4>::ConstTensor;
+ using FilterBackward = TTypes<float, 4>::Tensor;
+
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ filter_dims[3]); // filter_count
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index filter_rows = filter_dims[0];
+ Eigen::Index filter_cols = filter_dims[1];
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* output_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+ Scalar* filter_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(output_backward_data, 123, BufferSize(output_dims));
+
+ Input input(input_data, input_dims);
+ OutputBackward output_backward(output_backward_data, input_dims);
+ FilterBackward filter_backward(filter_backward_data, filter_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ filter_backward.device(device_) = Eigen::SpatialConvolutionBackwardKernel(
+ input, output_backward, filter_rows, filter_cols);
+ tensorflow::testing::DoNotOptimize(filter_backward);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(output_backward_data);
+ device_.deallocate(filter_backward_data);
+ }
+
+ private:
+ int iters_;
+ Device& device_;
+};
+
+template <typename Scalar, typename Device>
+class CuboidConvolutionBenchmarksSuite {
+ public:
+ using Input = TTypes<float, 5>::ConstTensor;
+ using Filter = TTypes<float, 5>::ConstTensor;
+ using Output = TTypes<float, 5>::Tensor;
+
+ using Dimensions = Eigen::DSizes<Eigen::Index, 5>;
+
+ CuboidConvolutionBenchmarksSuite(int iters, Device& device)
+ : iters_(iters), device_(device) {}
+
+ Eigen::Index BufferSize(const Dimensions& dims) {
+ return dims.TotalSize() * sizeof(Scalar);
+ }
+
+ void CuboidConvolution(Dimensions input_dims, Dimensions filter_dims) {
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ input_dims[3], // input_planes
+ filter_dims[4]); // filter_count
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+
+ Input input(input_data, input_dims);
+ Filter filter(filter_data, filter_dims);
+ Output output(output_data, output_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ output.device(device_) = Eigen::CuboidConvolution(input, filter);
+ tensorflow::testing::DoNotOptimize(output);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(filter_data);
+ device_.deallocate(output_data);
+ }
+
+ void CuboidConvolutionBackwardInput(Dimensions input_dims,
+ Dimensions filter_dims) {
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ input_dims[3], // input_planes
+ filter_dims[4]); // filter_count
+
+ using OutputBackward = TTypes<float, 5>::ConstTensor;
+ using InputBackward = TTypes<float, 5>::Tensor;
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index input_rows = input_dims[1];
+ Eigen::Index input_cols = input_dims[2];
+ Eigen::Index input_planes = input_dims[3];
+
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+ Scalar* input_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+ device_.memset(output_backward_data, 123, BufferSize(output_dims));
+
+ Filter filter(filter_data, filter_dims);
+ OutputBackward output_backward(output_backward_data, output_dims);
+ InputBackward input_backward(input_backward_data, input_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ input_backward.device(device_) = Eigen::CuboidConvolutionBackwardInput(
+ filter, output_backward, input_planes, input_rows, input_cols);
+ tensorflow::testing::DoNotOptimize(input_backward);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(filter_data);
+ device_.deallocate(output_backward_data);
+ device_.deallocate(input_backward_data);
+ }
+
+ void CuboidConvolutionBackwardKernel(Dimensions input_dims,
+ Dimensions filter_dims) {
+ using OutputBackward = TTypes<float, 5>::ConstTensor;
+ using FilterBackward = TTypes<float, 5>::Tensor;
+
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ input_dims[3], // input_planes
+ filter_dims[4]); // filter_count
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index filter_rows = filter_dims[0];
+ Eigen::Index filter_cols = filter_dims[1];
+ Eigen::Index filter_planes = filter_dims[2];
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* output_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+ Scalar* filter_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(output_backward_data, 123, BufferSize(output_dims));
+
+ Input input(input_data, input_dims);
+ OutputBackward output_backward(output_backward_data, output_dims);
+ FilterBackward filter_backward(filter_backward_data, filter_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ filter_backward.device(device_) = Eigen::CuboidConvolutionBackwardKernel(
+ input, output_backward, filter_planes, filter_rows, filter_cols);
+ tensorflow::testing::DoNotOptimize(filter_backward);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(output_backward_data);
+ device_.deallocate(filter_backward_data);
+ }
+
+ private:
+ int iters_;
+ Device& device_;
+};
+
+#endif // TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_
diff --git a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
new file mode 100644
index 0000000000..ec949ddc84
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
@@ -0,0 +1,422 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENTE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONT OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#define EIGEN_USE_CUSTOM_THREAD_POOL
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/eigen_benchmark.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+#define CREATE_THREAD_POOL(threads) \
+ Eigen::ThreadPool tp(threads); \
+ Eigen::ThreadPoolDevice device(&tp, threads)
+
+// -------------------------------------------------------------------------- //
+// Spatial Convolutions //
+// -------------------------------------------------------------------------- //
+
+void SpatialConvolution(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height, int input_width,
+ int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height, int filter_width) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(input_batches, input_height,
+ input_width, input_depth);
+ typename Benchmark::Dimensions filter_dims(filter_height, filter_width,
+ input_depth, filter_count);
+
+ benchmark.SpatialConvolution(input_dims, filter_dims);
+
+ auto num_computed_elements =
+ (input_dims.TotalSize() / input_depth) * filter_count;
+ auto flops =
+ num_computed_elements * (input_depth * filter_height * filter_width);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void SpatialConvolutionBackwardInput(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(input_batches, input_height,
+ input_width, input_depth);
+ typename Benchmark::Dimensions filter_dims(filter_height, filter_width,
+ input_depth, filter_count);
+
+ benchmark.SpatialConvolutionBackwardInput(input_dims, filter_dims);
+
+ auto num_computed_elements = input_dims.TotalSize();
+ auto flops =
+ num_computed_elements * (input_depth * filter_height * filter_width);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void SpatialConvolutionBackwardKernel(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(input_batches, input_height,
+ input_width, input_depth);
+ typename Benchmark::Dimensions filter_dims(filter_height, filter_width,
+ input_depth, filter_count);
+
+ benchmark.SpatialConvolutionBackwardKernel(input_dims, filter_dims);
+
+ auto num_computed_elements = filter_dims.TotalSize();
+ auto flops =
+ num_computed_elements * (input_batches * input_height * input_width);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+// Macro arguments names: --------------------------------------------------- //
+// NT: num threads
+// N: batch size
+// H: height
+// W: width
+// C: channels
+// FC: filter count
+// FH: filter height
+// FW: filter width
+
+#define BM_SPATIAL_NAME(prefix, NT, N, H, W, C, FC, FH, FW) \
+ BM_##prefix##_CPU_##NT##T_in_##N##_##H##_##W##_##C##_f_##FC##_##FH##_##FW
+
+#define BM_SpatialConvolution(NT, N, H, W, C, FC, FH, FW, LABEL) \
+ static void BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ SpatialConvolution(iters, NT, N, H, W, C, FC, FH, FW); \
+ } \
+ BENCHMARK(BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, FW))
+
+#define BM_SpatialConvolutionBwdInput(NT, N, H, W, C, FC, FH, FW, LABEL) \
+ static void BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, \
+ FH, FW)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ SpatialConvolutionBackwardInput(iters, NT, N, H, W, C, FC, FH, FW); \
+ } \
+ BENCHMARK( \
+ BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, FH, FW))
+
+#define BM_SpatialConvolutionBwdKernel(NT, N, H, W, C, FC, FH, FW, LABEL) \
+ static void BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \
+ FH, FW)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ SpatialConvolutionBackwardKernel(iters, NT, N, H, W, C, FC, FH, FW); \
+ } \
+ BENCHMARK(BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \
+ FH, FW))
+
+#define BM_SpatialConvolutions(N, H, W, C, FC, FH, FW, LABEL) \
+ BM_SpatialConvolution(2, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolution(4, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolution(8, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolution(16, N, H, W, C, FC, FH, FW, LABEL);
+
+#define BM_SpatialConvolutionsBwdInput(N, H, W, C, FC, FH, FW, LABEL) \
+ BM_SpatialConvolutionBwdInput(2, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdInput(4, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdInput(8, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdInput(16, N, H, W, C, FC, FH, FW, LABEL);
+
+#define BM_SpatialConvolutionsBwdKernel(N, H, W, C, FC, FH, FW, LABEL) \
+ BM_SpatialConvolutionBwdKernel(2, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdKernel(4, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdKernel(8, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdKernel(16, N, H, W, C, FC, FH, FW, LABEL);
+
+// ImageNet Forward Convolutions -------------------------------------------- //
+
+BM_SpatialConvolutions(32, // batch size
+ 56, 56, 64, // input: height, width, depth
+ 192, 3, 3, // filter: count, height, width
+ "conv2_00");
+
+BM_SpatialConvolutions(32, 28, 28, 96, 128, 3, 3, "conv3a_00_3x3");
+BM_SpatialConvolutions(32, 28, 28, 16, 32, 5, 5, "conv3a_00_5x5");
+BM_SpatialConvolutions(32, 28, 28, 128, 192, 3, 3, "conv3_00_3x3");
+BM_SpatialConvolutions(32, 28, 28, 32, 96, 5, 5, "conv3_00_5x5");
+BM_SpatialConvolutions(32, 14, 14, 96, 204, 3, 3, "conv4a_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 16, 48, 5, 5, "conv4a_00_5x5");
+BM_SpatialConvolutions(32, 14, 14, 112, 224, 3, 3, "conv4b_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 24, 64, 5, 5,
+ "conv4b_00_5x5 / conv4c_00_5x5");
+BM_SpatialConvolutions(32, 14, 14, 128, 256, 3, 3, "conv4c_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 144, 288, 3, 3, "conv4d_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 32, 64, 5, 5, "conv4d_00_5x5");
+BM_SpatialConvolutions(32, 14, 14, 160, 320, 3, 3, "conv4_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 32, 128, 5, 5, "conv4_00_5x5");
+BM_SpatialConvolutions(32, 7, 7, 160, 320, 3, 3, "conv5a_00_3x3");
+BM_SpatialConvolutions(32, 7, 7, 48, 128, 5, 5, "conv5a_00_5x5 / conv5_00_5x5");
+BM_SpatialConvolutions(32, 7, 7, 192, 384, 3, 3, "conv5_00_3x3");
+
+// Benchmarks from https://github.com/soumith/convnet-benchmarks
+BM_SpatialConvolutions(128, 128, 128, 3, 96, 11, 11, "convnet-layer1");
+BM_SpatialConvolutions(128, 64, 64, 64, 128, 9, 9, "convnet-layer2");
+BM_SpatialConvolutions(128, 32, 32, 128, 128, 9, 9, "convnet-layer3");
+BM_SpatialConvolutions(128, 16, 16, 128, 128, 7, 7, "convnet-layer4");
+BM_SpatialConvolutions(128, 13, 13, 384, 384, 3, 3, "convnet-layer5");
+
+// ImageNet BackwardInput Convolutions -------------------------------------- //
+
+BM_SpatialConvolutionsBwdInput(32, 56, 56, 64, 192, 3, 3, "conv2_00");
+BM_SpatialConvolutionsBwdInput(32, 28, 28, 96, 128, 3, 3, "conv3a_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 28, 28, 16, 32, 5, 5, "conv3a_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 28, 28, 128, 192, 3, 3, "conv3_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 28, 28, 32, 96, 5, 5, "conv3_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 96, 204, 3, 3, "conv4a_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 16, 48, 5, 5, "conv4a_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 112, 224, 3, 3, "conv4b_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 24, 64, 5, 5,
+ "conv4b_00_5x5 / conv4c_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 128, 256, 3, 3, "conv4c_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 144, 288, 3, 3, "conv4d_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 32, 64, 5, 5, "conv4d_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 160, 320, 3, 3, "conv4_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 32, 128, 5, 5, "conv4_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 7, 7, 160, 320, 3, 3, "conv5a_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 7, 7, 48, 128, 5, 5,
+ "conv5a_00_5x5 / conv5_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 7, 7, 192, 384, 3, 3, "conv5_00_3x3");
+
+// ImageNet BackwardKernel Convolutions ------------------------------------- //
+
+BM_SpatialConvolutionsBwdKernel(32, 56, 56, 64, 192, 3, 3, "conv2_00");
+BM_SpatialConvolutionsBwdKernel(32, 28, 28, 96, 128, 3, 3, "conv3a_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 28, 28, 16, 32, 5, 5, "conv3a_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 28, 28, 128, 192, 3, 3, "conv3_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 28, 28, 32, 96, 5, 5, "conv3_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 96, 204, 3, 3, "conv4a_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 16, 48, 5, 5, "conv4a_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 112, 224, 3, 3, "conv4b_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 24, 64, 5, 5,
+ "conv4b_00_5x5 / conv4c_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 128, 256, 3, 3, "conv4c_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 144, 288, 3, 3, "conv4d_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 32, 64, 5, 5, "conv4d_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 160, 320, 3, 3, "conv4_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 32, 128, 5, 5, "conv4_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 7, 7, 160, 320, 3, 3, "conv5a_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 7, 7, 48, 128, 5, 5,
+ "conv5a_00_5x5 / conv5_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 7, 7, 192, 384, 3, 3, "conv5_00_3x3");
+
+// -------------------------------------------------------------------------- //
+// Cuboid Convolutions //
+// -------------------------------------------------------------------------- //
+
+void CuboidConvolution(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height, int input_width,
+ int input_planes, int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height, int filter_width,
+ int filter_planes) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(
+ input_batches, input_height, input_width, input_planes, input_depth);
+ typename Benchmark::Dimensions filter_dims(
+ filter_height, filter_width, filter_planes, input_depth, filter_count);
+
+ benchmark.CuboidConvolution(input_dims, filter_dims);
+
+ auto num_computed_elements =
+ (input_dims.TotalSize() / input_depth) * filter_count;
+ auto flops = num_computed_elements *
+ (input_depth * filter_height * filter_width * filter_planes);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void CuboidConvolutionBackwardInput(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_planes,
+ int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width, int filter_planes) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(
+ input_batches, input_height, input_width, input_planes, input_depth);
+ typename Benchmark::Dimensions filter_dims(
+ filter_height, filter_width, filter_planes, input_depth, filter_count);
+
+ benchmark.CuboidConvolutionBackwardInput(input_dims, filter_dims);
+
+ auto num_computed_elements = input_dims.TotalSize();
+ auto flops = num_computed_elements *
+ (input_depth * filter_height * filter_width * filter_planes);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void CuboidConvolutionBackwardKernel(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_planes,
+ int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width, int filter_planes) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(
+ input_batches, input_height, input_width, input_planes, input_depth);
+ typename Benchmark::Dimensions filter_dims(
+ filter_height, filter_width, filter_planes, input_depth, filter_count);
+
+ benchmark.CuboidConvolutionBackwardKernel(input_dims, filter_dims);
+
+ auto num_computed_elements = filter_dims.TotalSize();
+ auto flops = num_computed_elements *
+ (input_batches * input_height * input_width * input_planes);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+// Macro arguments names: --------------------------------------------------- //
+// NT: num threads
+// N: batch size
+// H: height
+// W: width
+// P: panes
+// C: channels
+// FC: filter count
+// FH: filter height
+// FW: filter width
+// FP: filter panes
+
+#define BM_CONCAT(a, b) a##b
+
+#define BM_CUBOID_NAME(p, NT, N, H, W, P, C, FC, FH, FW, FP) \
+ BM_CONCAT(BM_##p##_CPU_##NT##T_in_##N##_##H##_##W##_##P##_##C, \
+ _f_##FC##_##FH##_##FW##_##FP)
+
+#define BM_CuboidConvolution(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ static void BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, \
+ FP)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ CuboidConvolution(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ } \
+ BENCHMARK( \
+ BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, FP))
+
+#define BM_CuboidConvolutionBwdInput(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ static void BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \
+ FH, FW, FP)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ CuboidConvolutionBackwardInput(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ } \
+ BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \
+ FH, FW, FP))
+
+#define BM_CuboidConvolutionBwdKernel(NT, N, H, W, P, C, FC, FH, FW, FP, \
+ LABEL) \
+ static void BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, \
+ FC, FH, FW, FP)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ CuboidConvolutionBackwardKernel(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ } \
+ BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, FC, \
+ FH, FW, FP))
+
+#define BM_CuboidConvolutions(N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ BM_CuboidConvolution(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolution(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolution(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolution(16, N, H, W, P, C, FC, FH, FW, FP, LABEL);
+
+#define BM_CuboidConvolutionsBwdInput(N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ BM_CuboidConvolutionBwdInput(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdInput(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdInput(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdInput(16, N, H, W, P, C, FC, FH, FW, FP, LABEL);
+
+#define BM_CuboidConvolutionsBwdKernel(N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ BM_CuboidConvolutionBwdKernel(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdKernel(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdKernel(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdKernel(16, N, H, W, P, C, FC, FH, FW, FP, LABEL);
+
+// Random Cuboid Convolutions ----------------------------------------------- //
+// TODO(ezhulenev): find representative dims for cuboid convolutions (find
+// models using Conv3D ops).
+
+BM_CuboidConvolutions(8, // batch size
+ 25, 25, 25, 4, // input: height, width, panes, depth
+ 16, 5, 5, 5, // filter: count, height, width, panes
+ "conv3d_depth4");
+BM_CuboidConvolutions(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8");
+BM_CuboidConvolutions(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1");
+BM_CuboidConvolutions(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2");
+
+BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4");
+BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8");
+BM_CuboidConvolutionsBwdInput(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1");
+BM_CuboidConvolutionsBwdInput(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2");
+
+BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4");
+BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8");
+BM_CuboidConvolutionsBwdKernel(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1");
+BM_CuboidConvolutionsBwdKernel(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2");
diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h
index 62e9f9123d..c41fbc42d3 100644
--- a/tensorflow/core/kernels/eigen_cuboid_convolution.h
+++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h
@@ -21,6 +21,1362 @@ limitations under the License.
namespace Eigen {
+namespace internal {
+
+// WARNING: Most of the code here implicitly assumes that the matrix is in
+// ColMajor layout. This is guaranteed by the tensor contraction (see
+// TensorContraction.h).
+//
+// Inside Eigen a tensor contraction is represented by a matrix multiplication.
+// We don't want to actually extract volume patches and reshape the result into
+// a matrix (this involves allocating huge extra memory), so the patch
+// extraction and reshape operations are implicit.
+//
+// TensorContractionInputMapper takes a matrix index and returns the coefficient
+// (or the packet) of the "virtual tensor", that would be at that index if we
+// were to actually reshape the result of patch extraction.
+//
+// TensorContractionSubMapper provides a similar view into the "virtual matrix"
+// at the given vertical and horizontal offsets.
+//
+// "Virtual matrix" dimensions:
+// *0: kernelChannels * kernelDepth * kernelRows * kernelCols;
+// 1: out_depth * out_height * out_width; * OTHERS (e.g batches, etc...)
+//
+// *) extracted patches are continuous in memory (innermost dimension assuming
+// col major layout)
+//
+// With this dimensions:
+// row - offset within a single patch (in code: patchId)
+// col - index of the extracted patch (in code: patchIndex)
+// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
+//
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar_,
+ typename Index, typename nocontract_t, typename contract_t, int Side,
+ int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+ int Alignment>
+class TensorContractionInputMapper<
+ Scalar_, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<NewDimension,
+ const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment> {
+ public:
+ typedef Scalar_ Scalar;
+ typedef TensorContractionInputMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ Self;
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ SubMapper;
+ typedef SubMapper VectorMapper;
+ typedef SubMapper LinearMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_DEVICE_FUNC
+ TensorContractionInputMapper(
+ const TensorEvaluator<
+ const TensorReshapingOp<
+ NewDimension,
+ const TensorVolumePatchOp<Planes, Rows, Cols, ArgType> >,
+ Device>& tensor,
+ const nocontract_t&, const nocontract_t&, const contract_t&,
+ const contract_t&)
+ : m_impl(tensor.impl().impl()) {
+ if (internal::traits<ArgType>::Layout == ColMajor) {
+ m_patch_depth = tensor.impl().dimensions()[0];
+ m_patch_planes = tensor.impl().dimensions()[1];
+ m_patch_rows = tensor.impl().dimensions()[2];
+ m_patch_cols = tensor.impl().dimensions()[3];
+ m_num_patches = tensor.impl().dimensions()[4];
+ } else {
+ const int NumDims = tensor.impl().dimensions().size();
+ m_patch_depth = tensor.impl().dimensions()[NumDims - 1];
+ m_patch_planes = tensor.impl().dimensions()[NumDims - 2];
+ m_patch_rows = tensor.impl().dimensions()[NumDims - 3];
+ m_patch_cols = tensor.impl().dimensions()[NumDims - 4];
+ m_num_patches = tensor.impl().dimensions()[NumDims - 5];
+ }
+
+ // Strides for the output tensor.
+ // IMPORTANT: These strides are used to locate an element in a patch at a
+ // depth zero (channel), which is not quite the same as "traditional"
+ // stride.
+ m_rowStride = m_patch_planes;
+ m_colStride = m_patch_rows * m_rowStride;
+ m_patchStride = m_colStride * m_patch_cols * m_patch_depth;
+ m_otherStride = m_patchStride * m_num_patches;
+
+ m_outputPlanes = tensor.impl().outputPlanes();
+ m_outputRows = tensor.impl().outputRows();
+ m_outputCols = tensor.impl().outputCols();
+
+ m_outputPlanesRows = m_outputPlanes * m_outputRows;
+
+ m_plane_strides = tensor.impl().userPlaneStride();
+ m_row_strides = tensor.impl().userRowStride();
+ m_col_strides = tensor.impl().userColStride();
+
+ m_in_plane_strides = tensor.impl().userInPlaneStride();
+ m_in_row_strides = tensor.impl().userInRowStride();
+ m_in_col_strides = tensor.impl().userInColStride();
+
+ m_patch_plane_inflate_strides = tensor.impl().planeInflateStride();
+ m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
+ m_patch_col_inflate_strides = tensor.impl().colInflateStride();
+
+ if (internal::traits<ArgType>::Layout == ColMajor) {
+ m_inputDepth = tensor.impl().impl().dimensions()[0];
+ m_inputPlanes = tensor.impl().impl().dimensions()[1];
+ m_inputRows = tensor.impl().impl().dimensions()[2];
+ m_inputCols = tensor.impl().impl().dimensions()[3];
+ } else {
+ const int NumDims = tensor.impl().impl().dimensions().size();
+ m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1];
+ m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2];
+ m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3];
+ m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4];
+ }
+
+ // Strides for navigating through the input tensor.
+ m_planeInputStride = m_inputDepth;
+ m_rowInputStride = m_inputDepth * m_inputPlanes;
+ m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes;
+ m_patchInputStride =
+ m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes;
+
+ m_planePaddingTop = tensor.impl().planePaddingTop();
+ m_rowPaddingTop = tensor.impl().rowPaddingTop();
+ m_colPaddingLeft = tensor.impl().colPaddingLeft();
+
+ m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
+
+ m_fastInputPlaneStride =
+ internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides);
+ m_fastInputRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
+ m_fastInputColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
+
+ m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride);
+ m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
+
+ m_fastDimZero = internal::TensorIntDivisor<Index>(m_patch_depth);
+ m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
+ m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes);
+ m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
+ m_fastOutputCols = internal::TensorIntDivisor<Index>(m_outputCols);
+
+ m_fastOutputPlanesRows =
+ internal::TensorIntDivisor<Index>(m_outputPlanesRows);
+ }
+
+ EIGEN_DEVICE_FUNC
+ TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper)
+ : m_impl(base_mapper.m_impl) {
+ m_patch_depth = base_mapper.m_patch_depth;
+ m_patch_planes = base_mapper.m_patch_planes;
+ m_patch_rows = base_mapper.m_patch_rows;
+ m_patch_cols = base_mapper.m_patch_cols;
+ m_num_patches = base_mapper.m_num_patches;
+
+ m_rowStride = base_mapper.m_rowStride;
+ m_colStride = base_mapper.m_colStride;
+ m_patchStride = base_mapper.m_patchStride;
+ m_otherStride = base_mapper.m_otherStride;
+
+ m_planeInputStride = base_mapper.m_planeInputStride;
+ m_rowInputStride = base_mapper.m_rowInputStride;
+ m_colInputStride = base_mapper.m_colInputStride;
+ m_patchInputStride = base_mapper.m_patchInputStride;
+ m_otherInputStride = base_mapper.m_otherInputStride;
+
+ m_inputDepth = base_mapper.m_inputDepth;
+ m_inputPlanes = base_mapper.m_inputPlanes;
+ m_inputRows = base_mapper.m_inputRows;
+ m_inputCols = base_mapper.m_inputCols;
+
+ m_outputPlanes = base_mapper.m_outputPlanes;
+ m_outputRows = base_mapper.m_outputRows;
+ m_outputCols = base_mapper.m_outputCols;
+
+ m_plane_strides = base_mapper.m_plane_strides;
+ m_row_strides = base_mapper.m_row_strides;
+ m_col_strides = base_mapper.m_col_strides;
+
+ m_in_plane_strides = base_mapper.m_in_plane_strides;
+ m_in_row_strides = base_mapper.m_in_row_strides;
+ m_in_col_strides = base_mapper.m_in_col_strides;
+
+ m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides;
+ m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
+ m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
+
+ m_planePaddingTop = base_mapper.m_planePaddingTop;
+ m_rowPaddingTop = base_mapper.m_rowPaddingTop;
+ m_colPaddingLeft = base_mapper.m_colPaddingLeft;
+
+ m_outputPlanesRows = base_mapper.m_outputPlanesRows;
+
+ m_fastNumPatches = base_mapper.m_fastNumPatches;
+ m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride;
+ m_fastInputRowStride = base_mapper.m_fastInputRowStride;
+ m_fastInputColStride = base_mapper.m_fastInputColStride;
+ m_fastRowStride = base_mapper.m_fastRowStride;
+ m_fastColStride = base_mapper.m_fastColStride;
+ m_fastOutputPlanes = base_mapper.m_fastOutputPlanes;
+ m_fastOutputRows = base_mapper.m_fastOutputRows;
+ m_fastOutputCols = base_mapper.m_fastOutputCols;
+ m_fastDimZero = base_mapper.m_fastDimZero;
+ m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows;
+ }
+
+ // If true, turns off some optimizations for loading packets since the image
+ // patches are "non-standard" such as there are non-trivial strides or
+ // inflations in the input.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
+ return m_in_plane_strides != 1 || m_in_row_strides != 1 ||
+ m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 ||
+ m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
+ return SubMapper(*this, i, j);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
+ return LinearMapper(*this, i, j);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ // Load the coefficient at the patchIndex location instead of the usual
+ // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the
+ // gpu code.
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ // Load the packet at the patchIndex location instead of the usual m_rowIndex,
+ // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const {
+ return m_impl;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_patch_depth; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_patch_planes; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
+ const Index baseIndex) const {
+ const Index inputIndex = depth + baseIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ private:
+ friend class TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>;
+
+ // Load coefficient from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index inputCol = colIndex + colOffset * m_in_col_strides;
+ const Index origInputCol =
+ (m_patch_col_inflate_strides == 1)
+ ? inputCol
+ : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
+
+ const Index rowOffset =
+ (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+ const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
+ const Index origInputRow =
+ (m_patch_row_inflate_strides == 1)
+ ? inputRow
+ : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
+
+ const Index planeOffset =
+ patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+ const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides;
+ const Index origInputPlane =
+ (m_patch_plane_inflate_strides == 1)
+ ? inputPlane
+ : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0);
+
+ if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 ||
+ origInputCol >= m_inputCols || origInputRow >= m_inputRows ||
+ origInputPlane >= m_inputPlanes ||
+ (inputCol != origInputCol * m_patch_col_inflate_strides) ||
+ (inputRow != origInputRow * m_patch_row_inflate_strides) ||
+ (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) {
+ return Scalar(0);
+ }
+
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + origInputPlane * m_planeInputStride +
+ origInputRow * m_rowInputStride +
+ origInputCol * m_colInputStride + otherIndex;
+
+ return m_impl.coeff(inputIndex);
+ }
+
+ // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
+ // and `in_strides` equal to 1 (template specialization without templates).
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ eigen_assert(!nonStandardPatches());
+
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index inputCol = colIndex + colOffset;
+
+ const Index rowOffset =
+ (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+ const Index inputRow = rowIndex + rowOffset;
+
+ const Index planeOffset =
+ patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+ const Index inputPlane = planeIndex + planeOffset;
+
+ if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
+ inputRow >= m_inputRows || inputPlane < 0 ||
+ inputPlane >= m_inputPlanes) {
+ return Scalar(0);
+ }
+
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + inputPlane * m_planeInputStride +
+ inputRow * m_rowInputStride +
+ inputCol * m_colInputStride + otherIndex;
+
+ return m_impl.coeff(inputIndex);
+ }
+
+ // Load packet from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId <
+ patchDepth() * patchPlanes() * patchRows() * patchCols());
+
+ if (nonStandardPatches()) {
+ return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ }
+ return loadPacketStandard(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId <
+ patchDepth() * patchPlanes() * patchRows() * patchCols());
+ eigen_assert(!nonStandardPatches());
+
+ if ((patchDepth() % packetSize) == 0) {
+ return loadPacketFast(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ } else {
+ // Offsets and input calculation here are identical to
+ // loadCoeffStandard(...), but repeated twice.
+
+ const Index patchOffsets[2] = {
+ patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
+
+ const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
+ patchOffsets[1] / m_fastColStride};
+ eigen_assert(colOffsets[0] <= colOffsets[1]);
+
+ const Index inputCols[2] = {colIndex + colOffsets[0],
+ colIndex + colOffsets[1]};
+ if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputCols[0] == inputCols[1]) {
+ const Index rowOffsets[2] = {
+ (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride,
+ (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride};
+ eigen_assert(rowOffsets[0] <= rowOffsets[1]);
+ const Index inputRows[2] = {rowIndex + rowOffsets[0],
+ rowIndex + rowOffsets[1]};
+
+ if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputRows[0] == inputRows[1]) {
+ const Index planeOffsets[2] = {
+ patchOffsets[0] - colOffsets[0] * m_colStride -
+ rowOffsets[0] * m_rowStride,
+ patchOffsets[1] - colOffsets[1] * m_colStride -
+ rowOffsets[1] * m_rowStride};
+ eigen_assert(planeOffsets[0] <= planeOffsets[1]);
+ const Index inputPlanes[2] = {planeIndex + planeOffsets[0],
+ planeIndex + planeOffsets[1]};
+
+ if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) {
+ const Index depth = patchId - patchOffsets[0] * patchDepth();
+ const Index inputIndex =
+ depth + inputPlanes[0] * m_planeInputStride +
+ inputRows[0] * m_rowInputStride +
+ inputCols[0] * m_colInputStride + otherIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+ }
+ }
+ }
+
+ return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId <
+ patchDepth() * patchPlanes() * patchRows() * patchCols());
+
+ eigen_assert(!nonStandardPatches());
+ eigen_assert((patchDepth() % packetSize) == 0);
+
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+ eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index inputCol = colIndex + colOffset;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+ const Index inputRow = rowIndex + rowOffset;
+ const Index planeOffset =
+ patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+ const Index inputPlane = planeIndex + planeOffset;
+
+ if (inputCol < 0 || inputRow < 0 || inputPlane < 0 ||
+ inputCol >= m_inputCols || inputRow >= m_inputRows ||
+ inputPlane >= m_inputPlanes) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + inputPlane * m_planeInputStride +
+ inputRow * m_rowInputStride +
+ inputCol * m_colInputStride + otherIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
+ packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex,
+ Index colIndex, Index otherIndex) const {
+ const int packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_ALIGN_MAX
+ typename internal::remove_const<Scalar>::type values[packetSize];
+ for (int i = 0; i < packetSize; ++i) {
+ values[i] =
+ loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+ Packet rslt = internal::pload<Packet>(values);
+ return rslt;
+ }
+
+ // Precompute the indices (plane, row, col, other) of the first element of
+ // the given patch index, within the output tensor of the TensorVolumePatchOp.
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
+ Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex,
+ Index& otherIndex) const {
+ const int NumInputDims = array_size<
+ typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
+
+ // Check if patchIndex might contain batch and other dimensions.
+ otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches;
+
+ // Compute index of the patch within the batch (and other dimensions).
+ const Index patch3DIndex = (NumInputDims == 4)
+ ? patchIndex
+ : (patchIndex - otherIndex * m_num_patches);
+
+ otherIndex *= m_patchInputStride;
+
+ colIndex = patch3DIndex / m_fastOutputPlanesRows;
+ rowIndex =
+ (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
+ planeIndex =
+ patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes;
+
+ colIndex = colIndex * m_col_strides - m_colPaddingLeft;
+ rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
+ planeIndex = planeIndex * m_plane_strides - m_planePaddingTop;
+ }
+
+ Index m_patch_depth; // number of channels in the patch
+ Index m_patch_planes; // number of planes in the patch
+ Index m_patch_rows; // number of rows in the patch
+ Index m_patch_cols; // number of columns in the patch
+ Index m_num_patches; // number of patches to extract
+
+ // Strides for the output tensor.
+ Index m_rowStride;
+ Index m_colStride;
+ Index m_patchStride;
+ Index m_otherStride;
+
+ Index m_planeInputStride; // Plane stride in the input tensor
+ Index m_rowInputStride; // Row stride in the input tensor
+ Index m_colInputStride; // Col stride in the input tensor
+ Index m_patchInputStride; // Patch stride in the input tensor
+ Index m_otherInputStride;
+
+ Index m_inputDepth; // Depth of the input tensor
+ Index m_inputPlanes; // Number of planes in the input tensor
+ Index m_inputRows; // Number of rows in the input tensor
+ Index m_inputCols; // Number of cols in the input tensor
+
+ Index m_outputPlanes; // Number of output planes
+ Index m_outputRows; // Number of output rows
+ Index m_outputCols; // Number of output cols
+ Index m_outputPlanesRows; // Cached outputPlanes * outputRows.
+
+ Index m_plane_strides; // User specified plane stride
+ Index m_row_strides; // User specified row stride
+ Index m_col_strides; // User specified col stride
+
+ // User specified plane/row/col atrous convolution strides.
+ Index m_in_plane_strides;
+ Index m_in_row_strides;
+ Index m_in_col_strides;
+
+ // User specified plane/row/col inflation strides in the image patch.
+ Index m_patch_plane_inflate_strides;
+ Index m_patch_row_inflate_strides;
+ Index m_patch_col_inflate_strides;
+
+ Index m_planePaddingTop; // Plane padding
+ Index m_rowPaddingTop; // Row padding
+ Index m_colPaddingLeft; // Column padding
+
+ // Fast representation of various divisors.
+ internal::TensorIntDivisor<Index> m_fastNumPatches;
+
+ internal::TensorIntDivisor<Index> m_fastInputPlaneStride;
+ internal::TensorIntDivisor<Index> m_fastInputRowStride;
+ internal::TensorIntDivisor<Index> m_fastInputColStride;
+
+ internal::TensorIntDivisor<Index> m_fastRowStride;
+ internal::TensorIntDivisor<Index> m_fastColStride;
+
+ internal::TensorIntDivisor<Index> m_fastDimZero; // aka output depth
+ internal::TensorIntDivisor<Index> m_fastOutputPlanes;
+ internal::TensorIntDivisor<Index> m_fastOutputRows;
+ internal::TensorIntDivisor<Index> m_fastOutputCols;
+ internal::TensorIntDivisor<Index> m_fastOutputPlanesRows;
+
+ const TensorEvaluator<ArgType, Device> m_impl;
+};
+
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t, int Side,
+ int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+ int Alignment>
+class TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<NewDimension,
+ const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment> {
+ public:
+ typedef typename packet_traits<Scalar>::type Packet;
+ typedef typename packet_traits<Scalar>::half HalfPacket;
+
+ typedef TensorContractionInputMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ ParentMapper;
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ Self;
+ typedef Self LinearMapper;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
+ const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
+ : m_base_mapper(base_mapper),
+ m_depth_offset(vert_offset),
+ m_col_offset(horiz_offset) {
+ m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
+ m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
+ const Self& base_mapper, Index vert_offset, Index horiz_offset)
+ : m_base_mapper(base_mapper.m_base_mapper),
+ m_depth_offset(vert_offset + base_mapper.m_depth_offset),
+ m_col_offset(horiz_offset + base_mapper.m_col_offset) {
+ m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
+ m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
+ return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex,
+ m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i,
+ Index j) const {
+ return m_base_mapper(i + m_depth_offset, j + m_col_offset);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
+ return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex,
+ m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i,
+ Index j) const {
+ return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset,
+ j + m_col_offset);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar
+ loadCoeffStandard(Index i) const {
+ return m_base_mapper.loadCoeffStandard(
+ i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const {
+ return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex,
+ m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
+ loadPacketStandard(Index i) const {
+ return m_base_mapper.loadPacketStandard(
+ i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC bool aligned(Index) const {
+ return false;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
+ return m_base_mapper.nonStandardPatches();
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchDepth() const {
+ return m_base_mapper.m_patch_depth;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchPlanes() const {
+ return m_base_mapper.m_patch_planes;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRows() const {
+ return m_base_mapper.m_patch_rows;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchCols() const {
+ return m_base_mapper.m_patch_cols;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
+ const Index baseIndex) const {
+ const Index inputIndex = depth + baseIndex;
+ return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const {
+ const Index p = m_planeIndex + plane;
+ return p < 0 || p >= m_base_mapper.m_inputPlanes;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
+ const Index r = m_rowIndex + row;
+ return r < 0 || r >= m_base_mapper.m_inputRows;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
+ const Index c = m_colIndex + col;
+ return c < 0 || c >= m_base_mapper.m_inputCols;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row,
+ const Index col) const {
+ const Index p = m_planeIndex + plane;
+ const Index r = m_rowIndex + row;
+ const Index c = m_colIndex + col;
+ return p * m_base_mapper.m_planeInputStride +
+ r * m_base_mapper.m_rowInputStride +
+ c * m_base_mapper.m_colInputStride + m_otherIndex;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index planeOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_base_mapper.m_colStride) /
+ m_base_mapper.m_fastRowStride;
+ const Index planeOffset = patchOffset -
+ colOffset * m_base_mapper.m_colStride -
+ rowOffset * m_base_mapper.m_rowStride;
+ return planeOffset;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index rowOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_base_mapper.m_colStride) /
+ m_base_mapper.m_fastRowStride;
+ return rowOffset;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index colOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ return colOffset;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index depthOffset() const {
+ const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth();
+ return patchOffset;
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
+ getLinearMapper(Index i, Index j) const {
+ return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
+ }
+
+ private:
+ const ParentMapper& m_base_mapper;
+ Index m_depth_offset; // First row in the input matrix
+ Index m_col_offset; // First col in the input matrix
+
+ // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
+ // indices for the first element in a patch specified by col_offset
+ // (see computeBaseIndices(...) for details).
+ Index m_planeIndex;
+ Index m_rowIndex;
+ Index m_colIndex;
+ Index m_otherIndex;
+};
+
+// Arrange a block of the right input matrix (in our case it's always a "virtual
+// matrix" constructed from extracted volume patches) in contiguous memory.
+//
+// Given column major input (A0 beside A1 in memory):
+// A0 B0 C0 D0 E0 F0 G0 H0 ...
+// A1 B1 C1 D1 E1 F1 G1 H1 ...
+// A2 B2 C2 D2 E2 F2 G2 H2 ...
+// A3 B3 C3 D3 E3 F3 G3 H3 ...
+// A4 B4 C4 D4 E4 F4 G4 H4 ...
+// A5 B5 C5 D5 E5 F5 G5 H5 ...
+// A6 B6 C6 D6 E6 F6 G6 H6 ...
+// A7 B7 C7 D7 E7 F7 G7 H7 ...
+// A8 ...
+// ...
+//
+// Packing yields row major output (A0 beside A1 in memory):
+// A0 A1 A2 A3 A4 A5 A6 A7
+// B0 B1 B2 B3 B4 B5 B6 B7
+// C0 ...
+// ...
+//
+// *) A, B, C, ... - patches extracted from the original input.
+// *) nr - number of registers along the 'n' dimension.
+// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
+// Multiplication" paper.
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t,
+ int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+ int Alignment, int nr>
+struct gemm_pack_rhs<
+ Scalar, Index,
+ TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>,
+ nr, ColMajor, false, false> {
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ SubMapper;
+ typedef SubMapper DataMapper;
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ const Index packet_cols4 = (cols / 4) * 4;
+ const Index peeled_k = (depth / packet_size) * packet_size;
+ const bool non_standard_patches = rhs.nonStandardPatches();
+
+ for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ Index k = 0;
+ if ((packet_size % 4) == 0 && !non_standard_patches) {
+ const Index patch_depth = rhs.patchDepth();
+
+ if ((patch_depth % packet_size) == 0) {
+ const Index patch_cols = rhs.patchCols();
+ const Index patch_rows = rhs.patchRows();
+ const Index patch_planes = rhs.patchPlanes();
+
+ const Index startCol = rhs.colOffset();
+ const Index max_cols = std::min<Index>(
+ Eigen::divup(peeled_k, patch_rows * patch_planes * patch_depth) +
+ startCol,
+ patch_cols);
+
+ for (Index c = startCol; c < max_cols; ++c) {
+ eigen_assert(k < peeled_k);
+
+ const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
+ const Index max_rows = std::min<Index>(
+ Eigen::divup(
+ peeled_k - c * patch_rows * patch_planes * patch_depth,
+ patch_planes * patch_depth) +
+ startRow,
+ patch_rows);
+
+ const bool pad_col0 = dm0.padCol(c);
+ const bool pad_col1 = dm1.padCol(c);
+ const bool pad_col2 = dm2.padCol(c);
+ const bool pad_col3 = dm3.padCol(c);
+
+ for (Index r = startRow; r < max_rows; ++r) {
+ eigen_assert(k < peeled_k);
+
+ const Index startPlane =
+ ((c == startCol) && (r == startRow)) ? rhs.planeOffset() : 0;
+ const Index max_planes = std::min<Index>(
+ Eigen::divup(
+ peeled_k -
+ c * patch_rows * patch_planes * patch_depth - // col
+ r * patch_planes * patch_depth, // row
+ patch_depth) +
+ startPlane,
+ patch_planes);
+
+ const bool pad_row0 = dm0.padRow(r);
+ const bool pad_row1 = dm1.padRow(r);
+ const bool pad_row2 = dm2.padRow(r);
+ const bool pad_row3 = dm3.padRow(r);
+
+ for (Index p = startPlane; p < max_planes; ++p) {
+ eigen_assert(k < peeled_k);
+
+ const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
+ const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
+ const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p);
+ const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p);
+
+ const Index idx0 = dm0.baseIndex(p, r, c);
+ const Index idx1 = dm1.baseIndex(p, r, c);
+ const Index idx2 = dm2.baseIndex(p, r, c);
+ const Index idx3 = dm3.baseIndex(p, r, c);
+
+ const Index startDepth =
+ ((c == startCol) && (r == startRow) && (p == startPlane))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = std::min<Index>(
+ peeled_k -
+ c * patch_rows * patch_planes * patch_depth - // col
+ r * patch_planes * patch_depth - // row
+ p * patch_depth + // plane
+ startDepth,
+ patch_depth);
+ eigen_assert((max_depth - startDepth) % packet_size == 0);
+
+ for (Index d = startDepth; d < max_depth; d += packet_size) {
+ eigen_assert(k < peeled_k);
+ PacketBlock<Packet, 4> kernel;
+ kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx0);
+ kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx1);
+ kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx2);
+ kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx3);
+ ptranspose(kernel);
+ pstoreu(block + 0 * packet_size, kernel.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel.packet[1]);
+ pstoreu(block + 2 * packet_size, kernel.packet[2]);
+ pstoreu(block + 3 * packet_size, kernel.packet[3]);
+ block += 4 * packet_size;
+ k += packet_size;
+ }
+ }
+ }
+ }
+
+ for (; k < peeled_k; k += packet_size) {
+ PacketBlock<Packet, 4> kernel;
+ kernel.packet[0] = dm0.loadPacketFast(k);
+ kernel.packet[1] = dm1.loadPacketFast(k);
+ kernel.packet[2] = dm2.loadPacketFast(k);
+ kernel.packet[3] = dm3.loadPacketFast(k);
+ ptranspose(kernel);
+ pstoreu(block + 0 * packet_size, kernel.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel.packet[1]);
+ pstoreu(block + 2 * packet_size, kernel.packet[2]);
+ pstoreu(block + 3 * packet_size, kernel.packet[3]);
+ block += 4 * packet_size;
+ }
+ } else {
+ for (; k < peeled_k; k += packet_size) {
+ PacketBlock<Packet, 4> kernel;
+ kernel.packet[0] = dm0.loadPacketStandard(k);
+ kernel.packet[1] = dm1.loadPacketStandard(k);
+ kernel.packet[2] = dm2.loadPacketStandard(k);
+ kernel.packet[3] = dm3.loadPacketStandard(k);
+ ptranspose(kernel);
+ pstoreu(block + 0 * packet_size, kernel.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel.packet[1]);
+ pstoreu(block + 2 * packet_size, kernel.packet[2]);
+ pstoreu(block + 3 * packet_size, kernel.packet[3]);
+ block += 4 * packet_size;
+ }
+ }
+ }
+ if (!rhs.nonStandardPatches()) {
+ for (; k < depth; k++) {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ } else {
+ for (; k < depth; k++) {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // copy the remaining columns one at a time (nr==1)
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for (Index k = 0; k < depth; k++) {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
+// Template specialization for packet_size = 2. We must special-case packet
+// blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t,
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
+ int nr>
+struct gemm_pack_rhs<
+ Scalar, Index,
+ TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>,
+ nr, ColMajor, false, false> {
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ SubMapper;
+ typedef SubMapper DataMapper;
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ const int packet_size = 2;
+
+ const Index packet_cols4 = (cols / 4) * 4;
+ const Index peeled_k = (depth / packet_size) * packet_size;
+ const bool non_standard_patches = rhs.nonStandardPatches();
+
+ for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ Index k = 0;
+ if (!non_standard_patches) {
+ const Index patch_depth = rhs.patchDepth();
+
+ if ((patch_depth % packet_size) == 0) {
+ const Index patch_cols = rhs.patchCols();
+ const Index patch_rows = rhs.patchRows();
+ const Index patch_planes = rhs.patchPlanes();
+
+ const Index startCol = rhs.colOffset();
+ const Index max_cols = std::min<Index>(
+ Eigen::divup(peeled_k, patch_rows * patch_planes * patch_depth) +
+ startCol,
+ patch_cols);
+
+ for (Index c = startCol; c < max_cols; ++c) {
+ eigen_assert(k < peeled_k);
+
+ const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
+ const Index max_rows = std::min<Index>(
+ Eigen::divup(
+ peeled_k - c * patch_rows * patch_planes * patch_depth,
+ patch_planes * patch_depth) +
+ startRow,
+ patch_rows);
+
+ const bool pad_col0 = dm0.padCol(c);
+ const bool pad_col1 = dm1.padCol(c);
+ const bool pad_col2 = dm2.padCol(c);
+ const bool pad_col3 = dm3.padCol(c);
+
+ for (Index r = startRow; r < max_rows; ++r) {
+ eigen_assert(k < peeled_k);
+
+ const Index startPlane =
+ ((c == startCol) && (r == startRow)) ? rhs.planeOffset() : 0;
+ const Index max_planes = std::min<Index>(
+ Eigen::divup(
+ peeled_k -
+ c * patch_rows * patch_planes * patch_depth - // col
+ r * patch_planes * patch_depth, // row
+ patch_depth) +
+ startPlane,
+ patch_planes);
+
+ const bool pad_row0 = dm0.padRow(r);
+ const bool pad_row1 = dm1.padRow(r);
+ const bool pad_row2 = dm2.padRow(r);
+ const bool pad_row3 = dm3.padRow(r);
+
+ for (Index p = startPlane; p < max_planes; ++p) {
+ eigen_assert(k < peeled_k);
+
+ const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
+ const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
+ const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p);
+ const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p);
+
+ const Index idx0 = dm0.baseIndex(p, r, c);
+ const Index idx1 = dm1.baseIndex(p, r, c);
+ const Index idx2 = dm2.baseIndex(p, r, c);
+ const Index idx3 = dm3.baseIndex(p, r, c);
+
+ const Index startDepth =
+ ((c == startCol) && (r == startRow) && (p == startPlane))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = std::min<Index>(
+ peeled_k -
+ c * patch_rows * patch_planes * patch_depth - // col
+ r * patch_planes * patch_depth - // row
+ p * patch_depth + // plane
+ startDepth,
+ patch_depth);
+ eigen_assert((max_depth - startDepth) % packet_size == 0);
+
+ for (Index d = startDepth; d < max_depth; d += packet_size) {
+ eigen_assert(k < peeled_k);
+ PacketBlock<Packet, 2> kernel0;
+ PacketBlock<Packet, 2> kernel1;
+ kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx0);
+ kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx1);
+ kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx2);
+ kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx3);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+ pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+ pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+ pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+ block += 4 * packet_size;
+ k += packet_size;
+ }
+ }
+ }
+ }
+
+ for (; k < peeled_k; k += packet_size) {
+ PacketBlock<Packet, 2> kernel0;
+ PacketBlock<Packet, 2> kernel1;
+ kernel0.packet[0] = dm0.loadPacketFast(k);
+ kernel0.packet[1] = dm1.loadPacketFast(k);
+ kernel1.packet[0] = dm2.loadPacketFast(k);
+ kernel1.packet[1] = dm3.loadPacketFast(k);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+ pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+ pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+ pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+ block += 4 * packet_size;
+ }
+ } else {
+ for (; k < peeled_k; k += packet_size) {
+ PacketBlock<Packet, 2> kernel0;
+ PacketBlock<Packet, 2> kernel1;
+ kernel0.packet[0] = dm0.loadPacketStandard(k);
+ kernel0.packet[1] = dm1.loadPacketStandard(k);
+ kernel1.packet[0] = dm2.loadPacketStandard(k);
+ kernel1.packet[1] = dm3.loadPacketStandard(k);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+ pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+ pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+ pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+ block += 4 * packet_size;
+ }
+ }
+ }
+ if (!rhs.nonStandardPatches()) {
+ for (; k < depth; k++) {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ } else {
+ for (; k < depth; k++) {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // copy the remaining columns one at a time (nr==1)
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for (Index k = 0; k < depth; k++) {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
+// Special case for non-vectorized types such as float16 (packet_size = 1).
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t,
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
+ int nr>
+struct gemm_pack_rhs<
+ Scalar, Index,
+ TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>,
+ nr, ColMajor, false, false> {
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
+ Alignment>
+ SubMapper;
+ typedef SubMapper DataMapper;
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ const Index packet_cols4 = (cols / 4) * 4;
+
+ for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ if (!rhs.nonStandardPatches()) {
+ for (Index k = 0; k < depth; k++) {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ } else {
+ for (Index k = 0; k < depth; k++) {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // copy the remaining columns one at a time (nr==1)
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for (Index k = 0; k < depth; k++) {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
+} // namespace internal
+
/** CuboidConvolution
* \ingroup CXX11_NeuralNetworks_Module
*
diff --git a/tensorflow/core/kernels/eigen_volume_patch.h b/tensorflow/core/kernels/eigen_volume_patch.h
index a3d795813d..80ab745bfe 100644
--- a/tensorflow/core/kernels/eigen_volume_patch.h
+++ b/tensorflow/core/kernels/eigen_volume_patch.h
@@ -43,6 +43,7 @@ struct CustomTensorEvaluator {
IsAligned = false,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
BlockAccess = false,
+ PreferBlockAccess = false,
Layout = TensorEvaluator<ArgType, Device>::Layout,
CoordAccess = NumDims == 6,
RawAccess = false
diff --git a/tensorflow/core/kernels/fuzzing/BUILD b/tensorflow/core/kernels/fuzzing/BUILD
index 8bfa40304e..f2e0b2558f 100644
--- a/tensorflow/core/kernels/fuzzing/BUILD
+++ b/tensorflow/core/kernels/fuzzing/BUILD
@@ -43,4 +43,6 @@ tf_ops_fuzz_target_lib("example_proto_fast_parsing")
tf_ops_fuzz_target_lib("parse_tensor_op")
+tf_ops_fuzz_target_lib("decode_compressed")
+
tf_ops_fuzz_target_lib("decode_json_example")
diff --git a/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc
new file mode 100644
index 0000000000..0a56f4b63f
--- /dev/null
+++ b/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
+
+namespace tensorflow {
+namespace fuzzing {
+
+class FuzzDecodeCompressed : public FuzzStringInputOp {
+ void BuildGraph(const Scope& scope) override {
+ auto input =
+ tensorflow::ops::Placeholder(scope.WithOpName("input1"), DT_STRING);
+ auto d1 = tensorflow::ops::DecodeCompressed(
+ scope.WithOpName("d1"), input,
+ tensorflow::ops::DecodeCompressed::CompressionType(""));
+ auto d2 = tensorflow::ops::DecodeCompressed(
+ scope.WithOpName("d2"), input,
+ tensorflow::ops::DecodeCompressed::CompressionType("ZLIB"));
+ auto d3 = tensorflow::ops::DecodeCompressed(
+ scope.WithOpName("d3"), input,
+ tensorflow::ops::DecodeCompressed::CompressionType("GZIP"));
+ Scope grouper =
+ scope.WithControlDependencies(std::vector<tensorflow::Operation>{
+ d1.output.op(), d2.output.op(), d3.output.op()});
+ (void)tensorflow::ops::NoOp(grouper.WithOpName("output"));
+ }
+};
+
+STANDARD_TF_FUZZ_FUNCTION(FuzzDecodeCompressed);
+
+} // namespace fuzzing
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/gather_functor.h b/tensorflow/core/kernels/gather_functor.h
index cd2873bdca..7710cf93d6 100644
--- a/tensorflow/core/kernels/gather_functor.h
+++ b/tensorflow/core/kernels/gather_functor.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/type_traits.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/prefetch.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
index ad0112e6cb..277ee2be02 100644
--- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
+++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
@@ -113,10 +113,25 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
#endif
generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator(
slice_size, Tindices, Tparams, Tout, &error_loc);
+
+#ifdef INTEL_MKL
+// Eigen implementation below is not highly performant. gather_nd_generator
+// does not seem to be called in parallel, leading to very poor performance.
+// Additionally, since it uses scalar (Tscratch) to invoke 'generate', it
+// needs to go through redundant operations like 'reshape', 'broadcast' and
+// 'sum'. OpenMP loop below essentially does same thing as Eigen code, but
+// is considerably more efficient.
+#pragma omp parallel for
+ for (Eigen::DenseIndex i = 0; i < batch_size; i++) {
+ const Eigen::array<Eigen::DenseIndex, 1> loc{i};
+ gather_nd_generator(loc);
+ }
+#else // INTEL_MKL
Tscratch.device(d) = Tscratch.reshape(reshape_dims)
.broadcast(broadcast_dims)
.generate(gather_nd_generator)
.sum();
+#endif
// error_loc() returns -1 if there's no out-of-bounds index,
// otherwise it returns the location of an OOB index in Tindices.
diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h
index c7dbefa0b4..86146f75f4 100644
--- a/tensorflow/core/kernels/gpu_utils.h
+++ b/tensorflow/core/kernels/gpu_utils.h
@@ -123,8 +123,7 @@ class AutoTuneMap {
string GetActionSummary(StringPiece action, const Parameters& params,
const Config& config) {
return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(),
- std::string(action).c_str(),
- params.ToString().c_str(),
+ string(action).c_str(), params.ToString().c_str(),
config.ToString().c_str());
}
diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc
index bca1cff41c..2088c13586 100644
--- a/tensorflow/core/kernels/list_kernels.cc
+++ b/tensorflow/core/kernels/list_kernels.cc
@@ -77,9 +77,9 @@ static Status TensorListDeviceCopy(
return Status::OK();
}
-#define REGISTER_LIST_COPY(DIRECTION) \
- INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- TensorList, DIRECTION, TensorList::kTypeName, TensorListDeviceCopy)
+#define REGISTER_LIST_COPY(DIRECTION) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \
+ TensorListDeviceCopy)
REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
@@ -92,8 +92,7 @@ Status TensorListShape(const TensorList& t, TensorShape* s) {
return Status::OK();
}
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorList::kTypeName,
- TensorListShape);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorListShape);
bool TensorList::Decode(const VariantTensorData& data) {
tensors = data.tensors();
@@ -625,12 +624,11 @@ REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(bfloat16);
#undef REGISTER_TENSOR_LIST_FROM_TENSOR_CPU
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- TensorList, TensorList::kTypeName,
+ TensorList,
TensorListBinaryAdd<CPUDevice>);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_CPU, TensorList,
- TensorList::kTypeName,
TensorListZerosLike<CPUDevice>);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/list_kernels.cu.cc b/tensorflow/core/kernels/list_kernels.cu.cc
index c591226b76..a00bf700ca 100644
--- a/tensorflow/core/kernels/list_kernels.cu.cc
+++ b/tensorflow/core/kernels/list_kernels.cu.cc
@@ -94,11 +94,10 @@ REGISTER_TENSOR_LIST_FROM_TENSOR_GPU(bool);
#undef REGISTER_TENSOR_LIST_FROM_TENSOR_GPU
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- TensorList, TensorList::kTypeName,
+ TensorList,
TensorListBinaryAdd<GPUDevice>);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_GPU, TensorList,
- TensorList::kTypeName,
TensorListZerosLike<GPUDevice>);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h
index 066a1d603b..72581c9293 100644
--- a/tensorflow/core/kernels/list_kernels.h
+++ b/tensorflow/core/kernels/list_kernels.h
@@ -374,7 +374,12 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
y->tensors.reserve(x.tensors.size());
for (const Tensor& t : x.tensors) {
Tensor out_tensor;
- TF_RETURN_IF_ERROR(c->allocate_temp(t.dtype(), t.shape(), &out_tensor));
+ AllocatorAttributes attr;
+ if (t.dtype() == DT_VARIANT) {
+ attr.set_on_host(true);
+ }
+ TF_RETURN_IF_ERROR(
+ c->allocate_temp(t.dtype(), t.shape(), &out_tensor, attr));
switch (out_tensor.dtype()) {
#define DTYPE_CASE(dtype) \
case DataTypeToEnum<dtype>::value: \
@@ -385,6 +390,20 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
TF_CALL_POD_TYPES(DTYPE_CASE)
#undef DTYPE_CASE
+
+ case DataTypeToEnum<Variant>::value: {
+ const TensorList* inner_x = t.scalar<Variant>()().get<TensorList>();
+ if (inner_x == nullptr) {
+ return errors::InvalidArgument("Input handle is not a list. Saw: '",
+ t.scalar<Variant>()().DebugString(),
+ "'");
+ }
+ TensorList inner_y;
+ TF_RETURN_IF_ERROR(TensorListZerosLike<Device>(c, *inner_x, &inner_y));
+ out_tensor.scalar<Variant>()() = std::move(inner_y);
+ break;
+ }
+
default:
return errors::InvalidArgument(
"Trying to compute zeros_like for unsupported dtype ",
diff --git a/tensorflow/core/kernels/logistic-loss.h b/tensorflow/core/kernels/logistic-loss.h
index b43902e0b9..9198a98e47 100644
--- a/tensorflow/core/kernels/logistic-loss.h
+++ b/tensorflow/core/kernels/logistic-loss.h
@@ -86,7 +86,7 @@ class LogisticLossUpdater : public DualLossUpdater {
} else {
inverse_exp_term = 1 / (1 + exp(label * wx));
}
- return inverse_exp_term * label * example_weight;
+ return -inverse_exp_term * label * example_weight;
}
// The smoothness constant is 4 since the derivative of logistic loss, which
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index 2e8d9c623c..a495758861 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -50,7 +50,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {}
size_t size() const override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return table_.size();
}
@@ -60,7 +60,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
const auto key_values = key.flat<K>();
auto value_values = value->flat<V>();
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
value_values(i) = gtl::FindWithDefault(
table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
@@ -95,7 +95,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
}
Status ExportValues(OpKernelContext* ctx) override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
int64 size = table_.size();
Tensor* keys;
@@ -125,7 +125,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
int64 MemoryUsed() const override {
int64 ret = 0;
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (unsigned i = 0; i < table_.bucket_count(); ++i) {
size_t bucket_size = table_.bucket_size(i);
if (bucket_size == 0) {
@@ -138,7 +138,6 @@ class MutableHashTableOfScalars final : public LookupInterface {
}
private:
- // TODO(andreasst): consider using a read/write lock or a concurrent map
mutable mutex mu_;
std::unordered_map<K, V> table_ GUARDED_BY(mu_);
};
@@ -158,7 +157,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
}
size_t size() const override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return table_.size();
}
@@ -169,7 +168,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
auto value_values = value->flat_inner_dims<V, 2>();
int64 value_dim = value_shape_.dim_size(0);
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
ValueArray* value_vec =
gtl::FindOrNull(table_, SubtleMustCopyIfIntegral(key_values(i)));
@@ -219,7 +218,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
}
Status ExportValues(OpKernelContext* ctx) override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
int64 size = table_.size();
int64 value_dim = value_shape_.dim_size(0);
@@ -254,7 +253,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
int64 MemoryUsed() const override {
int64 ret = 0;
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (unsigned i = 0; i < table_.bucket_count(); ++i) {
size_t bucket_size = table_.bucket_size(i);
if (bucket_size == 0) {
@@ -268,7 +267,6 @@ class MutableHashTableOfTensors final : public LookupInterface {
private:
TensorShape value_shape_;
- // TODO(andreasst): consider using a read/write lock or a concurrent map
mutable mutex mu_;
typedef gtl::InlinedVector<V, 4> ValueArray;
std::unordered_map<K, ValueArray> table_ GUARDED_BY(mu_);
@@ -335,7 +333,7 @@ class MutableDenseHashTable final : public LookupInterface {
}
size_t size() const override LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return num_entries_;
}
@@ -355,7 +353,7 @@ class MutableDenseHashTable final : public LookupInterface {
auto value_matrix = value->shaped<V, 2>({num_elements, value_size});
const auto default_flat = default_value.flat<V>();
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
const auto key_buckets_matrix =
key_buckets_.AccessTensor(ctx)->template matrix<K>();
const auto value_buckets_matrix =
@@ -451,7 +449,7 @@ class MutableDenseHashTable final : public LookupInterface {
}
Status ExportValues(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
Tensor key_buckets_tensor = *key_buckets_.AccessTensor(ctx);
Tensor value_buckets_tensor = *value_buckets_.AccessTensor(ctx);
TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_tensor));
@@ -493,7 +491,7 @@ class MutableDenseHashTable final : public LookupInterface {
TensorShape value_shape() const override { return value_shape_; }
int64 MemoryUsed() const override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() +
value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes();
}
diff --git a/tensorflow/core/kernels/loss_test.cc b/tensorflow/core/kernels/loss_test.cc
index 460d65c5c2..9209ed2ab7 100644
--- a/tensorflow/core/kernels/loss_test.cc
+++ b/tensorflow/core/kernels/loss_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/hinge-loss.h"
#include "tensorflow/core/kernels/logistic-loss.h"
+#include "tensorflow/core/kernels/poisson-loss.h"
#include "tensorflow/core/kernels/smooth-hinge-loss.h"
#include "tensorflow/core/kernels/squared-loss.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -30,6 +31,24 @@ namespace {
// TODO(sibyl-Aix6ihai): add a test to show the improvements of the Newton
// modification detailed in readme.md
+// This test checks that the dual value after update is optimal.
+// At the optimum the dual value should be the opposite of the primal gradient.
+// This does not hold at a point where the primal is not differentiable.
+void TestComputeUpdatedDual(const DualLossUpdater &loss_updater,
+ const int num_loss_partitions, const double label,
+ const double example_weight,
+ const double current_dual, const double wx,
+ const double weighted_example_norm) {
+ double new_dual = loss_updater.ComputeUpdatedDual(
+ num_loss_partitions, label, example_weight, current_dual, wx,
+ weighted_example_norm);
+ // The primal gradient needs to be computed after the weight update.
+ double new_wx = wx + (new_dual - current_dual) * num_loss_partitions *
+ weighted_example_norm * example_weight;
+ EXPECT_NEAR(new_dual, -loss_updater.PrimalLossDerivative(new_wx, label, 1.0),
+ 1e-5);
+}
+
TEST(LogisticLoss, ComputePrimalLoss) {
LogisticLossUpdater loss_updater;
EXPECT_NEAR(0.693147,
@@ -65,19 +84,12 @@ TEST(LogisticLoss, ComputeDualLoss) {
TEST(LogisticLoss, ComputeUpdatedDual) {
LogisticLossUpdater loss_updater;
- EXPECT_NEAR(0.479,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, 1.0 /* label */,
- 1.0 /* example weight */, 0.5 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
-
- EXPECT_NEAR(-0.031,
- loss_updater.ComputeUpdatedDual(
- 2 /* num partitions */, -1.0 /* label */,
- 1.0 /* example weight */, 0.1 /* current_dual */,
- -0.8 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */,
+ 1.0 /* example weight */, 0.5 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 2 /* num partitions */, -1.0 /* label */,
+ 1.0 /* example weight */, 0.1 /* current_dual */,
+ -0.8 /* wx */, 10.0 /* weighted_example_norm */);
}
TEST(SquaredLoss, ComputePrimalLoss) {
@@ -126,19 +138,12 @@ TEST(SquaredLoss, ComputeDualLoss) {
TEST(SquaredLoss, ComputeUpdatedDual) {
SquaredLossUpdater loss_updater;
- EXPECT_NEAR(0.336,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, 1.0 /* label */,
- 1.0 /* example weight */, 0.3 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
-
- EXPECT_NEAR(-0.427,
- loss_updater.ComputeUpdatedDual(
- 5 /* num partitions */, -1.0 /* label */,
- 1.0 /* example weight */, -0.4 /* current_dual */,
- 0.8 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */,
+ 1.0 /* example weight */, 0.3 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 5 /* num partitions */, -1.0 /* label */,
+ 1.0 /* example weight */, -0.4 /* current_dual */,
+ 0.8 /* wx */, 10.0 /* weighted_example_norm */);
}
TEST(HingeLoss, ComputePrimalLoss) {
@@ -207,48 +212,27 @@ TEST(HingeLoss, ConvertLabel) {
TEST(HingeLoss, ComputeUpdatedDual) {
HingeLossUpdater loss_updater;
- // When label=1.0, example_weight=1.0, current_dual=0.5, wx=0.3 and
- // weighted_example_norm=100.0, it turns out that the optimal value to update
- // the dual to is 0.507 which is within the permitted range and thus should be
- // the value returned.
+ // For the two tests belows, y*wx=1 after the update which is a
+ // non-differetiable point of the hinge loss and TestComputeUpdatedDual
+ // cannot be used. Check value of the dual variable instead.
EXPECT_NEAR(0.507,
loss_updater.ComputeUpdatedDual(
1 /* num partitions */, 1.0 /* label */,
1.0 /* example weight */, 0.5 /* current_dual */,
0.3 /* wx */, 100.0 /* weighted_example_norm */),
1e-3);
- // When label=-1.0, example_weight=1.0, current_dual=0.4, wx=0.6,
- // weighted_example_norm=10.0 and num_loss_partitions=10, it turns out that
- // the optimal value to update the dual to is 0.384 which is within the
- // permitted range and thus should be the value returned.
EXPECT_NEAR(-0.416,
loss_updater.ComputeUpdatedDual(
10 /* num partitions */, -1.0 /* label */,
1.0 /* example weight */, -0.4 /* current_dual */,
0.6 /* wx */, 10.0 /* weighted_example_norm */),
1e-3);
- // When label=1.0, example_weight=1.0, current_dual=-0.5, wx=0.3 and
- // weighted_example_norm=10.0, it turns out that the optimal value to update
- // the dual to is -0.43. However, this is outside the allowed [0.0, 1.0] range
- // and hence the closest permitted value (0.0) should be returned instead.
- EXPECT_NEAR(0.0,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, 1.0 /* label */,
- 1.0 /* example weight */, -0.5 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
-
- // When label=-1.0, example_weight=2.0, current_dual=-1.0, wx=0.3 and
- // weighted_example_norm=10.0, it turns out that the optimal value to update
- // the dual to is -1.065. However, this is outside the allowed [-1.0, 0.0]
- // range and hence the closest permitted value (-1.0) should be returned
- // instead.
- EXPECT_NEAR(-1.0,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, -1.0 /* label */,
- 2.0 /* example weight */, -1.0 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */,
+ 1.0 /* example weight */, -0.5 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, -1.0 /* label */,
+ 2.0 /* example weight */, -1.0 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
}
TEST(SmoothHingeLoss, ComputePrimalLoss) {
@@ -297,19 +281,75 @@ TEST(SmoothHingeLoss, ComputeDualLoss) {
TEST(SmoothHingeLoss, ComputeUpdatedDual) {
SmoothHingeLossUpdater loss_updater;
- EXPECT_NEAR(0.336,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, 1.0 /* label */,
- 1.0 /* example weight */, 0.3 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */,
+ 1.0 /* example weight */, 0.3 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 5 /* num partitions */, -1.0 /* label */,
+ 1.0 /* example weight */, -0.4 /* current_dual */,
+ 0.8 /* wx */, 10.0 /* weighted_example_norm */);
+}
- EXPECT_NEAR(-0.427,
- loss_updater.ComputeUpdatedDual(
- 5 /* num partitions */, -1.0 /* label */,
- 1.0 /* example weight */, -0.4 /* current_dual */,
- 0.8 /* wx */, 10.0 /* weighted_example_norm */),
+TEST(PoissonLoss, ComputePrimalLoss) {
+ PoissonLossUpdater loss_updater;
+ EXPECT_NEAR(1.0,
+ loss_updater.ComputePrimalLoss(0.0 /* wx */, 3.0 /* label */,
+ 1.0 /* example weight */),
1e-3);
+ EXPECT_NEAR(21996.0,
+ loss_updater.ComputePrimalLoss(10.0 /* wx */, 3.0 /* label */,
+ 1.0 /* example weight */),
+ 1.0);
+ EXPECT_NEAR(0.606,
+ loss_updater.ComputePrimalLoss(-0.5 /* wx */, 0.0 /* label */,
+ 1.0 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(6.64,
+ loss_updater.ComputePrimalLoss(1.2 /* wx */, 0.0 /* label */,
+ 2.0 /* example weight */),
+ 1e-2);
+}
+
+TEST(PoissonLoss, ComputeDualLoss) {
+ PoissonLossUpdater loss_updater;
+ // Dual is undefined.
+ EXPECT_NEAR(
+ std::numeric_limits<double>::max(),
+ loss_updater.ComputeDualLoss(1.0 /* current dual */, 0.0 /* label */,
+ 1.0 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(
+ 0.0,
+ loss_updater.ComputeDualLoss(0.0 /* current dual */, 0.0 /* label */,
+ 3.0 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(
+ -0.847,
+ loss_updater.ComputeDualLoss(1.5 /* current dual */, 2.0 /* label */,
+ 1.0 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(
+ -2.675,
+ loss_updater.ComputeDualLoss(0.5 /* current dual */, 2.0 /* label */,
+ 3.0 /* example weight */),
+ 1e-3);
+}
+
+TEST(PoissonLoss, ConvertLabel) {
+ PoissonLossUpdater loss_updater;
+ float example_label = -1.0;
+ // Negative label should throw an error.
+ Status status = loss_updater.ConvertLabel(&example_label);
+ EXPECT_FALSE(status.ok());
+}
+
+TEST(PoissonLoss, ComputeUpdatedDual) {
+ PoissonLossUpdater loss_updater;
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 2.0 /* label */,
+ 1.0 /* example weight */, 0.5 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 2 /* num partitions */, 0.0 /* label */,
+ 1.0 /* example weight */, 0.0 /* current_dual */,
+ -0.8 /* wx */, 10.0 /* weighted_example_norm */);
}
} // namespace
diff --git a/tensorflow/core/kernels/map_stage_op.cc b/tensorflow/core/kernels/map_stage_op.cc
index bdc3b5778f..dd89597369 100644
--- a/tensorflow/core/kernels/map_stage_op.cc
+++ b/tensorflow/core/kernels/map_stage_op.cc
@@ -410,8 +410,9 @@ class StagingMap : public ResourceBase {
copy_or_move_tensors(&it->second, *key, *indices, tuple));
// Remove entry if all the values have been consumed
- if (!std::any_of(it->second.begin(), it->second.end(),
- std::mem_fn(&OptionalTensor::has_value))) {
+ if (!std::any_of(
+ it->second.begin(), it->second.end(),
+ [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
map_.erase(it);
}
@@ -444,8 +445,9 @@ class StagingMap : public ResourceBase {
*key = it->first;
// Remove entry if all the values have been consumed
- if (!std::any_of(it->second.begin(), it->second.end(),
- std::mem_fn(&OptionalTensor::has_value))) {
+ if (!std::any_of(
+ it->second.begin(), it->second.end(),
+ [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
map_.erase(it);
}
diff --git a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc
index 10e468ce46..693ed8a8f0 100644
--- a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc
+++ b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc
@@ -114,9 +114,7 @@ class MergeV2CheckpointsOpTest : public OpsTestBase {
// Exercises "delete_old_dirs".
for (int i = 0; i < 2; ++i) {
int directory_found =
- Env::Default()
- ->IsDirectory(std::string(io::Dirname(prefixes[i])))
- .code();
+ Env::Default()->IsDirectory(string(io::Dirname(prefixes[i]))).code();
if (delete_old_dirs) {
EXPECT_EQ(error::NOT_FOUND, directory_found);
} else {
diff --git a/tensorflow/core/kernels/mirror_pad_op.h b/tensorflow/core/kernels/mirror_pad_op.h
index cc4b6941b9..62aa7d5c29 100644
--- a/tensorflow/core/kernels/mirror_pad_op.h
+++ b/tensorflow/core/kernels/mirror_pad_op.h
@@ -103,6 +103,7 @@ struct TensorEvaluator<const TensorMirrorPadOp<PaddingDimensions, ArgType>,
IsAligned = false,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
BlockAccess = false,
+ PreferBlockAccess = false,
Layout = TensorEvaluator<ArgType, Device>::Layout,
CoordAccess = true,
RawAccess = false
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index afbfaa83f3..52157ed5fb 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -300,19 +300,24 @@ template <typename T>
class MklConvBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklConvBwdFilterPrimitive<T>* Get(
- const MklConvBwdFilterParams& convBwdFilterDims) {
+ const MklConvBwdFilterParams& convBwdFilterDims, bool do_not_cache) {
MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr;
- // look into the pool for reusable primitive
- conv_bwd_filter = dynamic_cast<MklConvBwdFilterPrimitive<T>*>(
+ if (do_not_cache) { /* Create new primitive always */
+ conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
+ } else {
+ // look into the pool for reusable primitive
+ conv_bwd_filter = dynamic_cast<MklConvBwdFilterPrimitive<T>*> (
MklConvBwdFilterPrimitiveFactory<T>::GetInstance().GetConvBwdFilter(
convBwdFilterDims));
- if (conv_bwd_filter == nullptr) {
- conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
- MklConvBwdFilterPrimitiveFactory<T>::GetInstance().SetConvBwdFilter(
- convBwdFilterDims, conv_bwd_filter);
+ if (conv_bwd_filter == nullptr) {
+ conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
+ MklConvBwdFilterPrimitiveFactory<T>::GetInstance().SetConvBwdFilter(
+ convBwdFilterDims, conv_bwd_filter);
+ }
}
+
return conv_bwd_filter;
}
@@ -845,8 +850,13 @@ class MklConvCustomBackpropFilterOp
MklConvBwdFilterParams convBwdFilterDims(fwd_src_dims, fwd_filter_dims,
diff_bias_dims, diff_dst_dims, strides, dilations, padding_left,
padding_right, TFPaddingToMklDnnPadding(this->padding_));
- conv_bwd_filter =
- MklConvBwdFilterPrimitiveFactory<T>::Get(convBwdFilterDims);
+
+ // MKL DNN allocates large buffers when a conv gradient filter primtive is
+ // created. So we don't cache conv backward primitives when the env
+ // variable TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is set to true.
+ bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled();
+ conv_bwd_filter = MklConvBwdFilterPrimitiveFactory<T>::Get(
+ convBwdFilterDims, do_not_cache);
auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc();
// allocate output tensors: diff_fitler and diff_bias (w bias)
@@ -938,6 +948,9 @@ class MklConvCustomBackpropFilterOp
if (diff_filter_reorder_required) {
diff_filter.InsertReorderToUserMem();
}
+
+ // delete primitive since it is not cached.
+ if (do_not_cache) delete conv_bwd_filter;
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index b5a98301e2..c38c9cc27c 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -174,7 +174,6 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
}
};
-
void Setup(const MklConvBwdInputParams& convBwdInputDims) {
// create memory descriptors for convolution data w/ no specified format
context_.diff_src_md.reset(new memory::desc(
@@ -242,19 +241,23 @@ class MklConvBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklConvBwdInputPrimitive<T>* Get(
- const MklConvBwdInputParams& convBwdInputDims) {
+ const MklConvBwdInputParams& convBwdInputDims, bool do_not_cache) {
MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr;
- // look into the pool for reusable primitive
- conv_bwd_input = dynamic_cast<MklConvBwdInputPrimitive<T>*>(
- MklConvBwdInputPrimitiveFactory<T>::GetInstance().GetConvBwdInput(
- convBwdInputDims));
-
- if (conv_bwd_input == nullptr) {
+ if (do_not_cache) { /* Always allocate primitive */
conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims);
- MklConvBwdInputPrimitiveFactory<T>::GetInstance().SetConvBwdInput(
- convBwdInputDims, conv_bwd_input);
+ } else {
+ // look into the pool for reusable primitive
+ conv_bwd_input = dynamic_cast<MklConvBwdInputPrimitive<T>*>(
+ MklConvBwdInputPrimitiveFactory<T>::GetInstance().GetConvBwdInput(
+ convBwdInputDims));
+ if (conv_bwd_input == nullptr) {
+ conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims);
+ MklConvBwdInputPrimitiveFactory<T>::GetInstance().SetConvBwdInput(
+ convBwdInputDims, conv_bwd_input);
+ }
}
+
return conv_bwd_input;
}
@@ -708,8 +711,18 @@ class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> {
MklConvBwdInputParams convBwdInputDims(fwd_src_dims, fwd_filter_dims,
diff_dst_dims, strides, dilations, padding_left, padding_right,
TFPaddingToMklDnnPadding(this->padding_));
- conv_bwd_input =
- MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims);
+
+ // We don't cache those primitves if the env variable
+ // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true and if primitve descriptor
+ // includes potentialy large buffers. MKL DNN allocates buffers
+ // in the following cases
+ // 1. Legacy CPU without AVX512/AVX2, or
+ // 2. 1x1 convolution with stride != 1
+ bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled() &&
+ (MklPrimitiveFactory<T>::IsLegacyPlatform() ||
+ IsConv1x1StrideNot1(fwd_filter_dims, strides));
+ conv_bwd_input = MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims,
+ do_not_cache);
auto bwd_input_pd = conv_bwd_input->GetPrimitiveDesc();
// allocate output tensor
@@ -755,6 +768,11 @@ class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> {
// execute convolution input bwd
conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
+
+ // delete primitive since it is not cached.
+ if (do_not_cache) {
+ delete conv_bwd_input;
+ }
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index c6295c7280..184e0cb003 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -271,18 +271,23 @@ class MklConvFwdPrimitive : public MklPrimitive {
template <typename T>
class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
- static MklConvFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) {
+ static MklConvFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims,
+ bool do_not_cache) {
MklConvFwdPrimitive<T>* conv_fwd = nullptr;
- // try to find a suitable one in pool
- conv_fwd = dynamic_cast<MklConvFwdPrimitive<T>*>(
- MklConvFwdPrimitiveFactory<T>::GetInstance().GetConvFwd(convFwdDims));
-
- if (conv_fwd == nullptr) {
+ if (do_not_cache) { /* Always create new primitive */
conv_fwd = new MklConvFwdPrimitive<T>(convFwdDims);
- MklConvFwdPrimitiveFactory<T>::GetInstance().SetConvFwd(convFwdDims,
- conv_fwd);
+ } else {
+ // try to find a suitable one in pool
+ conv_fwd = dynamic_cast<MklConvFwdPrimitive<T>*>(
+ MklConvFwdPrimitiveFactory<T>::GetInstance().GetConvFwd(convFwdDims));
+ if (conv_fwd == nullptr) {
+ conv_fwd = new MklConvFwdPrimitive<T>(convFwdDims);
+ MklConvFwdPrimitiveFactory<T>::GetInstance().SetConvFwd(convFwdDims,
+ conv_fwd);
+ }
}
+
return conv_fwd;
}
@@ -894,6 +899,17 @@ class MklConvOp : public OpKernel {
// MKLDNN dilation starts from 0.
for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
+ // In some cases, primitve descriptor includes potentialy large buffers,
+ // we don't cache those primitves if the env variable
+ // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true. MKL DNN allocates buffers
+ // in the following cases
+ // 1. Legacy CPU without AVX512/AVX2, or
+ // 2. 1x1 convolution with stride != 1
+ bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled() &&
+ (src_dims[MklDnnDims::Dim_N] > kSmallBatchSize) &&
+ (MklPrimitiveFactory<T>::IsLegacyPlatform() ||
+ IsConv1x1StrideNot1(filter_dims, strides));
+
// get a conv2d fwd from primitive pool
MklConvFwdPrimitive<T>* conv_fwd = nullptr;
if (biasEnabled) {
@@ -902,12 +918,14 @@ class MklConvOp : public OpKernel {
MklConvFwdParams convFwdDims(src_dims, filter_dims, bias_dims,
dst_dims_mkl_order, strides, dilations,
padding_left, padding_right);
- conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(convFwdDims);
+ conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(
+ convFwdDims, do_not_cache);
} else {
MklConvFwdParams convFwdDims(src_dims, filter_dims, NONE_DIMS,
dst_dims_mkl_order, strides, dilations,
padding_left, padding_right);
- conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(convFwdDims);
+ conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(
+ convFwdDims, do_not_cache);
}
// allocate output tensors output_tensor and filter_out_tensor
@@ -952,6 +970,9 @@ class MklConvOp : public OpKernel {
} else {
conv_fwd->Execute(src_data, filter_data, dst_data);
}
+
+ // delete primitive since it is not cached.
+ if (do_not_cache) delete conv_fwd;
} catch (mkldnn::error &e) {
string error_msg = tensorflow::strings::StrCat(
"Status: ", e.status, ", message: ", string(e.message), ", in file ",
@@ -1062,7 +1083,7 @@ class MklConvOp : public OpKernel {
#endif
// Register 2D operations
-#define REGISTER_MKL_CPU(T) \
+#define REGISTER_MKL_CPU_2D(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
@@ -1079,16 +1100,16 @@ class MklConvOp : public OpKernel {
.Label(mkl_op_registry::kMklOpLabel), \
MklDummyOp<CPUDevice, T>);
-TF_CALL_float(REGISTER_MKL_CPU);
+TF_CALL_float(REGISTER_MKL_CPU_2D);
// Register 3D operations
-#define REGISTER_MKL_CPU(T) \
+#define REGISTER_MKL_CPU_3D(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv3D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
MklConvOp<CPUDevice, T, false>);
-TF_CALL_float(REGISTER_MKL_CPU);
+TF_CALL_float(REGISTER_MKL_CPU_3D);
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_conv_ops_test.cc b/tensorflow/core/kernels/mkl_conv_ops_test.cc
new file mode 100644
index 0000000000..a055351337
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_conv_ops_test.cc
@@ -0,0 +1,407 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#include "third_party/intel_mkl_dnn/include/mkldnn.h"
+#include "tensorflow/core/util/mkl_util.h"
+#endif
+
+// TODO(ezhulenev): Add numerical tests that will compare results of default
+// (aka Eigen) convolutions with MKL convolutions.
+
+// -------------------------------------------------------------------------- //
+// Performance Benchmarks. //
+// -------------------------------------------------------------------------- //
+
+// Compare performance of default Tensorflow convolution kernels (Eigen) with
+// MKL kernels on CPU.
+
+// Before running these benchmarks configure OpenMP environment variables:
+// export KMP_BLOCKTIME=0
+// export OMP_NUM_THREADS=${num_threads}
+
+namespace tensorflow {
+
+struct Conv2DDimensions {
+ Conv2DDimensions(int n, int h, int w, int c, int fc, int fh, int fw)
+ : input_batches(n),
+ input_height(h),
+ input_width(w),
+ input_depth(c),
+ filter_count(fc),
+ filter_height(fh),
+ filter_width(fw) {}
+
+ int input_batches;
+ int input_height;
+ int input_width;
+ int input_depth;
+ int filter_count;
+ int filter_height;
+ int filter_width;
+};
+
+static Tensor GetRandomTensor(const TensorShape& shape) {
+ Tensor tensor(DT_FLOAT, TensorShape(shape));
+ tensor.flat<float>() = tensor.flat<float>().setRandom();
+ return tensor;
+}
+
+// Get a random Tensor for the Conv2D input.
+static Tensor GetRandomInputTensor(const Conv2DDimensions& dims) {
+ return GetRandomTensor({dims.input_batches, dims.input_height,
+ dims.input_width, dims.input_depth});
+}
+
+// Get a random Tensor for the Conv2D filter.
+static Tensor GetRandomFilterTensor(const Conv2DDimensions& dims) {
+ return GetRandomTensor({dims.filter_height, dims.filter_width,
+ dims.input_depth, dims.filter_count});
+}
+
+// Get a random Tensor for the Conv2D output (assuming SAME padding).
+static Tensor GetRandomOutputTensor(const Conv2DDimensions& dims) {
+ return GetRandomTensor({dims.input_batches, dims.input_height,
+ dims.input_width, dims.filter_count});
+}
+
+// Get a Tensor encoding Conv2D input shape.
+static Tensor GetInputSizesTensor(const Conv2DDimensions& dims) {
+ return test::AsTensor<int32>({dims.input_batches, dims.input_height,
+ dims.input_width, dims.input_depth});
+}
+
+// Get a Tensor encoding Conv2D filter shape.
+static Tensor GetFilterSizesTensor(const Conv2DDimensions& dims) {
+ return test::AsTensor<int32>({dims.filter_height, dims.filter_width,
+ dims.input_depth, dims.filter_count});
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Tensor NonMklTensor() {
+ MklDnnShape non_mkl_shape;
+ non_mkl_shape.SetMklTensor(false);
+
+ auto size = static_cast<int64>(non_mkl_shape.GetSerializeBufferSize());
+ Tensor tensor(DT_UINT8, {size});
+
+ non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(),
+ size * sizeof(uint8));
+ return tensor;
+}
+#endif
+
+static Graph* DefaultConv2D(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+
+ Node* conv2d;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d"), "Conv2D")
+ .Input(input)
+ .Input(filter)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Finalize(graph, &conv2d));
+
+ return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2D(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+
+ Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+ Node* conv2d;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("mkl_conv_2d"), "_MklConv2D")
+ .Input(input)
+ .Input(filter)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Attr("_kernel", "MklOp")
+ .Finalize(graph, &conv2d));
+
+ return graph;
+}
+#endif
+
+static Graph* DefaultConv2DBwdInput(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_sizes_t = GetInputSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input_sizes =
+ test::graph::Constant(graph, input_sizes_t, "input_sizes");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* conv2d_bwd_input;
+ TF_CHECK_OK(
+ NodeBuilder(graph->NewName("conv_2d_bwd_input"), "Conv2DBackpropInput")
+ .Input(input_sizes)
+ .Input(filter)
+ .Input(out_backprop)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Finalize(graph, &conv2d_bwd_input));
+
+ return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2DBwdInput(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_sizes_t = GetInputSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input_sizes =
+ test::graph::Constant(graph, input_sizes_t, "input_sizes");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+ Node* conv2d_bwd_input;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d_bwd_input"),
+ "_MklConv2DBackpropInput")
+ .Input(input_sizes)
+ .Input(filter)
+ .Input(out_backprop)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Attr("_kernel", "MklOp")
+ .Finalize(graph, &conv2d_bwd_input));
+
+ return graph;
+}
+#endif
+
+static Graph* DefaultConv2DBwdFilter(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_sizes_t = GetFilterSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter_sizes =
+ test::graph::Constant(graph, filter_sizes_t, "filter_sizes");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* conv2d_bwd_filter;
+ TF_CHECK_OK(
+ NodeBuilder(graph->NewName("conv_2d_bwd_filter"), "Conv2DBackpropFilter")
+ .Input(input)
+ .Input(filter_sizes)
+ .Input(out_backprop)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Finalize(graph, &conv2d_bwd_filter));
+
+ return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2DBwdFilter(const Conv2DDimensions& dims) {
+ Graph* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_sizes_t = GetFilterSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter_sizes =
+ test::graph::Constant(graph, filter_sizes_t, "filter_sizes");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+ Node* conv2d_bwd_filter;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d_bwd_filter"),
+ "_MklConv2DBackpropFilter")
+ .Input(input)
+ .Input(filter_sizes)
+ .Input(out_backprop)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Attr("_kernel", "MklOp")
+ .Finalize(graph, &conv2d_bwd_filter));
+
+ return graph;
+}
+#endif
+
+// Macro arguments names: --------------------------------------------------- //
+// N: batch size
+// H: height
+// W: width
+// C: channels
+// FC: filter count
+// FH: filter height
+// FW: filter width
+
+#define BM_CONCAT(a, b) a##b
+
+#define BM_NAME(p, type, N, H, W, C, FC, FH, FW) \
+ BM_CONCAT(BM_##p##_##type##_in_##N##_##H##_##W##_##C, _f_##FC##_##FH##_##FW)
+
+// Flops computation in these benchmarks are the same as in
+// eigen_benchmark_cpu_test.cc.
+
+#define BM_Conv2DT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \
+ static void BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ testing::SetLabel(LABEL); \
+ \
+ int64 num_computed_elements = (N) * (H) * (W) * (FC); \
+ int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \
+ testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
+ \
+ Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \
+ test::Benchmark(#type, BM_CONCAT(kind, Conv2D)(dims)).Run(iters); \
+ } \
+ BENCHMARK(BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+ BM_Conv2DT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+#define BM_Conv2DBwdInputT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \
+ static void BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ testing::SetLabel(LABEL); \
+ \
+ int64 num_computed_elements = (N) * (H) * (W) * (C); \
+ int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \
+ testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
+ \
+ Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \
+ test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdInput)(dims)).Run(iters); \
+ } \
+ BENCHMARK(BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+ BM_Conv2DBwdInputT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+#define BM_Conv2DBwdFilterT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \
+ static void BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ testing::SetLabel(LABEL); \
+ \
+ int64 num_computed_elements = (FH) * (FW) * (C) * (FC); \
+ int64 flops_per_iter = num_computed_elements * ((N) * (H) * (W)); \
+ testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
+ \
+ Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \
+ test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdFilter)(dims)).Run(iters); \
+ } \
+ BENCHMARK(BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+ BM_Conv2DBwdFilterT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+// ImageNet Convolutions ---------------------------------------------------- //
+
+BM_Conv2D(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2D(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2D(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2D(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2D(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2D(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2D(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+BM_Conv2DBwdInput(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2DBwdInput(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2DBwdInput(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2DBwdInput(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2DBwdInput(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2DBwdInput(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2DBwdInput(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+BM_Conv2DBwdFilter(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2DBwdFilter(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2DBwdFilter(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2DBwdFilter(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2DBwdFilter(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2DBwdFilter(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2DBwdFilter(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
index ec6d241e17..5398e6113f 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
@@ -34,11 +34,11 @@ using mkldnn::prop_kind;
template <typename T>
void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
- if (fwdParams.alg_kind != pooling_max && fwdParams.alg_kind != pooling_avg &&
- fwdParams.alg_kind != pooling_avg_include_padding &&
- fwdParams.alg_kind != pooling_avg_exclude_padding) {
- assert("Pooling algorithm kind is not supported\n");
- }
+ DCHECK(fwdParams.alg_kind == pooling_max ||
+ fwdParams.alg_kind == pooling_avg ||
+ fwdParams.alg_kind == pooling_avg_include_padding ||
+ fwdParams.alg_kind == pooling_avg_exclude_padding)
+ << "Pooling algorithm kind is not supported";
context_.alg_kind = fwdParams.alg_kind;
// create memory desc
@@ -102,7 +102,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
if (context_.alg_kind == pooling_max) { // max pooling must have ws
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(ws_data);
}
context_.fwd_stream->submit(context_.fwd_primitives);
@@ -111,7 +111,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
context_.src_mem->set_data_handle(DummyData);
context_.dst_mem->set_data_handle(DummyData);
if (context_.alg_kind == pooling_max) { // max pooling must have ws
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(DummyData);
}
}
@@ -120,11 +120,11 @@ template class MklPoolingFwdPrimitive<float>;
template <typename T>
void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
- if (bwdParams.alg_kind != pooling_max && bwdParams.alg_kind != pooling_avg &&
- bwdParams.alg_kind != pooling_avg_include_padding &&
- bwdParams.alg_kind != pooling_avg_exclude_padding) {
- assert("Pooling algorithm kind is not supported\n");
- }
+ DCHECK(bwdParams.alg_kind == pooling_max ||
+ bwdParams.alg_kind == pooling_avg ||
+ bwdParams.alg_kind == pooling_avg_include_padding ||
+ bwdParams.alg_kind == pooling_avg_exclude_padding)
+ << "Pooling algorithm kind is not supported";
context_.alg_kind = bwdParams.alg_kind;
// check whether it is 2d or 3d
@@ -190,7 +190,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
if (context_.alg_kind == pooling_max) {
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(const_cast<void*>(ws_data));
}
@@ -199,7 +199,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
context_.diff_dst_mem->set_data_handle(DummyData);
context_.diff_src_mem->set_data_handle(DummyData);
if (context_.alg_kind == pooling_max) {
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(DummyData);
}
}
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index f4cfc48af5..84385356e1 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -40,7 +40,6 @@ using mkldnn::memory;
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
#endif
-#include "tensorflow/core/platform/default/logging.h"
#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index 8bde966be9..cfab529662 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -50,6 +50,7 @@ class MklSoftmaxOp : public OpKernel {
// src_tensor now points to the 0-th input of global data struct "context"
size_t src_idx = 0;
const Tensor& src_tensor = MklGetInput(context, src_idx);
+ const int input_dims = src_tensor.dims();
// Add: get MklShape
MklDnnShape src_mkl_shape;
@@ -62,7 +63,33 @@ class MklSoftmaxOp : public OpKernel {
: src_tensor.shape();
auto src_dims = TFShapeToMklDnnDims(src_tf_shape);
auto output_dims = src_dims;
-
+ memory::format layout_type;
+ // In MKL, data format passed to mkl softmax op depends on dimension of the input tensor.
+ // Here "x" data format in MKL is used for 1 dim tensor, "nc" for 2 dim tensor,
+ // "tnc" for 3 dim tensor, "nchw" for 4 dim tensor, and "ncdhw" for 5 dim tensor.
+ // Each of the simbols has the following meaning:
+ // n = batch, c = channels, t = sequence lenght, h = height,
+ // w = width, d = depth
+ switch (input_dims) {
+ case 1:
+ layout_type = memory::format::x;
+ break;
+ case 2:
+ layout_type = memory::format::nc;
+ break;
+ case 3:
+ layout_type = memory::format::tnc;
+ break;
+ case 4:
+ layout_type = memory::format::nchw;
+ break;
+ case 5:
+ layout_type = memory::format::ncdhw;
+ break;
+ default:
+ OP_REQUIRES_OK(context, errors::Aborted("Input dims must be <= 5 and >=1"));
+ return;
+ }
// Create softmax memory for src, dst: both are defined in mkl_util.h,
// they are wrapper
MklDnnData<T> src(&cpu_engine);
@@ -75,7 +102,7 @@ class MklSoftmaxOp : public OpKernel {
auto src_md =
src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
- : memory::desc(src_dims, MklDnnType<T>(), memory::format::nc);
+ : memory::desc(src_dims, MklDnnType<T>(), layout_type);
// src: setting memory descriptor and op memory descriptor
// Basically following two functions maps the TF "src_tensor" to mkl
@@ -84,10 +111,11 @@ class MklSoftmaxOp : public OpKernel {
// data format is "nc" for src and dst; since the src and dst buffer is
// always in 2D shape
src.SetUsrMem(src_md, &src_tensor);
- src.SetOpMemDesc(src_dims, memory::format::nc);
+ src.SetOpMemDesc(src_dims, layout_type);
// creating a memory descriptor
- int axis = 1; // axis to which softmax will be applied
+ // passing outermost dim as default axis, where the softmax is applied
+ int axis = input_dims - 1;
auto softmax_fwd_desc = softmax_forward::desc(prop_kind::forward_scoring,
src.GetOpMemDesc(), axis);
auto softmax_fwd_pd =
@@ -107,7 +135,7 @@ class MklSoftmaxOp : public OpKernel {
output_mkl_shape.SetMklLayout(&dst_pd);
output_mkl_shape.SetElemType(MklDnnType<T>());
output_mkl_shape.SetTfLayout(output_dims.size(), output_dims,
- memory::format::nc);
+ layout_type);
output_tf_shape.AddDim((dst_pd.get_size() / sizeof(T)));
} else { // then output is also TF shape
output_mkl_shape.SetMklTensor(false);
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index 5d9257e20b..81ce6d6e95 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -75,28 +75,28 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
}
// Return intersection-over-union overlap between boxes i and j
-static inline float IOUGreaterThanThreshold(
- typename TTypes<float, 2>::ConstTensor boxes, int i, int j,
- float iou_threshold) {
- const float ymin_i = std::min<float>(boxes(i, 0), boxes(i, 2));
- const float xmin_i = std::min<float>(boxes(i, 1), boxes(i, 3));
- const float ymax_i = std::max<float>(boxes(i, 0), boxes(i, 2));
- const float xmax_i = std::max<float>(boxes(i, 1), boxes(i, 3));
- const float ymin_j = std::min<float>(boxes(j, 0), boxes(j, 2));
- const float xmin_j = std::min<float>(boxes(j, 1), boxes(j, 3));
- const float ymax_j = std::max<float>(boxes(j, 0), boxes(j, 2));
- const float xmax_j = std::max<float>(boxes(j, 1), boxes(j, 3));
- const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
- const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
- if (area_i <= 0 || area_j <= 0) return 0.0;
- const float intersection_ymin = std::max<float>(ymin_i, ymin_j);
- const float intersection_xmin = std::max<float>(xmin_i, xmin_j);
- const float intersection_ymax = std::min<float>(ymax_i, ymax_j);
- const float intersection_xmax = std::min<float>(xmax_i, xmax_j);
- const float intersection_area =
- std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
- std::max<float>(intersection_xmax - intersection_xmin, 0.0);
- const float iou = intersection_area / (area_i + area_j - intersection_area);
+template <typename T>
+static inline bool IOUGreaterThanThreshold(
+ typename TTypes<T, 2>::ConstTensor boxes, int i, int j, T iou_threshold) {
+ const T ymin_i = std::min<T>(boxes(i, 0), boxes(i, 2));
+ const T xmin_i = std::min<T>(boxes(i, 1), boxes(i, 3));
+ const T ymax_i = std::max<T>(boxes(i, 0), boxes(i, 2));
+ const T xmax_i = std::max<T>(boxes(i, 1), boxes(i, 3));
+ const T ymin_j = std::min<T>(boxes(j, 0), boxes(j, 2));
+ const T xmin_j = std::min<T>(boxes(j, 1), boxes(j, 3));
+ const T ymax_j = std::max<T>(boxes(j, 0), boxes(j, 2));
+ const T xmax_j = std::max<T>(boxes(j, 1), boxes(j, 3));
+ const T area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
+ const T area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
+ if (area_i <= static_cast<T>(0) || area_j <= static_cast<T>(0)) return 0;
+ const T intersection_ymin = std::max<T>(ymin_i, ymin_j);
+ const T intersection_xmin = std::max<T>(xmin_i, xmin_j);
+ const T intersection_ymax = std::min<T>(ymax_i, ymax_j);
+ const T intersection_xmax = std::min<T>(xmax_i, xmax_j);
+ const T intersection_area =
+ std::max<T>(intersection_ymax - intersection_ymin, static_cast<T>(0.0)) *
+ std::max<T>(intersection_xmax - intersection_xmin, static_cast<T>(0.0));
+ const T iou = intersection_area / (area_i + area_j - intersection_area);
return iou > iou_threshold;
}
@@ -106,11 +106,13 @@ static inline bool OverlapsGreaterThanThreshold(
return overlaps(i, j) > overlap_threshold;
}
+template <typename T>
static inline std::function<bool(int, int)> CreateIOUSuppressCheckFn(
const Tensor& boxes, float threshold) {
- typename TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>();
- return std::bind(&IOUGreaterThanThreshold, boxes_data, std::placeholders::_1,
- std::placeholders::_2, threshold);
+ typename TTypes<T, 2>::ConstTensor boxes_data = boxes.tensor<T, 2>();
+ return std::bind(&IOUGreaterThanThreshold<T>, boxes_data,
+ std::placeholders::_1, std::placeholders::_2,
+ static_cast<T>(threshold));
}
static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
@@ -121,6 +123,7 @@ static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
std::placeholders::_1, std::placeholders::_2, threshold);
}
+template <typename T>
void DoNonMaxSuppressionOp(
OpKernelContext* context, const Tensor& scores, int num_boxes,
const Tensor& max_output_size, const float score_threshold,
@@ -128,13 +131,13 @@ void DoNonMaxSuppressionOp(
bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) {
const int output_size = max_output_size.scalar<int>()();
- std::vector<float> scores_data(num_boxes);
- std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
+ std::vector<T> scores_data(num_boxes);
+ std::copy_n(scores.flat<T>().data(), num_boxes, scores_data.begin());
// Data structure for selection candidate in NMS.
struct Candidate {
int box_index;
- float score;
+ T score;
};
auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
@@ -143,13 +146,13 @@ void DoNonMaxSuppressionOp(
std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)>
candidate_priority_queue(cmp);
for (int i = 0; i < scores_data.size(); ++i) {
- if (scores_data[i] > score_threshold) {
+ if (static_cast<float>(scores_data[i]) > score_threshold) {
candidate_priority_queue.emplace(Candidate({i, scores_data[i]}));
}
}
std::vector<int> selected;
- std::vector<float> selected_scores;
+ std::vector<T> selected_scores;
Candidate next_candidate;
while (selected.size() < output_size && !candidate_priority_queue.empty()) {
@@ -176,7 +179,7 @@ void DoNonMaxSuppressionOp(
int num_valid_outputs = selected.size();
if (pad_to_max_output_size) {
selected.resize(output_size, 0);
- selected_scores.resize(output_size, 0);
+ selected_scores.resize(output_size, static_cast<T>(0));
}
if (ptr_num_valid_outputs) {
*ptr_num_valid_outputs = num_valid_outputs;
@@ -221,18 +224,19 @@ class NonMaxSuppressionOp : public OpKernel {
if (!context->status().ok()) {
return;
}
- auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_);
+ auto suppress_check_fn =
+ CreateIOUSuppressCheckFn<float>(boxes, iou_threshold_);
const float score_threshold_val = std::numeric_limits<float>::lowest();
- DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
- score_threshold_val, suppress_check_fn);
+ DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
private:
float iou_threshold_;
};
-template <typename Device>
+template <typename Device, typename T>
class NonMaxSuppressionV2Op : public OpKernel {
public:
explicit NonMaxSuppressionV2Op(OpKernelConstruction* context)
@@ -264,11 +268,12 @@ class NonMaxSuppressionV2Op : public OpKernel {
if (!context->status().ok()) {
return;
}
- auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val);
+ auto suppress_check_fn =
+ CreateIOUSuppressCheckFn<T>(boxes, iou_threshold_val);
const float score_threshold_val = std::numeric_limits<float>::lowest();
- DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
- score_threshold_val, suppress_check_fn);
+ DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
};
@@ -325,7 +330,7 @@ class NonMaxSuppressionV3V4Base : public OpKernel {
float score_threshold_val_;
};
-template <typename Device>
+template <typename Device, typename T>
class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base {
public:
explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
@@ -334,14 +339,14 @@ class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base {
protected:
void DoComputeAndPostProcess(OpKernelContext* context) override {
auto suppress_check_fn =
- CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+ CreateIOUSuppressCheckFn<T>(boxes_, iou_threshold_val_);
- DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
- score_threshold_val_, suppress_check_fn);
+ DoNonMaxSuppressionOp<T>(context, scores_, num_boxes_, max_output_size_,
+ score_threshold_val_, suppress_check_fn);
}
};
-template <typename Device>
+template <typename Device, typename T>
class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base {
public:
explicit NonMaxSuppressionV4Op(OpKernelConstruction* context)
@@ -353,12 +358,12 @@ class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base {
protected:
void DoComputeAndPostProcess(OpKernelContext* context) override {
auto suppress_check_fn =
- CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+ CreateIOUSuppressCheckFn<T>(boxes_, iou_threshold_val_);
int num_valid_outputs;
- DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
- score_threshold_val_, suppress_check_fn,
- pad_to_max_output_size_, &num_valid_outputs);
+ DoNonMaxSuppressionOp<T>(context, scores_, num_boxes_, max_output_size_,
+ score_threshold_val_, suppress_check_fn,
+ pad_to_max_output_size_, &num_valid_outputs);
// Allocate scalar output tensor for number of indices computed.
Tensor* num_outputs_t = nullptr;
@@ -413,22 +418,37 @@ class NonMaxSuppressionWithOverlapsOp : public OpKernel {
auto suppress_check_fn =
CreateOverlapsSuppressCheckFn(overlaps, overlap_threshold_val);
- DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
- score_threshold_val, suppress_check_fn);
+ DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
};
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
NonMaxSuppressionOp<CPUDevice>);
-REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU),
- NonMaxSuppressionV2Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+ Name("NonMaxSuppressionV2").TypeConstraint<float>("T").Device(DEVICE_CPU),
+ NonMaxSuppressionV2Op<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2")
+ .TypeConstraint<Eigen::half>("T")
+ .Device(DEVICE_CPU),
+ NonMaxSuppressionV2Op<CPUDevice, Eigen::half>);
-REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU),
- NonMaxSuppressionV3Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+ Name("NonMaxSuppressionV3").TypeConstraint<float>("T").Device(DEVICE_CPU),
+ NonMaxSuppressionV3Op<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3")
+ .TypeConstraint<Eigen::half>("T")
+ .Device(DEVICE_CPU),
+ NonMaxSuppressionV3Op<CPUDevice, Eigen::half>);
-REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").Device(DEVICE_CPU),
- NonMaxSuppressionV4Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+ Name("NonMaxSuppressionV4").TypeConstraint<float>("T").Device(DEVICE_CPU),
+ NonMaxSuppressionV4Op<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4")
+ .TypeConstraint<Eigen::half>("T")
+ .Device(DEVICE_CPU),
+ NonMaxSuppressionV4Op<CPUDevice, Eigen::half>);
REGISTER_KERNEL_BUILDER(
Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index 876a1704c7..fc1c9003aa 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/placer.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/function.h"
@@ -104,13 +105,6 @@ class PartitionedCallOp : public AsyncOpKernel {
for (auto d : lib->device_mgr()->ListDevices()) {
device_set.AddDevice(d);
}
- Placer placer(graph.get(), &device_set);
- OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
-
- std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
- OP_REQUIRES_OK_ASYNC(
- ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
- done);
// The FunctionLibraryRuntime's library cannot be mutated from within
// an OpKernel, so functions are instantiated in an overlay library.
@@ -124,6 +118,47 @@ class PartitionedCallOp : public AsyncOpKernel {
new FunctionLibraryDefinition(*lib->GetFunctionLibraryDefinition());
overlay_libs_.emplace(lib, overlay_lib);
+ GraphOptimizationPassOptions optimization_options;
+ // TODO(akshayka): Thread SessionOptions (if any) into this kernel, or
+ // make it possible to specify the relevant options via attributes.
+ SessionOptions session_options;
+ session_options.env = ctx->env();
+ optimization_options.session_options = &session_options;
+ optimization_options.graph = &graph;
+ optimization_options.flib_def = overlay_lib;
+ optimization_options.device_set = &device_set;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::PRE_PLACEMENT, optimization_options),
+ done);
+ Placer placer(graph.get(), &device_set);
+ OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::POST_PLACEMENT, optimization_options),
+ done);
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::POST_REWRITE_FOR_EXEC,
+ optimization_options),
+ done);
+
+ std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
+ done);
+ optimization_options.graph = nullptr;
+ optimization_options.device_set = nullptr;
+ optimization_options.partition_graphs = &subgraphs;
+ OP_REQUIRES_OK_ASYNC(ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::POST_PARTITIONING,
+ optimization_options),
+ done);
+
auto handles = tensorflow::MakeUnique<gtl::FlatMap<string, FHandle>>();
for (const auto& pair : subgraphs) {
// TODO(akshayka): Fail gracefully if the set of devices corresponds
@@ -175,7 +210,7 @@ class PartitionedCallOp : public AsyncOpKernel {
TF_RETURN_IF_ERROR(node->attrs().Find("T", &attr_value));
DataType dtype = attr_value->type();
if (dtype == DT_RESOURCE) {
- ResourceHandle handle = args[index].flat<ResourceHandle>()(0);
+ const ResourceHandle& handle = args[index].flat<ResourceHandle>()(0);
node->set_assigned_device_name(handle.device());
}
}
diff --git a/tensorflow/core/kernels/poisson-loss.h b/tensorflow/core/kernels/poisson-loss.h
new file mode 100644
index 0000000000..f91244454e
--- /dev/null
+++ b/tensorflow/core/kernels/poisson-loss.h
@@ -0,0 +1,109 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_POISSON_LOSS_H_
+#define TENSORFLOW_CORE_KERNELS_POISSON_LOSS_H_
+
+#include <cmath>
+
+#include "tensorflow/core/kernels/loss.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+class PoissonLossUpdater : public DualLossUpdater {
+ public:
+ // Update is found by a Newton algorithm (see readme.md).
+ double ComputeUpdatedDual(const int num_loss_partitions, const double label,
+ const double example_weight,
+ const double current_dual, const double wx,
+ const double weighted_example_norm) const final {
+ // Newton algorithm converges quadratically so 10 steps will be largely
+ // enough to achieve a very good precision
+ static const int newton_total_steps = 10;
+ // Initialize the Newton optimization at x such that
+ // exp(x) = label - current_dual
+ const double y_minus_a = label - current_dual;
+ double x = (y_minus_a > 0) ? log(y_minus_a) : 0;
+ for (int i = 0; i < newton_total_steps; ++i) {
+ x = NewtonStep(x, num_loss_partitions, label, wx, example_weight,
+ weighted_example_norm, current_dual);
+ }
+ return label - exp(x);
+ }
+
+ // Dual of poisson loss function.
+ // https://en.wikipedia.org/wiki/Convex_conjugate
+ double ComputeDualLoss(const double current_dual, const double example_label,
+ const double example_weight) const final {
+ // Dual of the poisson loss function is
+ // (y-a)*(log(y-a)-1), where a is the dual variable.
+ // It is defined only for a<y.
+ const double y_minus_a = example_label - current_dual;
+ if (y_minus_a == 0.0) {
+ // (y-a)*(log(y-a)-1) approaches 0 as y-a approaches 0.
+ return 0.0;
+ }
+ if (y_minus_a < 0.0) {
+ return std::numeric_limits<double>::max();
+ }
+ return y_minus_a * (log(y_minus_a) - 1) * example_weight;
+ }
+
+ double ComputePrimalLoss(const double wx, const double example_label,
+ const double example_weight) const final {
+ return (exp(wx) - wx * example_label) * example_weight;
+ }
+
+ double PrimalLossDerivative(const double wx, const double label,
+ const double example_weight) const final {
+ return (exp(wx) - label) * example_weight;
+ }
+
+ // TODO(chapelle): We need to introduce a maximum_prediction parameter,
+ // expose that parameter to the user and have this method return
+ // 1.0/maximum_prediction.
+ // Setting this at 1 for now, it only impacts the adaptive sampling.
+ double SmoothnessConstant() const final { return 1; }
+
+ Status ConvertLabel(float* const example_label) const final {
+ if (*example_label < 0.0) {
+ return errors::InvalidArgument(
+ "Only non-negative labels can be used with the Poisson log loss. "
+ "Found example with label: ", *example_label);
+ }
+ return Status::OK();
+ }
+
+ private:
+ // One Newton step (see readme.md).
+ double NewtonStep(const double x, const int num_loss_partitions,
+ const double label, const double wx,
+ const double example_weight,
+ const double weighted_example_norm,
+ const double current_dual) const {
+ const double expx = exp(x);
+ const double numerator =
+ x - wx - num_loss_partitions * weighted_example_norm *
+ example_weight * (label - current_dual - expx);
+ const double denominator =
+ 1 + num_loss_partitions * weighted_example_norm * example_weight * expx;
+ return x - numerator / denominator;
+ }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_LOGISTIC_LOSS_H_
diff --git a/tensorflow/core/kernels/qr_op_complex128.cc b/tensorflow/core/kernels/qr_op_complex128.cc
index d8d589f5aa..8a3e3dc0a9 100644
--- a/tensorflow/core/kernels/qr_op_complex128.cc
+++ b/tensorflow/core/kernels/qr_op_complex128.cc
@@ -24,7 +24,13 @@ REGISTER_LINALG_OP("Qr", (QrOp<complex128>), complex128);
// cuSolver affecting older hardware. The cuSolver team is tracking the issue
// (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable
// this feature when a fix is available.
-// REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex128>), complex128);
+REGISTER_KERNEL_BUILDER(Name("Qr")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<complex128>("T")
+ .HostMemory("input")
+ .HostMemory("q")
+ .HostMemory("r"),
+ QrOp<complex128>);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_double.cc b/tensorflow/core/kernels/qr_op_double.cc
index 63f2e03b3b..05537a0eaa 100644
--- a/tensorflow/core/kernels/qr_op_double.cc
+++ b/tensorflow/core/kernels/qr_op_double.cc
@@ -24,7 +24,13 @@ REGISTER_LINALG_OP("Qr", (QrOp<double>), double);
// cuSolver affecting older hardware. The cuSolver team is tracking the issue
// (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable
// this feature when a fix is available.
-// REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<double>), double);
+REGISTER_KERNEL_BUILDER(Name("Qr")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<double>("T")
+ .HostMemory("input")
+ .HostMemory("q")
+ .HostMemory("r"),
+ QrOp<double>);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_float.cc b/tensorflow/core/kernels/qr_op_float.cc
index 0b1a0aaa76..6aebd98186 100644
--- a/tensorflow/core/kernels/qr_op_float.cc
+++ b/tensorflow/core/kernels/qr_op_float.cc
@@ -24,7 +24,13 @@ REGISTER_LINALG_OP("Qr", (QrOp<float>), float);
// cuSolver affecting older hardware. The cuSolver team is tracking the issue
// (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable
// this feature when a fix is available.
-// REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<float>), float);
+REGISTER_KERNEL_BUILDER(Name("Qr")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .HostMemory("input")
+ .HostMemory("q")
+ .HostMemory("r"),
+ QrOp<float>);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc
index c4d404259b..97ddc852f7 100644
--- a/tensorflow/core/kernels/queue_ops.cc
+++ b/tensorflow/core/kernels/queue_ops.cc
@@ -65,7 +65,7 @@ class FakeQueueOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
- ResourceHandle ref = context->input(0).flat<ResourceHandle>()(0);
+ const ResourceHandle& ref = context->input(0).flat<ResourceHandle>()(0);
handle_.AccessTensor(context)->flat<string>()(0) = ref.container();
handle_.AccessTensor(context)->flat<string>()(1) = ref.name();
context->set_output_ref(0, &mu_, handle_.AccessTensor(context));
diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc
index 9cf953f4bf..8bfa44b2d0 100644
--- a/tensorflow/core/kernels/reduction_ops_max.cc
+++ b/tensorflow/core/kernels/reduction_ops_max.cc
@@ -50,6 +50,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, Eigen::internal::MaxReducer<type>>);
+
+REGISTER_GPU_KERNELS(Eigen::half);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);
REGISTER_GPU_KERNELS(int64);
diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc
index 5318d8c133..e4ca89eca3 100644
--- a/tensorflow/core/kernels/reduction_ops_sum.cc
+++ b/tensorflow/core/kernels/reduction_ops_sum.cc
@@ -76,7 +76,15 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("output")
.HostMemory("reduction_indices"),
ReductionOp<CPUDevice, int32, int64, Eigen::internal::SumReducer<int32>>);
-
+REGISTER_KERNEL_BUILDER(
+ Name("Sum")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int64>("T")
+ .TypeConstraint<int32>("Tidx")
+ .HostMemory("input")
+ .HostMemory("output")
+ .HostMemory("reduction_indices"),
+ ReductionOp<CPUDevice, int64, int32, Eigen::internal::SumReducer<int64>>);
#endif
#ifdef TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/regex_full_match_op.cc b/tensorflow/core/kernels/regex_full_match_op.cc
index 5863a2c8e4..7edaaad8f7 100644
--- a/tensorflow/core/kernels/regex_full_match_op.cc
+++ b/tensorflow/core/kernels/regex_full_match_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -56,4 +57,36 @@ class RegexFullMatchOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU),
RegexFullMatchOp);
+class StaticRegexFullMatchOp : public OpKernel {
+ public:
+ explicit StaticRegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string pattern;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
+ re_ = MakeUnique<RE2>(pattern);
+ OP_REQUIRES(ctx, re_->ok(),
+ errors::InvalidArgument("Invalid pattern: ", pattern,
+ ", error: ", re_->error()));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
+ const auto& input_flat = input_tensor->flat<string>();
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
+ &output_tensor));
+ auto output_flat = output_tensor->flat<bool>();
+ for (size_t i = 0; i < input_flat.size(); ++i) {
+ output_flat(i) = RE2::FullMatch(input_flat(i), *re_);
+ }
+ }
+
+ private:
+ std::unique_ptr<RE2> re_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StaticRegexFullMatch").Device(DEVICE_CPU),
+ StaticRegexFullMatchOp);
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
index 194a711d98..26f107f940 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
@@ -47,7 +47,7 @@ std::unordered_set<string> BuildNodeSetFromNodeNamesAndPorts(
std::unordered_set<string> retval;
for (const string& node_name_and_port : node_names_and_ports) {
const TensorId tid = ParseTensorName(node_name_and_port);
- retval.emplace(std::string(tid.first));
+ retval.emplace(tid.first);
}
return retval;
}
@@ -64,7 +64,7 @@ Node* FindMutableNodeByName(const string& name, Graph* graph) {
const NodeDef* FindNodeDefByName(const string& input,
const GraphDef& graph_def) {
const TensorId tid = ParseTensorName(input);
- const string name = std::string(tid.first);
+ const string name = string(tid.first);
for (const NodeDef& node_def : graph_def.node()) {
if (node_def.name() == name) {
return &node_def;
@@ -423,7 +423,7 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap(
std::vector<DataType> data_types;
std::vector<TensorShape> shapes;
const TensorId tid = ParseTensorName(name_and_port);
- const string node_name = std::string(tid.first);
+ const string node_name(tid.first);
const int port = tid.second;
const NodeDef* node_def = FindNodeDefByName(node_name, graph_def);
CHECK_NOTNULL(node_def);
@@ -522,8 +522,7 @@ RemoteFusedGraphExecuteUtils::GetTensorShapeType(
const TensorShapeMap& tensor_shape_map, const string& node_name) {
if (node_name.find(':') != string::npos) {
const TensorId tid = ParseTensorName(node_name);
- return GetTensorShapeType(tensor_shape_map, std::string(tid.first),
- tid.second);
+ return GetTensorShapeType(tensor_shape_map, string(tid.first), tid.second);
} else {
return GetTensorShapeType(tensor_shape_map, node_name, 0);
}
@@ -570,7 +569,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto(
const TensorId tid = ParseTensorName(name);
CHECK_EQ(tensor_shape_map->count(name), 0);
tensor_shape_map->emplace(
- std::string(tid.first),
+ string(tid.first),
std::make_pair(tid.second,
std::make_pair(tensor.dtype(), tensor.shape())));
}
@@ -692,7 +691,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
std::vector<NodeBuilder::NodeOut> node_out_list;
for (const string& input : inputs) {
const TensorId tid = ParseTensorName(input);
- Node* node = FindMutableNodeByName(std::string(tid.first), graph);
+ Node* node = FindMutableNodeByName(string(tid.first), graph);
CHECK_NOTNULL(node);
node_out_list.emplace_back(node, tid.second);
}
@@ -848,7 +847,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
for (const string& subgraph_input : std::get<1>(cluster)) {
const TensorId tid = ParseTensorName(subgraph_input);
- const string subgraph_input_name = std::string(tid.first);
+ const string subgraph_input_name(tid.first);
const int subgraph_input_port = tid.second;
const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def);
CHECK_NOTNULL(node_def);
@@ -895,7 +894,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
std::deque<const Node*> queue;
for (const string& output : border_outputs) {
const TensorId tid = ParseTensorName(output);
- const string& output_node_name = std::string(tid.first);
+ const string output_node_name(tid.first);
for (const Node* node : graph.nodes()) {
if (output_node_name == node->name()) {
queue.push_back(node);
@@ -975,7 +974,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
for (int j = 0; j < border_outputs.size(); ++j) {
const string& output = border_outputs.at(j);
const TensorId tid = ParseTensorName(output);
- const string output_name = std::string(tid.first);
+ const string output_name(tid.first);
Node* src_node = edge->src();
if (src_node != nullptr && src_node->name() == output_name &&
edge->src_output() == tid.second) {
@@ -995,12 +994,11 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
// RemoteFusedGraphExecuteOpNode
for (const string& output : outputs) {
const TensorId output_tid = ParseTensorName(output);
- const string output_name = std::string(output_tid.first);
+ const string output_name(output_tid.first);
for (size_t i = 0; i < border_outputs.size(); ++i) {
const TensorId subgraph_output_tid =
ParseTensorName(border_outputs.at(i));
- const string& subgraph_output_name =
- std::string(subgraph_output_tid.first);
+ const string subgraph_output_name(subgraph_output_tid.first);
if (output_name == subgraph_output_name) {
LOG(INFO) << "As graph output and subgraph output are same, "
<< "the graph output node is replaced by identity node";
@@ -1435,7 +1433,7 @@ RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
GraphDef* graph_def) {
const TensorId tid = ParseTensorName(input);
CHECK_EQ(0, tid.second);
- const string node_name = std::string(tid.first);
+ const string node_name(tid.first);
for (NodeDef& node : *graph_def->mutable_node()) {
if (node.name() != node_name) {
continue;
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index ebcfb673d1..26705a8d34 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -79,7 +79,7 @@ ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
void ReadVariableOp::Compute(OpKernelContext* ctx) {
Var* variable = nullptr;
- ResourceHandle handle = HandleFromInput(ctx, 0);
+ const ResourceHandle& handle = HandleFromInput(ctx, 0);
const auto status = LookupResource(ctx, handle, &variable);
OP_REQUIRES(ctx, status.ok(),
errors::FailedPrecondition(
diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc
index 15a707a9c6..cded417986 100644
--- a/tensorflow/core/kernels/reverse_sequence_op.cc
+++ b/tensorflow/core/kernels/reverse_sequence_op.cc
@@ -64,7 +64,7 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
"), ", "(", seq_lens.NumElements(),
- " vs. ", input.dim_size(batch_dim)));
+ " vs. ", input.dim_size(batch_dim), ")"));
for (size_t d = 0; d < seq_lens_vec.size(); ++d) {
OP_REQUIRES(context, seq_lens_vec[d] >= 0,
@@ -91,7 +91,7 @@ void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) {
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
"), ", "(", seq_lens.NumElements(),
- " vs. ", input.dim_size(batch_dim)));
+ " vs. ", input.dim_size(batch_dim), ")"));
}
template <>
@@ -127,6 +127,7 @@ class ReverseSequenceOp : public OpKernel {
auto seq_lens_t = seq_lens.vec<Tlen>();
CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_);
+ if (!context->status().ok()) return;
const int input_dims = input.dims();
diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc
index e335e38bdc..82546d581a 100644
--- a/tensorflow/core/kernels/save_restore_tensor.cc
+++ b/tensorflow/core/kernels/save_restore_tensor.cc
@@ -161,9 +161,12 @@ void RestoreTensor(OpKernelContext* context,
// If we cannot find a cached reader we will allocate our own.
std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
- const checkpoint::TensorSliceReader* reader =
- context->slice_reader_cache()->GetReader(file_pattern, open_func,
- preferred_shard);
+ const checkpoint::TensorSliceReader* reader = nullptr;
+
+ if (context->slice_reader_cache()) {
+ reader = context->slice_reader_cache()->GetReader(file_pattern, open_func,
+ preferred_shard);
+ }
if (!reader) {
allocated_reader.reset(new checkpoint::TensorSliceReader(
file_pattern, open_func, preferred_shard));
diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc
index ab4de6c815..180eb3ca34 100644
--- a/tensorflow/core/kernels/save_restore_v2_ops.cc
+++ b/tensorflow/core/kernels/save_restore_v2_ops.cc
@@ -220,9 +220,9 @@ class MergeV2Checkpoints : public OpKernel {
context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix));
if (delete_old_dirs_) {
- const string& merged_dir = std::string(io::Dirname(merged_prefix));
+ const string merged_dir(io::Dirname(merged_prefix));
for (const string& input_prefix : input_prefixes) {
- const string& dirname = std::string(io::Dirname(input_prefix));
+ const string dirname(io::Dirname(input_prefix));
if (dirname == merged_dir) continue;
Status status = env->DeleteDir(dirname);
// For sharded save, only the first delete will go through and all
diff --git a/tensorflow/core/kernels/sdca_internal.cc b/tensorflow/core/kernels/sdca_internal.cc
index 1c071d3d41..a8e9b3261c 100644
--- a/tensorflow/core/kernels/sdca_internal.cc
+++ b/tensorflow/core/kernels/sdca_internal.cc
@@ -251,7 +251,7 @@ Status Examples::SampleAdaptiveProbabilities(
num_weight_vectors);
const double kappa = example_state_data(example_id, 0) +
loss_updater->PrimalLossDerivative(
- example_statistics.wx[0], label, example_weight);
+ example_statistics.wx[0], label, 1.0);
probabilities_[example_id] = example_weight *
sqrt(examples_[example_id].squared_norm_ +
regularization.symmetric_l2() *
diff --git a/tensorflow/core/kernels/sdca_ops.cc b/tensorflow/core/kernels/sdca_ops.cc
index 05c835ebc4..3bd4168dc7 100644
--- a/tensorflow/core/kernels/sdca_ops.cc
+++ b/tensorflow/core/kernels/sdca_ops.cc
@@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/kernels/hinge-loss.h"
#include "tensorflow/core/kernels/logistic-loss.h"
#include "tensorflow/core/kernels/loss.h"
+#include "tensorflow/core/kernels/poisson-loss.h"
#include "tensorflow/core/kernels/sdca_internal.h"
#include "tensorflow/core/kernels/smooth-hinge-loss.h"
#include "tensorflow/core/kernels/squared-loss.h"
@@ -75,6 +76,8 @@ struct ComputeOptions {
loss_updater.reset(new HingeLossUpdater);
} else if (loss_type == "smooth_hinge_loss") {
loss_updater.reset(new SmoothHingeLossUpdater);
+ } else if (loss_type == "poisson_loss") {
+ loss_updater.reset(new PoissonLossUpdater);
} else {
OP_REQUIRES(
context, false,
diff --git a/tensorflow/core/kernels/shape_op_test.cc b/tensorflow/core/kernels/shape_op_test.cc
index 9cd590ae61..30cb1e0a7f 100644
--- a/tensorflow/core/kernels/shape_op_test.cc
+++ b/tensorflow/core/kernels/shape_op_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/abi.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -60,8 +61,7 @@ Status GetShapeFromKnownVecSize(const KnownVecSize& ks, TensorShape* s) {
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE");
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE",
- GetShapeFromKnownVecSize);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, GetShapeFromKnownVecSize);
static void ExpectHasError(const Status& s, StringPiece substr) {
EXPECT_TRUE(str_util::StrContains(s.ToString(), substr))
@@ -94,9 +94,9 @@ TEST_F(ShapeOpTest, Simple) {
Status s = session.Run({{input, variant_tensor}}, {shape_output}, &outputs);
EXPECT_FALSE(s.ok());
ExpectHasError(
- s,
- "No unary variant shape function found for Variant type_name: "
- "NO KNOWN SHAPE");
+ s, strings::StrCat(
+ "No unary variant shape function found for Variant type_index: ",
+ port::MaybeAbiDemangle(MakeTypeIndex<NoKnownShape>().name())));
}
{
diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator.h b/tensorflow/core/kernels/sparse_conditional_accumulator.h
index 11149c4d16..a4453bd7ab 100644
--- a/tensorflow/core/kernels/sparse_conditional_accumulator.h
+++ b/tensorflow/core/kernels/sparse_conditional_accumulator.h
@@ -50,10 +50,10 @@ class SparseConditionalAccumulator
public:
SparseConditionalAccumulator(const DataType& dtype,
const PartialTensorShape& shape,
- const string& name)
+ const string& name, const string& reduction_type)
: TypedConditionalAccumulatorBase<
std::tuple<const Tensor*, const Tensor*, const Tensor*>>(
- dtype, shape, name) {
+ dtype, shape, name, reduction_type) {
accum_idx_vec_ = nullptr;
count_element_ = nullptr;
accum_val_ = nullptr;
diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
index 80bc1f1934..1e542a26a7 100644
--- a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
+++ b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
@@ -34,8 +34,8 @@ class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
Creator GetCreator() const override {
return [this](ConditionalAccumulatorBase** ret) {
SparseConditionalAccumulator<Device, T>* accumulator =
- new SparseConditionalAccumulator<Device, T>(dtype_, shape_,
- cinfo_.name());
+ new SparseConditionalAccumulator<Device, T>(
+ dtype_, shape_, cinfo_.name(), reduction_type_);
*ret = accumulator;
return Status::OK();
};
diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc
index 7cc3c532c9..11db72bfa3 100644
--- a/tensorflow/core/kernels/split_op.cc
+++ b/tensorflow/core/kernels/split_op.cc
@@ -49,7 +49,12 @@ class SplitOpBase : public OpKernel {
void ComputeEasyCases(OpKernelContext* context, bool* done) {
const Tensor& input = context->input(1);
const TensorShape& input_shape = input.shape();
- const int32 split_dim_orig = context->input(0).flat<int32>()(0);
+ const Tensor& split_dim_tensor = context->input(0);
+ OP_REQUIRES(
+ context, split_dim_tensor.shape().dims() == 0,
+ errors::InvalidArgument("split_dim must be a scalar but has rank ",
+ split_dim_tensor.shape().dims()));
+ const int32 split_dim_orig = split_dim_tensor.flat<int32>()(0);
const int32 split_dim =
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
const int32 num_split = num_outputs();
diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc
index 65296f61fd..add4afafc9 100644
--- a/tensorflow/core/kernels/stack_ops.cc
+++ b/tensorflow/core/kernels/stack_ops.cc
@@ -131,10 +131,8 @@ class Stack : public ResourceBase {
};
Status GetStack(OpKernelContext* ctx, Stack** stack) {
- string key;
if (ctx->input_dtype(0) == DT_RESOURCE) {
- auto resource = ctx->input(0).flat<ResourceHandle>()(0);
- key = resource.name();
+ return LookupResource(ctx, HandleFromInput(ctx, 0), stack);
} else {
Tensor Tstack_handle = ctx->mutable_input(0, false);
if (Tstack_handle.NumElements() != 2) {
@@ -144,18 +142,18 @@ Status GetStack(OpKernelContext* ctx, Stack** stack) {
}
const string& container = Tstack_handle.flat<string>()(0);
const string& stack_name = Tstack_handle.flat<string>()(1);
- key = strings::StrCat(container, stack_name);
- }
- ResourceMgr* rm = ctx->resource_manager();
- if (rm == nullptr) {
- return errors::Internal("No resource manager.");
- }
- auto* step_container = ctx->step_container();
- if (step_container == nullptr) {
- return errors::Internal("No step container.");
+ string key = strings::StrCat(container, stack_name);
+ ResourceMgr* rm = ctx->resource_manager();
+ if (rm == nullptr) {
+ return errors::Internal("No resource manager.");
+ }
+ auto* step_container = ctx->step_container();
+ if (step_container == nullptr) {
+ return errors::Internal("No step container.");
+ }
+ TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack));
+ return Status::OK();
}
- TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack));
- return Status::OK();
}
std::atomic<int64> Stack::stack_counter{0};
diff --git a/tensorflow/core/kernels/string_strip_op.cc b/tensorflow/core/kernels/string_strip_op.cc
index 2aeafa28c4..544dca96ba 100644
--- a/tensorflow/core/kernels/string_strip_op.cc
+++ b/tensorflow/core/kernels/string_strip_op.cc
@@ -43,7 +43,7 @@ class StringStripOp : public OpKernel {
for (int64 i = 0; i < input.size(); ++i) {
StringPiece entry(input(i));
str_util::RemoveWhitespaceContext(&entry);
- output(i) = std::string(entry);
+ output(i) = string(entry);
}
}
};
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index 22e45918a0..07f1d6e767 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <cstddef>
+#include <cstdlib>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -25,6 +27,8 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
@@ -64,26 +68,28 @@ class SubstrOp : public OpKernel {
const T len =
tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
- string in = input(i);
+ StringPiece in(input(i));
OP_REQUIRES(
- context, FastBoundsCheck(pos, in.size() + 1),
+ context, FastBoundsCheck(std::abs(pos), in.size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for string",
"b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
} else {
// Perform Op element-wise with tensor pos/len
auto pos_flat = pos_tensor.flat<T>();
auto len_flat = len_tensor.flat<T>();
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
- string in = input(i);
+ StringPiece in(input(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
OP_REQUIRES(
- context, FastBoundsCheck(pos, in.size() + 1),
+ context, FastBoundsCheck(std::abs(pos), in.size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for string",
"b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
}
} else {
@@ -142,14 +148,16 @@ class SubstrOp : public OpKernel {
// Iterate through broadcasted tensors and perform substr
for (int i = 0; i < output_shape.dim_size(0); ++i) {
- string in = input_bcast(i);
+ StringPiece in(input_bcast(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
OP_REQUIRES(
- context, FastBoundsCheck(pos, input_bcast(i).size() + 1),
+ context,
+ FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for string",
"b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
break;
}
@@ -192,16 +200,18 @@ class SubstrOp : public OpKernel {
// Iterate through broadcasted tensors and perform substr
for (int i = 0; i < output_shape.dim_size(0); ++i) {
for (int j = 0; j < output_shape.dim_size(1); ++j) {
- string in = input_bcast(i, j);
+ StringPiece in(input_bcast(i, j));
const T pos =
tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
const T len =
tensorflow::internal::SubtleMustCopy(len_bcast(i, j));
- OP_REQUIRES(context, FastBoundsCheck(pos, in.size() + 1),
- errors::InvalidArgument(
- "pos ", pos, " out of range for ", "string b'",
- in, "' at index (", i, ", ", j, ")"));
- output(i, j) = in.substr(pos, len);
+ OP_REQUIRES(
+ context, FastBoundsCheck(std::abs(pos), in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index (", i,
+ ", ", j, ")"));
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i, j).assign(sub_in.data(), sub_in.size());
}
}
break;
@@ -213,6 +223,16 @@ class SubstrOp : public OpKernel {
}
}
}
+
+ private:
+ // This adjusts the requested position. Note it does not perform any bound
+ // checks.
+ T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
+ if (pos_requested < 0) {
+ return s.size() + pos_requested;
+ }
+ return pos_requested;
+ }
};
#define REGISTER_SUBSTR(type) \
diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc
new file mode 100644
index 0000000000..2e07050260
--- /dev/null
+++ b/tensorflow/core/kernels/substr_op_test.cc
@@ -0,0 +1,105 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Test data from the TensorFlow README.md.
+const char* lines[] = {
+ "**TensorFlow** is an open source software library for numerical "
+ "computation using data flow graphs.",
+ "The graph nodes represent mathematical operations, while the graph edges "
+ "represent the multidimensional data arrays (tensors) that flow between "
+ "them.",
+ "This flexible architecture enables you to deploy computation to one or "
+ "more CPUs or GPUs in a desktop, server, or mobile device without "
+ "rewriting code.",
+ "TensorFlow also includes "
+ "[TensorBoard](https://www.tensorflow.org/guide/"
+ "summaries_and_tensorboard), a data visualization toolkit.",
+ "TensorFlow was originally developed by researchers and engineers working "
+ "on the Google Brain team within Google's Machine Intelligence Research "
+ "organization for the purposes of conducting machine learning and deep "
+ "neural networks research.",
+ "The system is general enough to be applicable in a wide variety of other "
+ "domains, as well.",
+ "TensorFlow provides stable Python API and C APIs as well as without API "
+ "backwards compatibility guarantee like C++, Go, Java, JavaScript and "
+ "Swift."};
+
+Tensor GetTestTensor(int batch) {
+ const int sz = TF_ARRAYSIZE(lines);
+ Tensor t(DT_STRING, {batch});
+ auto s = t.flat<string>();
+ for (int i = 0; i < batch; ++i) {
+ s(i) = lines[i % sz];
+ }
+ return t;
+}
+
+Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor position(DT_INT32, TensorShape({}));
+ position.flat<int32>().setConstant(pos);
+ Tensor length(DT_INT32, TensorShape({}));
+ length.flat<int32>().setConstant(len);
+
+ TF_CHECK_OK(NodeBuilder("substr_op", "Substr")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, position))
+ .Input(test::graph::Constant(g, length))
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+
+void BM_Substr(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupSubstrGraph(input, 3, 30);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_Substr)->Arg(1)->Arg(8)->Arg(16)->Arg(32)->Arg(64)->Arg(128)->Arg(
+ 256);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index 632b65e9b6..fe93b91eb8 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -290,14 +290,14 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
}
} else {
container = "_tensor_arrays";
- auto resource = ctx->input(0).flat<ResourceHandle>()(0);
+ const auto& resource = ctx->input(0).flat<ResourceHandle>()(0);
if (StringPiece(resource.name()).substr(0, container.size()) !=
container) {
return errors::InvalidArgument("Wrong input container. ",
resource.name());
}
tensor_array_name =
- std::string(StringPiece(resource.name()).substr(container.size()));
+ string(StringPiece(resource.name()).substr(container.size()));
}
auto output_handle = tensor_array_output_handle->flat<string>();
diff --git a/tensorflow/core/kernels/typed_conditional_accumulator_base.h b/tensorflow/core/kernels/typed_conditional_accumulator_base.h
index 9dedb618f9..ca341e511e 100644
--- a/tensorflow/core/kernels/typed_conditional_accumulator_base.h
+++ b/tensorflow/core/kernels/typed_conditional_accumulator_base.h
@@ -35,8 +35,9 @@ class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase {
public:
TypedConditionalAccumulatorBase(const DataType& dtype,
const PartialTensorShape& shape,
- const string& name)
- : ConditionalAccumulatorBase(dtype, shape, name) {}
+ const string& name,
+ const string& reduction_type)
+ : ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {}
/**
* Attempts to add a gradient to the accumulator. An ApplyGrad attempt is
diff --git a/tensorflow/core/kernels/unravel_index_op.cc b/tensorflow/core/kernels/unravel_index_op.cc
index 62e814ff77..8d839ba85a 100644
--- a/tensorflow/core/kernels/unravel_index_op.cc
+++ b/tensorflow/core/kernels/unravel_index_op.cc
@@ -97,10 +97,12 @@ class UnravelIndexOp : public OpKernel {
auto output = output_tensor->matrix<Tidx>();
- Eigen::array<int64, 2> reshape{{dims_tensor.NumElements(), 1}};
- Eigen::array<int64, 2> bcast({1, indices_tensor.NumElements()});
- Eigen::array<int64, 2> indices_reshape{{1, indices_tensor.NumElements()}};
- Eigen::array<int64, 2> indices_bcast({dims_tensor.NumElements(), 1});
+ Eigen::array<Eigen::Index, 2> reshape{{dims_tensor.NumElements(), 1}};
+ Eigen::array<Eigen::Index, 2> bcast({1, indices_tensor.NumElements()});
+ Eigen::array<Eigen::Index, 2> indices_reshape{
+ {1, indices_tensor.NumElements()}};
+ Eigen::array<Eigen::Index, 2> indices_bcast(
+ {dims_tensor.NumElements(), 1});
output = indices_tensor.vec<Tidx>()
.reshape(indices_reshape)
diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc
index ed2bf3e8e2..1bf46b5e46 100644
--- a/tensorflow/core/kernels/whole_file_read_ops.cc
+++ b/tensorflow/core/kernels/whole_file_read_ops.cc
@@ -134,7 +134,7 @@ class WriteFileOp : public OpKernel {
"Contents tensor must be scalar, but had shape: ",
contents_input->shape().DebugString()));
const string& filename = filename_input->scalar<string>()();
- const string dir = std::string(io::Dirname(filename));
+ const string dir(io::Dirname(filename));
if (!context->env()->FileExists(dir).ok()) {
OP_REQUIRES_OK(context, context->env()->RecursivelyCreateDir(dir));
}
diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h
index 49a8a4dbd4..d5cbe6c616 100644
--- a/tensorflow/core/lib/core/errors.h
+++ b/tensorflow/core/lib/core/errors.h
@@ -131,11 +131,23 @@ inline string FormatNodeNameForError(const string& name) {
// LINT.ThenChange(//tensorflow/python/client/session.py)
template <typename T>
string FormatNodeNamesForError(const T& names) {
- ::tensorflow::str_util::Formatter<string> f(
- [](string* output, const string& s) {
+ return ::tensorflow::str_util::Join(
+ names, ", ", [](string* output, const string& s) {
::tensorflow::strings::StrAppend(output, FormatNodeNameForError(s));
});
- return ::tensorflow::str_util::Join(names, ", ", f);
+}
+// LINT.IfChange
+inline string FormatColocationNodeForError(const string& name) {
+ return strings::StrCat("{{colocation_node ", name, "}}");
+}
+// LINT.ThenChange(//tensorflow/python/framework/error_interpolation.py)
+template <typename T>
+string FormatColocationNodeForError(const T& names) {
+ return ::tensorflow::str_util::Join(
+ names, ", ", [](string* output, const string& s) {
+ ::tensorflow::strings::StrAppend(output,
+ FormatColocationNodeForError(s));
+ });
}
// The CanonicalCode() for non-errors.
diff --git a/tensorflow/core/lib/core/status.h b/tensorflow/core/lib/core/status.h
index 49f74ff47f..eb0ff555a5 100644
--- a/tensorflow/core/lib/core/status.h
+++ b/tensorflow/core/lib/core/status.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h
index e7b17c9b36..6edff139ae 100644
--- a/tensorflow/core/lib/core/stringpiece.h
+++ b/tensorflow/core/lib/core/stringpiece.h
@@ -26,13 +26,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_
#define TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_
-#include <assert.h>
-#include <stddef.h>
-#include <string.h>
-#include <iosfwd>
-#include <string>
#include "absl/strings/string_view.h"
-#include "tensorflow/core/platform/types.h"
namespace tensorflow {
diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h
index c18dc9ad1a..2d622dc229 100644
--- a/tensorflow/core/lib/gtl/inlined_vector.h
+++ b/tensorflow/core/lib/gtl/inlined_vector.h
@@ -13,674 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// An InlinedVector<T,N,A> is like a std::vector<T,A>, except that storage
-// for sequences of length <= N are provided inline without requiring
-// any heap allocation. Typically N is very small (e.g., 4) so that
-// sequences that are expected to be short do not require allocations.
-//
-// Only some of the std::vector<> operations are currently implemented.
-// Other operations may be added as needed to facilitate migrating
-// code that uses std::vector<> to InlinedVector<>.
-//
-// NOTE: If you want an inlined version to replace use of a
-// std::vector<bool>, consider using util::bitmap::InlinedBitVector<NBITS>
-// in util/bitmap/inlined_bitvector.h
-//
-// TODO(billydonahue): change size_t to size_type where appropriate.
-
#ifndef TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_
#define TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_
-#include <stddef.h>
-#include <stdlib.h>
-#include <string.h>
-#include <sys/types.h>
-#include <algorithm>
-#include <cstddef>
-#include <iterator>
-#include <memory>
-#include <type_traits>
-#include <vector>
-
-#include "tensorflow/core/lib/gtl/manual_constructor.h"
-#include "tensorflow/core/platform/byte_order.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/mem.h"
+#include "absl/container/inlined_vector.h"
+// TODO(kramerb): This is kept only because lots of targets transitively depend
+// on it. Remove all targets' dependencies.
+#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
-#include <initializer_list> // NOLINT(build/include_order)
-
namespace tensorflow {
namespace gtl {
-template <typename T, int N>
-class InlinedVector {
- public:
- typedef T value_type;
- typedef T* pointer;
- typedef const T* const_pointer;
- typedef T& reference;
- typedef const T& const_reference;
- typedef size_t size_type;
- typedef std::ptrdiff_t difference_type;
- typedef pointer iterator;
- typedef const_pointer const_iterator;
-
- // Create an empty vector
- InlinedVector();
-
- // Create a vector with n copies of value_type().
- explicit InlinedVector(size_t n);
-
- // Create a vector with n copies of elem
- InlinedVector(size_t n, const value_type& elem);
-
- // Create and initialize with the elements [range_start .. range_end).
- // The unused enable_if argument restricts this constructor so that it is
- // elided when value_type is an integral type. This prevents ambiguous
- // interpretation between a call to this constructor with two integral
- // arguments and a call to the preceding (n, elem) constructor.
- template <typename InputIterator>
- InlinedVector(
- InputIterator range_start, InputIterator range_end,
- typename std::enable_if<!std::is_integral<InputIterator>::value>::type* =
- NULL) {
- InitRep();
- AppendRange(range_start, range_end);
- }
-
- InlinedVector(std::initializer_list<value_type> init) {
- InitRep();
- AppendRange(init.begin(), init.end());
- }
-
- InlinedVector(const InlinedVector& v);
-
- ~InlinedVector() { clear(); }
-
- InlinedVector& operator=(const InlinedVector& v) {
- // Optimized to avoid reallocation.
- // Prefer reassignment to copy construction for elements.
- const size_t s = size();
- const size_t vs = v.size();
- if (s < vs) { // grow
- reserve(vs);
- if (s) std::copy(v.begin(), v.begin() + s, begin());
- std::copy(v.begin() + s, v.end(), std::back_inserter(*this));
- } else { // maybe shrink
- erase(begin() + vs, end());
- std::copy(v.begin(), v.end(), begin());
- }
- return *this;
- }
-
- size_t size() const { return size_internal(); }
-
- bool empty() const { return (size() == 0); }
-
- // Return number of elements that can be stored in vector
- // without requiring a reallocation of underlying memory
- size_t capacity() const {
- if (is_inline()) {
- return kFit;
- } else {
- return static_cast<size_t>(1) << u_.data[kSize - 2];
- }
- }
-
- // Return a pointer to the underlying array.
- // Only result[0,size()-1] are defined.
- pointer data() {
- if (is_inline()) {
- return reinterpret_cast<T*>(u_.data);
- } else {
- return outofline_pointer();
- }
- }
- const_pointer data() const {
- return const_cast<InlinedVector<T, N>*>(this)->data();
- }
-
- // Remove all elements
- void clear() {
- DiscardStorage();
- u_.data[kSize - 1] = 0;
- }
-
- // Return the ith element
- // REQUIRES: 0 <= i < size()
- const value_type& at(size_t i) const {
- DCHECK_LT(i, size());
- return data()[i];
- }
- const value_type& operator[](size_t i) const {
- DCHECK_LT(i, size());
- return data()[i];
- }
-
- // Return a non-const reference to the ith element
- // REQUIRES: 0 <= i < size()
- value_type& at(size_t i) {
- DCHECK_LT(i, size());
- return data()[i];
- }
- value_type& operator[](size_t i) {
- DCHECK_LT(i, size());
- return data()[i];
- }
-
- value_type& back() {
- DCHECK(!empty());
- return at(size() - 1);
- }
-
- const value_type& back() const {
- DCHECK(!empty());
- return at(size() - 1);
- }
-
- value_type& front() {
- DCHECK(!empty());
- return at(0);
- }
-
- const value_type& front() const {
- DCHECK(!empty());
- return at(0);
- }
-
- // Append a T constructed with args to the vector.
- // Increases size() by one.
- // Amortized complexity: O(1)
- // Worst-case complexity: O(size())
- template <typename... Args>
- void emplace_back(Args&&... args) {
- size_t s = size();
- DCHECK_LE(s, capacity());
- if (s < capacity()) {
- new (data() + s) T(std::forward<Args>(args)...);
- set_size_internal(s + 1);
- } else {
- EmplaceBackSlow(std::forward<Args>(args)...);
- }
- }
-
- // Append t to the vector.
- // Increases size() by one.
- // Amortized complexity: O(1)
- // Worst-case complexity: O(size())
- void push_back(const value_type& t) { emplace_back(t); }
- void push_back(value_type&& t) { emplace_back(std::move(t)); }
-
- inline void pop_back() {
- DCHECK(!empty());
- const size_t s = size();
- Destroy(data() + s - 1, 1);
- set_size_internal(s - 1);
- }
-
- // Resizes the vector to contain "n" elements.
- // If "n" is smaller than the initial size, extra elements are destroyed.
- // If "n" is larger than the initial size, enough copies of "elem"
- // are appended to increase the size to "n". If "elem" is omitted,
- // new elements are value-initialized.
- void resize(size_t n) { Resize<ValueInit>(n, nullptr); }
- void resize(size_t n, const value_type& elem) { Resize<Fill>(n, &elem); }
-
- iterator begin() { return data(); }
- const_iterator begin() const { return data(); }
-
- iterator end() { return data() + size(); }
- const_iterator end() const { return data() + size(); }
-
- iterator insert(iterator pos, const value_type& v);
-
- iterator erase(iterator pos) {
- DCHECK_LT(pos, end());
- DCHECK_GE(pos, begin());
- std::copy(pos + 1, end(), pos);
- pop_back();
- return pos;
- }
-
- iterator erase(iterator first, iterator last);
-
- // Enlarges the underlying representation so it can hold at least
- // "n" elements without reallocation.
- // Does not change size() or the actual contents of the vector.
- void reserve(size_t n) {
- if (n > capacity()) {
- // Make room for new elements
- Grow<Move>(n);
- }
- }
-
- // Swap the contents of *this with other.
- // REQUIRES: value_type is swappable and copyable.
- void swap(InlinedVector& other);
-
- private:
- // Representation can either be inlined or out-of-line.
- // In either case, at least sizeof(void*) + 8 bytes are available.
- //
- // Inlined:
- // Last byte holds the length.
- // First (length*sizeof(T)) bytes stores the elements.
- // Outlined:
- // Last byte holds kSentinel.
- // Second-last byte holds lg(capacity)
- // Preceding 6 bytes hold size.
- // First sizeof(T*) bytes hold pointer.
-
- // Compute rep size.
- static const size_t kSizeUnaligned = N * sizeof(T) + 1; // Room for tag
- static const size_t kSize = ((kSizeUnaligned + 15) / 16) * 16; // Align
-
- // See how many fit T we can fit inside kSize, but no more than 254
- // since 255 is used as sentinel tag for out-of-line allocation.
- static const unsigned int kSentinel = 255;
- static const size_t kFit1 = (kSize - 1) / sizeof(T);
- static const size_t kFit = (kFit1 >= kSentinel) ? (kSentinel - 1) : kFit1;
-
- union {
- unsigned char data[kSize];
- // Force data to be aligned enough for a pointer.
- T* unused_aligner;
- } u_;
-
- inline void InitRep() { u_.data[kSize - 1] = 0; }
- inline bool is_inline() const { return u_.data[kSize - 1] != kSentinel; }
-
- inline T* outofline_pointer() const {
- T* ptr;
- memcpy(&ptr, &u_.data[0], sizeof(ptr));
- return ptr;
- }
-
- inline void set_outofline_pointer(T* p) {
- memcpy(&u_.data[0], &p, sizeof(p));
- }
-
- inline uint64_t outofline_word() const {
- uint64_t word;
- memcpy(&word, &u_.data[kSize - 8], sizeof(word));
- return word;
- }
-
- inline void set_outofline_word(uint64_t w) {
- memcpy(&u_.data[kSize - 8], &w, sizeof(w));
- }
-
- inline size_t size_internal() const {
- uint8_t s = static_cast<uint8_t>(u_.data[kSize - 1]);
- if (s != kSentinel) {
- return static_cast<size_t>(s);
- } else {
- const uint64_t word = outofline_word();
- if (port::kLittleEndian) {
- // The sentinel and capacity bits are most-significant bits in word.
- return static_cast<size_t>(word & 0xffffffffffffull);
- } else {
- // The sentinel and capacity bits are least-significant bits in word.
- return static_cast<size_t>(word >> 16);
- }
- }
- }
-
- void set_size_internal(size_t n) {
- if (is_inline()) {
- DCHECK_LT(n, kSentinel);
- u_.data[kSize - 1] = static_cast<unsigned char>(n);
- } else {
- uint64_t word;
- if (port::kLittleEndian) {
- // The sentinel and capacity bits are most-significant bits in word.
- word = (static_cast<uint64_t>(n) |
- (static_cast<uint64_t>(u_.data[kSize - 2]) << 48) |
- (static_cast<uint64_t>(kSentinel) << 56));
- } else {
- // The sentinel and capacity bits are least-significant bits in word.
- word = ((static_cast<uint64_t>(n) << 16) |
- (static_cast<uint64_t>(u_.data[kSize - 2]) << 8) |
- (static_cast<uint64_t>(kSentinel)));
- }
- set_outofline_word(word);
- DCHECK_EQ(u_.data[kSize - 1], kSentinel) << n;
- }
- }
-
- void DiscardStorage() {
- T* base = data();
- size_t n = size();
- Destroy(base, n);
- if (!is_inline()) {
- port::Free(base);
- }
- }
-
- template <typename... Args>
- void EmplaceBackSlow(Args&&... args) {
- const size_t s = size();
- DCHECK_EQ(s, capacity());
- Grow<Move, Construct>(s + 1, std::forward<Args>(args)...);
- set_size_internal(s + 1);
- }
-
- // Movers for Grow
- // Does nothing.
- static void Nop(T* src, size_t n, T* dst) {}
-
- // Moves srcs[0,n-1] contents to dst[0,n-1].
- static void Move(T* src, size_t n, T* dst) {
- for (size_t i = 0; i < n; i++) {
- new (dst + i) T(std::move(*(src + i)));
- }
- }
-
- // Initializers for Resize.
- // Initializes dst[0,n-1] with empty constructor.
- static void ValueInit(const T*, size_t n, T* dst) {
- for (size_t i = 0; i < n; i++) {
- new (dst + i) T();
- }
- }
-
- // Initializes dst[0,n-1] with copies of *src.
- static void Fill(const T* src, size_t n, T* dst) {
- for (size_t i = 0; i < n; i++) {
- new (dst + i) T(*src);
- }
- }
-
- void Destroy(T* src, int n) {
- if (!std::is_trivially_destructible<T>::value) {
- for (int i = 0; i < n; i++) {
- (src + i)->~T();
- }
- }
- }
-
- // Initialization methods for Grow.
- // 1) Leave uninitialized memory.
- struct Uninitialized {
- void operator()(T*) const {}
- };
- // 2) Construct a T with args at not-yet-initialized memory pointed by dst.
- struct Construct {
- template <class... Args>
- void operator()(T* dst, Args&&... args) const {
- new (dst) T(std::forward<Args>(args)...);
- }
- };
-
- // Grow so that capacity >= n. Uses Mover to move existing elements
- // to new buffer, and possibly initialize the new element according
- // to InitType.
- // We pass the InitType and Mover as template arguments so that
- // this code compiles even if T does not support copying or default
- // construction.
- template <void(Mover)(T*, size_t, T*), class InitType = Uninitialized,
- class... Args>
- void Grow(size_t n, Args&&... args) {
- size_t s = size();
- DCHECK_LE(s, capacity());
-
- // Compute new capacity by repeatedly doubling current capacity
- size_t target = 1;
- size_t target_lg = 0;
- while (target < kFit || target < n) {
- // TODO(psrc): Check and avoid overflow?
- target_lg++;
- target <<= 1;
- }
-
- T* src = data();
- T* dst = static_cast<T*>(port::Malloc(target * sizeof(T)));
-
- // Need to copy elem before discarding src since it might alias src.
- InitType{}(dst + s, std::forward<Args>(args)...);
- Mover(src, s, dst);
- DiscardStorage();
-
- u_.data[kSize - 1] = kSentinel;
- u_.data[kSize - 2] = static_cast<unsigned char>(target_lg);
- set_size_internal(s);
- DCHECK_EQ(capacity(), target);
- set_outofline_pointer(dst);
- }
-
- // Resize to size n. Any new elements are initialized by passing
- // elem and the destination to Initializer. We pass the Initializer
- // as a template argument so that this code compiles even if T does
- // not support copying.
- template <void(Initializer)(const T*, size_t, T*)>
- void Resize(size_t n, const T* elem) {
- size_t s = size();
- if (n <= s) {
- Destroy(data() + n, s - n);
- set_size_internal(n);
- return;
- }
- reserve(n);
- DCHECK_GE(capacity(), n);
- set_size_internal(n);
- Initializer(elem, n - s, data() + s);
- }
-
- template <typename Iter>
- void AppendRange(Iter first, Iter last, std::input_iterator_tag);
-
- // Faster path for forward iterators.
- template <typename Iter>
- void AppendRange(Iter first, Iter last, std::forward_iterator_tag);
-
- template <typename Iter>
- void AppendRange(Iter first, Iter last);
-};
-
-// Provide linkage for constants.
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kSizeUnaligned;
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kSize;
-template <typename T, int N>
-const unsigned int InlinedVector<T, N>::kSentinel;
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kFit1;
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kFit;
-
-template <typename T, int N>
-inline void swap(InlinedVector<T, N>& a, InlinedVector<T, N>& b) {
- a.swap(b);
-}
-
-template <typename T, int N>
-inline bool operator==(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin());
-}
-
-template <typename T, int N>
-inline bool operator!=(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return !(a == b);
-}
-
-template <typename T, int N>
-inline bool operator<(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end());
-}
-
-template <typename T, int N>
-inline bool operator>(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return b < a;
-}
-
-template <typename T, int N>
-inline bool operator<=(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return !(b < a);
-}
-
-template <typename T, int N>
-inline bool operator>=(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return !(a < b);
-}
-
-// ========================================
-// Implementation
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector() {
- InitRep();
-}
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector(size_t n) {
- InitRep();
- if (n > capacity()) {
- Grow<Nop>(n); // Must use Nop in case T is not copyable
- }
- set_size_internal(n);
- ValueInit(nullptr, n, data());
-}
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector(size_t n, const value_type& elem) {
- InitRep();
- if (n > capacity()) {
- Grow<Nop>(n); // Can use Nop since we know we have nothing to copy
- }
- set_size_internal(n);
- Fill(&elem, n, data());
-}
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector(const InlinedVector& v) {
- InitRep();
- *this = v;
-}
-
-template <typename T, int N>
-typename InlinedVector<T, N>::iterator InlinedVector<T, N>::insert(
- iterator pos, const value_type& v) {
- DCHECK_GE(pos, begin());
- DCHECK_LE(pos, end());
- if (pos == end()) {
- push_back(v);
- return end() - 1;
- }
- size_t s = size();
- size_t idx = std::distance(begin(), pos);
- if (s == capacity()) {
- Grow<Move>(s + 1);
- }
- CHECK_LT(s, capacity());
- pos = begin() + idx; // Reset 'pos' into a post-enlarge iterator.
- Fill(data() + s - 1, 1, data() + s); // data[s] = data[s-1]
- std::copy_backward(pos, data() + s - 1, data() + s);
- *pos = v;
-
- set_size_internal(s + 1);
- return pos;
-}
-
-template <typename T, int N>
-typename InlinedVector<T, N>::iterator InlinedVector<T, N>::erase(
- iterator first, iterator last) {
- DCHECK_LE(begin(), first);
- DCHECK_LE(first, last);
- DCHECK_LE(last, end());
-
- size_t s = size();
- ptrdiff_t erase_gap = std::distance(first, last);
- std::copy(last, data() + s, first);
- Destroy(data() + s - erase_gap, erase_gap);
- set_size_internal(s - erase_gap);
- return first;
-}
-
-template <typename T, int N>
-void InlinedVector<T, N>::swap(InlinedVector& other) {
- using std::swap; // Augment ADL with std::swap.
- if (&other == this) {
- return;
- }
-
- InlinedVector* a = this;
- InlinedVector* b = &other;
-
- const bool a_inline = a->is_inline();
- const bool b_inline = b->is_inline();
-
- if (!a_inline && !b_inline) {
- // Just swap the top-level representations.
- T* aptr = a->outofline_pointer();
- T* bptr = b->outofline_pointer();
- a->set_outofline_pointer(bptr);
- b->set_outofline_pointer(aptr);
-
- uint64_t aword = a->outofline_word();
- uint64_t bword = b->outofline_word();
- a->set_outofline_word(bword);
- b->set_outofline_word(aword);
- return;
- }
-
- // Make a the larger of the two to reduce number of cases.
- size_t a_size = a->size();
- size_t b_size = b->size();
- if (a->size() < b->size()) {
- swap(a, b);
- swap(a_size, b_size);
- }
- DCHECK_GE(a_size, b_size);
-
- if (b->capacity() < a_size) {
- b->Grow<Move>(a_size);
- }
-
- // One is inline and one is not.
- // 'a' is larger. Swap the elements up to the smaller array size.
- std::swap_ranges(a->data(), a->data() + b_size, b->data());
- std::uninitialized_copy(a->data() + b_size, a->data() + a_size,
- b->data() + b_size);
- Destroy(a->data() + b_size, a_size - b_size);
- a->set_size_internal(b_size);
- b->set_size_internal(a_size);
- DCHECK_EQ(b->size(), a_size);
- DCHECK_EQ(a->size(), b_size);
-}
-
-template <typename T, int N>
-template <typename Iter>
-inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last,
- std::input_iterator_tag) {
- std::copy(first, last, std::back_inserter(*this));
-}
-
-template <typename T, int N>
-template <typename Iter>
-inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last,
- std::forward_iterator_tag) {
- typedef typename std::iterator_traits<Iter>::difference_type Length;
- Length length = std::distance(first, last);
- size_t s = size();
- reserve(s + length);
- std::uninitialized_copy_n(first, length, data() + s);
- set_size_internal(s + length);
-}
-
-template <typename T, int N>
-template <typename Iter>
-inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last) {
- typedef typename std::iterator_traits<Iter>::iterator_category IterTag;
- AppendRange(first, last, IterTag());
-}
+using absl::InlinedVector;
} // namespace gtl
} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc
deleted file mode 100644
index 2721885c4a..0000000000
--- a/tensorflow/core/lib/gtl/inlined_vector_test.cc
+++ /dev/null
@@ -1,898 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
-
-#include <list>
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/test_benchmark.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-
-typedef tensorflow::gtl::InlinedVector<int, 8> IntVec;
-
-// A type that counts number of live occurrences of the type
-static int64 instances = 0;
-class Instance {
- public:
- int value_;
- explicit Instance(int x) : value_(x) { instances++; }
- Instance(const Instance& x) : value_(x.value_) { instances++; }
- ~Instance() { instances--; }
-
- friend inline void swap(Instance& a, Instance& b) {
- using std::swap;
- swap(a.value_, b.value_);
- }
-
- friend std::ostream& operator<<(std::ostream& o, const Instance& v) {
- return o << "[value:" << v.value_ << "]";
- }
-};
-
-typedef tensorflow::gtl::InlinedVector<Instance, 8> InstanceVec;
-
-// A simple reference counted class to make sure that the proper elements are
-// destroyed in the erase(begin, end) test.
-class RefCounted {
- public:
- RefCounted(int value, int* count) : value_(value), count_(count) { Ref(); }
-
- RefCounted(const RefCounted& v) : value_(v.value_), count_(v.count_) {
- VLOG(5) << "[RefCounted: copy"
- << " from count @" << v.count_ << "]";
- Ref();
- }
-
- ~RefCounted() {
- Unref();
- count_ = nullptr;
- }
-
- friend void swap(RefCounted& a, RefCounted& b) {
- using std::swap;
- swap(a.value_, b.value_);
- swap(a.count_, b.count_);
- }
-
- RefCounted& operator=(RefCounted v) {
- using std::swap;
- swap(*this, v);
- return *this;
- }
-
- void Ref() const {
- CHECK(count_ != nullptr);
- ++(*count_);
- VLOG(5) << "[Ref: refcount " << *count_ << " on count @" << count_ << "]";
- }
-
- void Unref() const {
- --(*count_);
- CHECK_GE(*count_, 0);
- VLOG(5) << "[Unref: refcount " << *count_ << " on count @" << count_ << "]";
- }
-
- int count() const { return *count_; }
-
- friend std::ostream& operator<<(std::ostream& o, const RefCounted& v) {
- return o << "[value:" << v.value_ << ", count:" << *v.count_ << "]";
- }
-
- int value_;
- int* count_;
-};
-
-typedef tensorflow::gtl::InlinedVector<RefCounted, 8> RefCountedVec;
-
-// A class with a vtable pointer
-class Dynamic {
- public:
- virtual ~Dynamic() {}
-
- friend std::ostream& operator<<(std::ostream& o, const Dynamic& v) {
- return o << "[Dynamic]";
- }
-};
-
-typedef tensorflow::gtl::InlinedVector<Dynamic, 8> DynamicVec;
-
-// Append 0..len-1 to *v
-static void Fill(IntVec* v, int len, int offset = 0) {
- for (int i = 0; i < len; i++) {
- v->push_back(i + offset);
- }
-}
-
-static IntVec Fill(int len, int offset = 0) {
- IntVec v;
- Fill(&v, len, offset);
- return v;
-}
-
-TEST(IntVec, SimpleOps) {
- for (int len = 0; len < 20; len++) {
- IntVec v;
- const IntVec& cv = v; // const alias
-
- Fill(&v, len);
- EXPECT_EQ(len, v.size());
- EXPECT_LE(len, v.capacity());
-
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(i, v[i]);
- }
- EXPECT_EQ(v.begin(), v.data());
- EXPECT_EQ(cv.begin(), cv.data());
-
- int counter = 0;
- for (IntVec::iterator iter = v.begin(); iter != v.end(); ++iter) {
- EXPECT_EQ(counter, *iter);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- counter = 0;
- for (IntVec::const_iterator iter = v.begin(); iter != v.end(); ++iter) {
- EXPECT_EQ(counter, *iter);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- if (len > 0) {
- EXPECT_EQ(0, v.front());
- EXPECT_EQ(len - 1, v.back());
- v.pop_back();
- EXPECT_EQ(len - 1, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(i, v[i]);
- }
- }
- }
-}
-
-TEST(IntVec, Erase) {
- for (int len = 1; len < 20; len++) {
- for (int i = 0; i < len; ++i) {
- IntVec v;
- Fill(&v, len);
- v.erase(v.begin() + i);
- EXPECT_EQ(len - 1, v.size());
- for (int j = 0; j < i; ++j) {
- EXPECT_EQ(j, v[j]);
- }
- for (int j = i; j < len - 1; ++j) {
- EXPECT_EQ(j + 1, v[j]);
- }
- }
- }
-}
-
-// At the end of this test loop, the elements between [erase_begin, erase_end)
-// should have reference counts == 0, and all others elements should have
-// reference counts == 1.
-TEST(RefCountedVec, EraseBeginEnd) {
- for (int len = 1; len < 20; ++len) {
- for (int erase_begin = 0; erase_begin < len; ++erase_begin) {
- for (int erase_end = erase_begin; erase_end <= len; ++erase_end) {
- std::vector<int> counts(len, 0);
- RefCountedVec v;
- for (int i = 0; i < len; ++i) {
- v.push_back(RefCounted(i, &counts[i]));
- }
-
- int erase_len = erase_end - erase_begin;
-
- v.erase(v.begin() + erase_begin, v.begin() + erase_end);
-
- EXPECT_EQ(len - erase_len, v.size());
-
- // Check the elements before the first element erased.
- for (int i = 0; i < erase_begin; ++i) {
- EXPECT_EQ(i, v[i].value_);
- }
-
- // Check the elements after the first element erased.
- for (size_t i = erase_begin; i < v.size(); ++i) {
- EXPECT_EQ(i + erase_len, v[i].value_);
- }
-
- // Check that the elements at the beginning are preserved.
- for (int i = 0; i < erase_begin; ++i) {
- EXPECT_EQ(1, counts[i]);
- }
-
- // Check that the erased elements are destroyed
- for (int i = erase_begin; i < erase_end; ++i) {
- EXPECT_EQ(0, counts[i]);
- }
-
- // Check that the elements at the end are preserved.
- for (int i = erase_end; i < len; ++i) {
- EXPECT_EQ(1, counts[i]);
- }
- }
- }
- }
-}
-
-struct NoDefaultCtor {
- explicit NoDefaultCtor(int) {}
-};
-struct NoCopy {
- NoCopy() {}
- NoCopy(const NoCopy&) = delete;
-};
-struct NoAssign {
- NoAssign() {}
- NoAssign& operator=(const NoAssign&) = delete;
-};
-struct MoveOnly {
- MoveOnly() {}
- MoveOnly(MoveOnly&&) = default;
- MoveOnly& operator=(MoveOnly&&) = default;
-};
-TEST(InlinedVectorTest, NoDefaultCtor) {
- tensorflow::gtl::InlinedVector<NoDefaultCtor, 1> v(10, NoDefaultCtor(2));
- (void)v;
-}
-TEST(InlinedVectorTest, NoCopy) {
- tensorflow::gtl::InlinedVector<NoCopy, 1> v(10);
- (void)v;
-}
-TEST(InlinedVectorTest, NoAssign) {
- tensorflow::gtl::InlinedVector<NoAssign, 1> v(10);
- (void)v;
-}
-TEST(InlinedVectorTest, MoveOnly) {
- gtl::InlinedVector<MoveOnly, 2> v;
- v.push_back(MoveOnly{});
- v.push_back(MoveOnly{});
- v.push_back(MoveOnly{});
-}
-
-TEST(IntVec, Insert) {
- for (int len = 0; len < 20; len++) {
- for (int pos = 0; pos <= len; pos++) {
- IntVec v;
- Fill(&v, len);
- v.insert(v.begin() + pos, 9999);
- EXPECT_EQ(v.size(), len + 1);
- for (int i = 0; i < pos; i++) {
- EXPECT_EQ(v[i], i);
- }
- EXPECT_EQ(v[pos], 9999);
- for (size_t i = pos + 1; i < v.size(); i++) {
- EXPECT_EQ(v[i], i - 1);
- }
- }
- }
-}
-
-TEST(RefCountedVec, InsertConstructorDestructor) {
- // Make sure the proper construction/destruction happen during insert
- // operations.
- for (int len = 0; len < 20; len++) {
- SCOPED_TRACE(len);
- for (int pos = 0; pos <= len; pos++) {
- SCOPED_TRACE(pos);
- std::vector<int> counts(len, 0);
- int inserted_count = 0;
- RefCountedVec v;
- for (int i = 0; i < len; ++i) {
- SCOPED_TRACE(i);
- v.push_back(RefCounted(i, &counts[i]));
- }
-
- for (auto elem : counts) {
- EXPECT_EQ(1, elem);
- }
-
- RefCounted insert_element(9999, &inserted_count);
- EXPECT_EQ(1, inserted_count);
- v.insert(v.begin() + pos, insert_element);
- EXPECT_EQ(2, inserted_count);
- // Check that the elements at the end are preserved.
- for (auto elem : counts) {
- EXPECT_EQ(1, elem);
- }
- EXPECT_EQ(2, inserted_count);
- }
- }
-}
-
-TEST(IntVec, Resize) {
- for (int len = 0; len < 20; len++) {
- IntVec v;
- Fill(&v, len);
-
- // Try resizing up and down by k elements
- static const int kResizeElem = 1000000;
- for (int k = 0; k < 10; k++) {
- // Enlarging resize
- v.resize(len + k, kResizeElem);
- EXPECT_EQ(len + k, v.size());
- EXPECT_LE(len + k, v.capacity());
- for (int i = 0; i < len + k; i++) {
- if (i < len) {
- EXPECT_EQ(i, v[i]);
- } else {
- EXPECT_EQ(kResizeElem, v[i]);
- }
- }
-
- // Shrinking resize
- v.resize(len, kResizeElem);
- EXPECT_EQ(len, v.size());
- EXPECT_LE(len, v.capacity());
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(i, v[i]);
- }
- }
- }
-}
-
-TEST(IntVec, InitWithLength) {
- for (int len = 0; len < 20; len++) {
- IntVec v(len, 7);
- EXPECT_EQ(len, v.size());
- EXPECT_LE(len, v.capacity());
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(7, v[i]);
- }
- }
-}
-
-TEST(IntVec, CopyConstructorAndAssignment) {
- for (int len = 0; len < 20; len++) {
- IntVec v;
- Fill(&v, len);
- EXPECT_EQ(len, v.size());
- EXPECT_LE(len, v.capacity());
-
- IntVec v2(v);
- EXPECT_EQ(v, v2);
-
- for (int start_len = 0; start_len < 20; start_len++) {
- IntVec v3;
- Fill(&v3, start_len, 99); // Add dummy elements that should go away
- v3 = v;
- EXPECT_EQ(v, v3);
- }
- }
-}
-
-TEST(OverheadTest, Storage) {
- // Check for size overhead.
- using tensorflow::gtl::InlinedVector;
- EXPECT_EQ(2 * sizeof(int*), sizeof(InlinedVector<int*, 1>));
- EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector<int*, 2>));
- EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector<int*, 3>));
- EXPECT_EQ(6 * sizeof(int*), sizeof(InlinedVector<int*, 4>));
-
- EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 1>));
- EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 2>));
- EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 3>));
- EXPECT_EQ(2 * sizeof(char*),
- sizeof(InlinedVector<char, 2 * sizeof(char*) - 1>));
- EXPECT_EQ(4 * sizeof(char*), sizeof(InlinedVector<char, 2 * sizeof(char*)>));
-}
-
-TEST(IntVec, Clear) {
- for (int len = 0; len < 20; len++) {
- SCOPED_TRACE(len);
- IntVec v;
- Fill(&v, len);
- v.clear();
- EXPECT_EQ(0, v.size());
- EXPECT_EQ(v.begin(), v.end());
- }
-}
-
-TEST(IntVec, Reserve) {
- for (size_t len = 0; len < 20; len++) {
- IntVec v;
- Fill(&v, len);
-
- for (size_t newlen = 0; newlen < 100; newlen++) {
- const int* start_rep = v.data();
- v.reserve(newlen);
- const int* final_rep = v.data();
- if (newlen <= len) {
- EXPECT_EQ(start_rep, final_rep);
- }
- EXPECT_LE(newlen, v.capacity());
-
- // Filling up to newlen should not change rep
- while (v.size() < newlen) {
- v.push_back(0);
- }
- EXPECT_EQ(final_rep, v.data());
- }
- }
-}
-
-template <typename T>
-static std::vector<typename T::value_type> Vec(const T& src) {
- std::vector<typename T::value_type> result;
- for (const auto& elem : src) {
- result.push_back(elem);
- }
- return result;
-}
-
-TEST(IntVec, SelfRefPushBack) {
- std::vector<string> std_v;
- tensorflow::gtl::InlinedVector<string, 4> v;
- const string s = "A quite long string to ensure heap.";
- std_v.push_back(s);
- v.push_back(s);
- for (int i = 0; i < 20; ++i) {
- EXPECT_EQ(std_v, Vec(v));
-
- v.push_back(v.back());
- std_v.push_back(std_v.back());
- }
- EXPECT_EQ(std_v, Vec(v));
-}
-
-TEST(IntVec, SelfRefPushBackWithMove) {
- std::vector<string> std_v;
- gtl::InlinedVector<string, 4> v;
- const string s = "A quite long string to ensure heap.";
- std_v.push_back(s);
- v.push_back(s);
- for (int i = 0; i < 20; ++i) {
- EXPECT_EQ(v.back(), std_v.back());
-
- v.push_back(std::move(v.back()));
- std_v.push_back(std::move(std_v.back()));
- }
- EXPECT_EQ(v.back(), std_v.back());
-}
-
-TEST(IntVec, Swap) {
- for (int l1 = 0; l1 < 20; l1++) {
- SCOPED_TRACE(l1);
- for (int l2 = 0; l2 < 20; l2++) {
- SCOPED_TRACE(l2);
- IntVec a = Fill(l1, 0);
- IntVec b = Fill(l2, 100);
- {
- using std::swap;
- swap(a, b);
- }
- EXPECT_EQ(l1, b.size());
- EXPECT_EQ(l2, a.size());
- for (int i = 0; i < l1; i++) {
- SCOPED_TRACE(i);
- EXPECT_EQ(i, b[i]);
- }
- for (int i = 0; i < l2; i++) {
- SCOPED_TRACE(i);
- EXPECT_EQ(100 + i, a[i]);
- }
- }
- }
-}
-
-TEST(InstanceVec, Swap) {
- for (int l1 = 0; l1 < 20; l1++) {
- for (int l2 = 0; l2 < 20; l2++) {
- InstanceVec a, b;
- for (int i = 0; i < l1; i++) a.push_back(Instance(i));
- for (int i = 0; i < l2; i++) b.push_back(Instance(100 + i));
- EXPECT_EQ(l1 + l2, instances);
- {
- using std::swap;
- swap(a, b);
- }
- EXPECT_EQ(l1 + l2, instances);
- EXPECT_EQ(l1, b.size());
- EXPECT_EQ(l2, a.size());
- for (int i = 0; i < l1; i++) {
- EXPECT_EQ(i, b[i].value_);
- }
- for (int i = 0; i < l2; i++) {
- EXPECT_EQ(100 + i, a[i].value_);
- }
- }
- }
-}
-
-TEST(IntVec, EqualAndNotEqual) {
- IntVec a, b;
- EXPECT_TRUE(a == b);
- EXPECT_FALSE(a != b);
-
- a.push_back(3);
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
-
- b.push_back(3);
- EXPECT_TRUE(a == b);
- EXPECT_FALSE(a != b);
-
- b.push_back(7);
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
-
- a.push_back(6);
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
-
- a.clear();
- b.clear();
- for (int i = 0; i < 100; i++) {
- a.push_back(i);
- b.push_back(i);
- EXPECT_TRUE(a == b);
- EXPECT_FALSE(a != b);
-
- b[i] = b[i] + 1;
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
-
- b[i] = b[i] - 1; // Back to before
- EXPECT_TRUE(a == b);
- EXPECT_FALSE(a != b);
- }
-}
-
-TEST(IntVec, RelationalOps) {
- IntVec a, b;
- EXPECT_FALSE(a < b);
- EXPECT_FALSE(b < a);
- EXPECT_FALSE(a > b);
- EXPECT_FALSE(b > a);
- EXPECT_TRUE(a <= b);
- EXPECT_TRUE(b <= a);
- EXPECT_TRUE(a >= b);
- EXPECT_TRUE(b >= a);
- b.push_back(3);
- EXPECT_TRUE(a < b);
- EXPECT_FALSE(b < a);
- EXPECT_FALSE(a > b);
- EXPECT_TRUE(b > a);
- EXPECT_TRUE(a <= b);
- EXPECT_FALSE(b <= a);
- EXPECT_FALSE(a >= b);
- EXPECT_TRUE(b >= a);
-}
-
-TEST(InstanceVec, CountConstructorsDestructors) {
- const int start = instances;
- for (int len = 0; len < 20; len++) {
- InstanceVec v;
- for (int i = 0; i < len; i++) {
- v.push_back(Instance(i));
- }
- EXPECT_EQ(start + len, instances);
-
- { // Copy constructor should create 'len' more instances.
- InstanceVec v_copy(v);
- EXPECT_EQ(start + len + len, instances);
- }
- EXPECT_EQ(start + len, instances);
-
- // Enlarging resize() must construct some objects
- v.resize(len + 10, Instance(100));
- EXPECT_EQ(start + len + 10, instances);
-
- // Shrinking resize() must destroy some objects
- v.resize(len, Instance(100));
- EXPECT_EQ(start + len, instances);
-
- // reserve() must not increase the number of initialized objects
- v.reserve(len + 1000);
- EXPECT_EQ(start + len, instances);
-
- // pop_back() and erase() must destroy one object
- if (len > 0) {
- v.pop_back();
- EXPECT_EQ(start + len - 1, instances);
- if (!v.empty()) {
- v.erase(v.begin());
- EXPECT_EQ(start + len - 2, instances);
- }
- }
- }
- EXPECT_EQ(start, instances);
-}
-
-TEST(InstanceVec, CountConstructorsDestructorsOnAssignment) {
- const int start = instances;
- for (int len = 0; len < 20; len++) {
- for (int longorshort = 0; longorshort <= 1; ++longorshort) {
- InstanceVec longer, shorter;
- for (int i = 0; i < len; i++) {
- longer.push_back(Instance(i));
- shorter.push_back(Instance(i));
- }
- longer.push_back(Instance(len));
- EXPECT_EQ(start + len + len + 1, instances);
-
- if (longorshort) {
- shorter = longer;
- EXPECT_EQ(start + (len + 1) + (len + 1), instances);
- } else {
- longer = shorter;
- EXPECT_EQ(start + len + len, instances);
- }
- }
- }
- EXPECT_EQ(start, instances);
-}
-
-TEST(RangedConstructor, SimpleType) {
- std::vector<int> source_v = {4, 5, 6, 7};
- // First try to fit in inline backing
- tensorflow::gtl::InlinedVector<int, 4> v(source_v.begin(), source_v.end());
- tensorflow::gtl::InlinedVector<int, 4> empty4;
- EXPECT_EQ(4, v.size());
- EXPECT_EQ(empty4.capacity(), v.capacity()); // Must still be inline
- EXPECT_EQ(4, v[0]);
- EXPECT_EQ(5, v[1]);
- EXPECT_EQ(6, v[2]);
- EXPECT_EQ(7, v[3]);
-
- // Now, force a re-allocate
- tensorflow::gtl::InlinedVector<int, 2> realloc_v(source_v.begin(),
- source_v.end());
- tensorflow::gtl::InlinedVector<int, 2> empty2;
- EXPECT_EQ(4, realloc_v.size());
- EXPECT_LT(empty2.capacity(), realloc_v.capacity());
- EXPECT_EQ(4, realloc_v[0]);
- EXPECT_EQ(5, realloc_v[1]);
- EXPECT_EQ(6, realloc_v[2]);
- EXPECT_EQ(7, realloc_v[3]);
-}
-
-TEST(RangedConstructor, ComplexType) {
- // We also use a list here to pass a different flavor of iterator (e.g. not
- // random-access).
- std::list<Instance> source_v = {Instance(0)};
-
- // First try to fit in inline backing
- tensorflow::gtl::InlinedVector<Instance, 1> v(source_v.begin(),
- source_v.end());
- tensorflow::gtl::InlinedVector<Instance, 1> empty1;
- EXPECT_EQ(1, v.size());
- EXPECT_EQ(empty1.capacity(), v.capacity()); // Must still be inline
- EXPECT_EQ(0, v[0].value_);
-
- std::list<Instance> source_v2 = {Instance(0), Instance(1), Instance(2),
- Instance(3)};
- // Now, force a re-allocate
- tensorflow::gtl::InlinedVector<Instance, 1> realloc_v(source_v2.begin(),
- source_v2.end());
- EXPECT_EQ(4, realloc_v.size());
- EXPECT_LT(empty1.capacity(), realloc_v.capacity());
- EXPECT_EQ(0, realloc_v[0].value_);
- EXPECT_EQ(1, realloc_v[1].value_);
- EXPECT_EQ(2, realloc_v[2].value_);
- EXPECT_EQ(3, realloc_v[3].value_);
-}
-
-TEST(RangedConstructor, ElementsAreConstructed) {
- std::vector<string> source_v = {"cat", "dog"};
-
- // Force expansion and re-allocation of v. Ensures that when the vector is
- // expanded that new elements are constructed.
- tensorflow::gtl::InlinedVector<string, 1> v(source_v.begin(), source_v.end());
- EXPECT_EQ("cat", v[0]);
- EXPECT_EQ("dog", v[1]);
-}
-
-TEST(InitializerListConstructor, SimpleTypeWithInlineBacking) {
- auto vec = tensorflow::gtl::InlinedVector<int, 3>{4, 5, 6};
- EXPECT_EQ(3, vec.size());
- EXPECT_EQ(3, vec.capacity());
- EXPECT_EQ(4, vec[0]);
- EXPECT_EQ(5, vec[1]);
- EXPECT_EQ(6, vec[2]);
-}
-
-TEST(InitializerListConstructor, SimpleTypeWithReallocationRequired) {
- auto vec = tensorflow::gtl::InlinedVector<int, 2>{4, 5, 6};
- EXPECT_EQ(3, vec.size());
- EXPECT_LE(3, vec.capacity());
- EXPECT_EQ(4, vec[0]);
- EXPECT_EQ(5, vec[1]);
- EXPECT_EQ(6, vec[2]);
-}
-
-TEST(InitializerListConstructor, DisparateTypesInList) {
- EXPECT_EQ((std::vector<int>{-7, 8}),
- Vec(tensorflow::gtl::InlinedVector<int, 2>{-7, 8ULL}));
-
- EXPECT_EQ(
- (std::vector<string>{"foo", "bar"}),
- Vec(tensorflow::gtl::InlinedVector<string, 2>{"foo", string("bar")}));
-}
-
-TEST(InitializerListConstructor, ComplexTypeWithInlineBacking) {
- tensorflow::gtl::InlinedVector<Instance, 1> empty;
- auto vec = tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0)};
- EXPECT_EQ(1, vec.size());
- EXPECT_EQ(empty.capacity(), vec.capacity());
- EXPECT_EQ(0, vec[0].value_);
-}
-
-TEST(InitializerListConstructor, ComplexTypeWithReallocationRequired) {
- auto vec =
- tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0), Instance(1)};
- EXPECT_EQ(2, vec.size());
- EXPECT_LE(2, vec.capacity());
- EXPECT_EQ(0, vec[0].value_);
- EXPECT_EQ(1, vec[1].value_);
-}
-
-TEST(DynamicVec, DynamicVecCompiles) {
- DynamicVec v;
- (void)v;
-}
-
-static void BM_InlinedVectorFill(int iters, int len) {
- for (int i = 0; i < iters; i++) {
- IntVec v;
- for (int j = 0; j < len; j++) {
- v.push_back(j);
- }
- }
- testing::BytesProcessed((int64{iters} * len) * sizeof(int));
-}
-BENCHMARK(BM_InlinedVectorFill)->Range(0, 1024);
-
-static void BM_InlinedVectorFillRange(int iters, int len) {
- std::unique_ptr<int[]> ia(new int[len]);
- for (int j = 0; j < len; j++) {
- ia[j] = j;
- }
- for (int i = 0; i < iters; i++) {
- IntVec TF_ATTRIBUTE_UNUSED v(ia.get(), ia.get() + len);
- }
- testing::BytesProcessed((int64{iters} * len) * sizeof(int));
-}
-BENCHMARK(BM_InlinedVectorFillRange)->Range(0, 1024);
-
-static void BM_StdVectorFill(int iters, int len) {
- for (int i = 0; i < iters; i++) {
- std::vector<int> v;
- v.reserve(len);
- for (int j = 0; j < len; j++) {
- v.push_back(j);
- }
- }
- testing::BytesProcessed((int64{iters} * len) * sizeof(int));
-}
-BENCHMARK(BM_StdVectorFill)->Range(0, 1024);
-
-bool StringRepresentedInline(string s) {
- const char* chars = s.data();
- string s1 = std::move(s);
- return s1.data() != chars;
-}
-
-static void BM_InlinedVectorFillString(int iters, int len) {
- string strings[4] = {"a quite long string", "another long string",
- "012345678901234567", "to cause allocation"};
- for (int i = 0; i < iters; i++) {
- gtl::InlinedVector<string, 8> v;
- for (int j = 0; j < len; j++) {
- v.push_back(strings[j & 3]);
- }
- }
- testing::ItemsProcessed(int64{iters} * len);
-}
-BENCHMARK(BM_InlinedVectorFillString)->Range(0, 1024);
-
-static void BM_StdVectorFillString(int iters, int len) {
- string strings[4] = {"a quite long string", "another long string",
- "012345678901234567", "to cause allocation"};
- for (int i = 0; i < iters; i++) {
- std::vector<string> v;
- v.reserve(len);
- for (int j = 0; j < len; j++) {
- v.push_back(strings[j & 3]);
- }
- }
- testing::ItemsProcessed(int64{iters} * len);
- // The purpose of the benchmark is to verify that inlined vector is
- // efficient when moving is more efficient than copying. To do so, we
- // use strings that are larger than the small string optimization.
- CHECK(!StringRepresentedInline(strings[0]));
-}
-BENCHMARK(BM_StdVectorFillString)->Range(0, 1024);
-
-namespace {
-struct Buffer { // some arbitrary structure for benchmarking.
- char* base;
- int length;
- int capacity;
- void* user_data;
-};
-} // anonymous namespace
-
-static void BM_InlinedVectorTenAssignments(int iters, int len) {
- typedef tensorflow::gtl::InlinedVector<Buffer, 2> BufferVec;
-
- BufferVec src;
- src.resize(len);
-
- iters *= 10;
- BufferVec dst;
- for (int i = 0; i < iters; i++) {
- dst = src;
- }
-}
-BENCHMARK(BM_InlinedVectorTenAssignments)
- ->Arg(0)
- ->Arg(1)
- ->Arg(2)
- ->Arg(3)
- ->Arg(4)
- ->Arg(20);
-
-static void BM_CreateFromInitializerList(int iters) {
- for (; iters > 0; iters--) {
- tensorflow::gtl::InlinedVector<int, 4> x{1, 2, 3};
- (void)x[0];
- }
-}
-BENCHMARK(BM_CreateFromInitializerList);
-
-namespace {
-
-struct LargeSwappable {
- LargeSwappable() : d_(1024, 17) {}
- ~LargeSwappable() {}
- LargeSwappable(const LargeSwappable& o) : d_(o.d_) {}
-
- friend void swap(LargeSwappable& a, LargeSwappable& b) {
- using std::swap;
- swap(a.d_, b.d_);
- }
-
- LargeSwappable& operator=(LargeSwappable o) {
- using std::swap;
- swap(*this, o);
- return *this;
- }
-
- std::vector<int> d_;
-};
-
-} // namespace
-
-static void BM_LargeSwappableElements(int iters, int len) {
- typedef tensorflow::gtl::InlinedVector<LargeSwappable, 32> Vec;
- Vec a(len);
- Vec b;
- while (--iters >= 0) {
- using std::swap;
- swap(a, b);
- }
-}
-BENCHMARK(BM_LargeSwappableElements)->Range(0, 1024);
-
-} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/optional.h b/tensorflow/core/lib/gtl/optional.h
index 7ad916ad3d..238aa18e1e 100644
--- a/tensorflow/core/lib/gtl/optional.h
+++ b/tensorflow/core/lib/gtl/optional.h
@@ -16,861 +16,18 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_LIB_GTL_OPTIONAL_H_
#define TENSORFLOW_CORE_LIB_GTL_OPTIONAL_H_
-#include <assert.h>
-#include <functional>
-#include <initializer_list>
-#include <type_traits>
-#include <utility>
-
-#include "tensorflow/core/platform/logging.h"
+#include "absl/types/optional.h"
namespace tensorflow {
namespace gtl {
-// A value of type gtl::optional<T> holds either a value of T or an
-// "empty" value. When it holds a value of T, it stores it as a direct
-// subobject, so sizeof(optional<T>) is approximately sizeof(T)+1. The interface
-// is based on the upcoming std::optional<T>, and gtl::optional<T> is
-// designed to be cheaply drop-in replaceable by std::optional<T>, once it is
-// rolled out.
-//
-// This implementation is based on the specification in the latest draft as of
-// 2017-01-05, section 20.6.
-//
-// Differences between gtl::optional<T> and std::optional<T> include:
-// - constexpr not used for nonconst member functions.
-// (dependency on some differences between C++11 and C++14.)
-// - nullopt and in_place are not constexpr. We need the inline variable
-// support in C++17 for external linkage.
-// - CHECK instead of throwing std::bad_optional_access.
-// - optional::swap() and swap() relies on std::is_(nothrow_)swappable
-// which is introduced in C++17. So we assume is_swappable is always true
-// and is_nothrow_swappable is same as std::is_trivial.
-// - make_optional cannot be constexpr due to absence of guaranteed copy
-// elision.
-//
-// Synopsis:
-//
-// #include "tensorflow/core/lib/gtl/optional.h"
-//
-// tensorflow::gtl::optional<string> f() {
-// string result;
-// if (...) {
-// ...
-// result = ...;
-// return result;
-// } else {
-// ...
-// return tensorflow::gtl::nullopt;
-// }
-// }
-//
-// int main() {
-// tensorflow::gtl::optional<string> optstr = f();
-// if (optstr) {
-// // non-empty
-// print(optstr.value());
-// } else {
-// // empty
-// error();
-// }
-// }
-template <typename T>
-class optional;
-
-// The tag constant `in_place` is used as the first parameter of an optional<T>
-// constructor to indicate that the remaining arguments should be forwarded
-// to the underlying T constructor.
-struct in_place_t {};
-extern const in_place_t in_place;
-
-// The tag constant `nullopt` is used to indicate an empty optional<T> in
-// certain functions, such as construction or assignment.
-struct nullopt_t {
- struct init_t {};
- static init_t init;
- // It must not be default-constructible to avoid ambiguity for opt = {}.
- // Note the non-const reference, it is to eliminate ambiguity for code like:
- // struct S { int value; };
- //
- // void Test() {
- // optional<S> opt;
- // opt = {{}};
- // }
- explicit constexpr nullopt_t(init_t& /*unused*/) {} // NOLINT
-};
-extern const nullopt_t nullopt;
-
-namespace internal_optional {
-
-// define forward locally because std::forward is not constexpr until C++14
-template <typename T>
-constexpr T&& forward(typename std::remove_reference<T>::type&
- t) noexcept { // NOLINT(runtime/references)
- return static_cast<T&&>(t);
-}
-
-struct empty_struct {};
-// This class stores the data in optional<T>.
-// It is specialized based on whether T is trivially destructible.
-// This is the specialization for non trivially destructible type.
-template <typename T, bool = std::is_trivially_destructible<T>::value>
-class optional_data_dtor_base {
- protected:
- // Whether there is data or not.
- bool engaged_;
- // data storage
- union {
- empty_struct dummy_;
- T data_;
- };
-
- void destruct() noexcept {
- if (engaged_) {
- data_.~T();
- engaged_ = false;
- }
- }
-
- // dummy_ must be initialized for constexpr constructor
- constexpr optional_data_dtor_base() noexcept : engaged_(false), dummy_{} {}
-
- template <typename... Args>
- constexpr explicit optional_data_dtor_base(in_place_t, Args&&... args)
- : engaged_(true), data_(internal_optional::forward<Args>(args)...) {}
-
- ~optional_data_dtor_base() { destruct(); }
-};
-
-// Specialization for trivially destructible type.
-template <typename T>
-class optional_data_dtor_base<T, true> {
- protected:
- // Whether there is data or not.
- bool engaged_;
- // data storage
- union {
- empty_struct dummy_;
- T data_;
- };
- void destruct() noexcept { engaged_ = false; }
-
- // dummy_ must be initialized for constexpr constructor
- constexpr optional_data_dtor_base() noexcept : engaged_(false), dummy_{} {}
-
- template <typename... Args>
- constexpr explicit optional_data_dtor_base(in_place_t, Args&&... args)
- : engaged_(true), data_(internal_optional::forward<Args>(args)...) {}
-
- ~optional_data_dtor_base() = default;
-};
-
-template <typename T>
-class optional_data : public optional_data_dtor_base<T> {
- protected:
- using base = optional_data_dtor_base<T>;
- using base::base;
-
- T* pointer() { return &this->data_; }
-
- constexpr const T* pointer() const { return &this->data_; }
-
- template <typename... Args>
- void construct(Args&&... args) {
- new (pointer()) T(std::forward<Args>(args)...);
- this->engaged_ = true;
- }
-
- template <typename U>
- void assign(U&& u) {
- if (this->engaged_) {
- this->data_ = std::forward<U>(u);
- } else {
- construct(std::forward<U>(u));
- }
- }
-
- optional_data() = default;
-
- optional_data(const optional_data& rhs) {
- if (rhs.engaged_) {
- construct(rhs.data_);
- }
- }
-
- optional_data(optional_data&& rhs) noexcept(
- std::is_nothrow_move_constructible<T>::value) {
- if (rhs.engaged_) {
- construct(std::move(rhs.data_));
- }
- }
-
- optional_data& operator=(const optional_data& rhs) {
- if (rhs.engaged_) {
- assign(rhs.data_);
- } else {
- this->destruct();
- }
- return *this;
- }
-
- optional_data& operator=(optional_data&& rhs) noexcept(
- std::is_nothrow_move_assignable<T>::value&&
- std::is_nothrow_move_constructible<T>::value) {
- if (rhs.engaged_) {
- assign(std::move(rhs.data_));
- } else {
- this->destruct();
- }
- return *this;
- }
-};
-
-// ordered by level of restriction, from low to high.
-// copyable implies movable.
-enum class copy_traits { copyable = 0, movable = 1, non_movable = 2 };
-
-// base class for enabling/disabling copy/move constructor.
-template <copy_traits>
-class optional_ctor_base;
-
-template <>
-class optional_ctor_base<copy_traits::copyable> {
- public:
- constexpr optional_ctor_base() = default;
- optional_ctor_base(const optional_ctor_base&) = default;
- optional_ctor_base(optional_ctor_base&&) = default;
- optional_ctor_base& operator=(const optional_ctor_base&) = default;
- optional_ctor_base& operator=(optional_ctor_base&&) = default;
-};
-
-template <>
-class optional_ctor_base<copy_traits::movable> {
- public:
- constexpr optional_ctor_base() = default;
- optional_ctor_base(const optional_ctor_base&) = delete;
- optional_ctor_base(optional_ctor_base&&) = default;
- optional_ctor_base& operator=(const optional_ctor_base&) = default;
- optional_ctor_base& operator=(optional_ctor_base&&) = default;
-};
-
-template <>
-class optional_ctor_base<copy_traits::non_movable> {
- public:
- constexpr optional_ctor_base() = default;
- optional_ctor_base(const optional_ctor_base&) = delete;
- optional_ctor_base(optional_ctor_base&&) = delete;
- optional_ctor_base& operator=(const optional_ctor_base&) = default;
- optional_ctor_base& operator=(optional_ctor_base&&) = default;
-};
-
-// base class for enabling/disabling copy/move assignment.
-template <copy_traits>
-class optional_assign_base;
-
-template <>
-class optional_assign_base<copy_traits::copyable> {
- public:
- constexpr optional_assign_base() = default;
- optional_assign_base(const optional_assign_base&) = default;
- optional_assign_base(optional_assign_base&&) = default;
- optional_assign_base& operator=(const optional_assign_base&) = default;
- optional_assign_base& operator=(optional_assign_base&&) = default;
-};
-
-template <>
-class optional_assign_base<copy_traits::movable> {
- public:
- constexpr optional_assign_base() = default;
- optional_assign_base(const optional_assign_base&) = default;
- optional_assign_base(optional_assign_base&&) = default;
- optional_assign_base& operator=(const optional_assign_base&) = delete;
- optional_assign_base& operator=(optional_assign_base&&) = default;
-};
-
-template <>
-class optional_assign_base<copy_traits::non_movable> {
- public:
- constexpr optional_assign_base() = default;
- optional_assign_base(const optional_assign_base&) = default;
- optional_assign_base(optional_assign_base&&) = default;
- optional_assign_base& operator=(const optional_assign_base&) = delete;
- optional_assign_base& operator=(optional_assign_base&&) = delete;
-};
-
+// Deprecated: please use absl::optional directly.
+using absl::make_optional;
+using absl::nullopt;
template <typename T>
-constexpr copy_traits get_ctor_copy_traits() {
- return std::is_copy_constructible<T>::value
- ? copy_traits::copyable
- : std::is_move_constructible<T>::value ? copy_traits::movable
- : copy_traits::non_movable;
-}
-
-template <typename T>
-constexpr copy_traits get_assign_copy_traits() {
- return std::is_copy_assignable<T>::value &&
- std::is_copy_constructible<T>::value
- ? copy_traits::copyable
- : std::is_move_assignable<T>::value &&
- std::is_move_constructible<T>::value
- ? copy_traits::movable
- : copy_traits::non_movable;
-}
-
-// Whether T is constructible or convertible from optional<U>.
-template <typename T, typename U>
-struct is_constructible_convertible_from_optional
- : std::integral_constant<
- bool, std::is_constructible<T, optional<U>&>::value ||
- std::is_constructible<T, optional<U>&&>::value ||
- std::is_constructible<T, const optional<U>&>::value ||
- std::is_constructible<T, const optional<U>&&>::value ||
- std::is_convertible<optional<U>&, T>::value ||
- std::is_convertible<optional<U>&&, T>::value ||
- std::is_convertible<const optional<U>&, T>::value ||
- std::is_convertible<const optional<U>&&, T>::value> {};
-
-// Whether T is constructible or convertible or assignable from optional<U>.
-template <typename T, typename U>
-struct is_constructible_convertible_assignable_from_optional
- : std::integral_constant<
- bool, is_constructible_convertible_from_optional<T, U>::value ||
- std::is_assignable<T&, optional<U>&>::value ||
- std::is_assignable<T&, optional<U>&&>::value ||
- std::is_assignable<T&, const optional<U>&>::value ||
- std::is_assignable<T&, const optional<U>&&>::value> {};
-
-} // namespace internal_optional
-
-template <typename T>
-class optional : private internal_optional::optional_data<T>,
- private internal_optional::optional_ctor_base<
- internal_optional::get_ctor_copy_traits<T>()>,
- private internal_optional::optional_assign_base<
- internal_optional::get_assign_copy_traits<T>()> {
- using data_base = internal_optional::optional_data<T>;
-
- public:
- typedef T value_type;
-
- // [optional.ctor], constructors
-
- // A default constructed optional holds the empty value, NOT a default
- // constructed T.
- constexpr optional() noexcept {}
-
- // An optional initialized with `nullopt` holds the empty value.
- constexpr optional(nullopt_t) noexcept {} // NOLINT(runtime/explicit)
-
- // Copy constructor, standard semantics.
- optional(const optional& src) = default;
-
- // Move constructor, standard semantics.
- optional(optional&& src) = default;
-
- // optional<T>(in_place, arg1, arg2, arg3) constructs a non-empty optional
- // with an in-place constructed value of T(arg1,arg2,arg3).
- // TODO(b/34201852): Add std::is_constructible<T, Args&&...> SFINAE.
- template <typename... Args>
- constexpr explicit optional(in_place_t, Args&&... args)
- : data_base(in_place_t(), internal_optional::forward<Args>(args)...) {}
-
- // optional<T>(in_place, {arg1, arg2, arg3}) constructs a non-empty optional
- // with an in-place list-initialized value of T({arg1, arg2, arg3}).
- template <typename U, typename... Args,
- typename = typename std::enable_if<std::is_constructible<
- T, std::initializer_list<U>&, Args&&...>::value>::type>
- constexpr explicit optional(in_place_t, std::initializer_list<U> il,
- Args&&... args)
- : data_base(in_place_t(), il, internal_optional::forward<Args>(args)...) {
- }
-
- template <
- typename U = T,
- typename std::enable_if<
- std::is_constructible<T, U&&>::value &&
- !std::is_same<in_place_t, typename std::decay<U>::type>::value &&
- !std::is_same<optional<T>, typename std::decay<U>::type>::value &&
- std::is_convertible<U&&, T>::value,
- bool>::type = false>
- constexpr optional(U&& v) // NOLINT
- : data_base(in_place_t(), internal_optional::forward<U>(v)) {}
-
- template <
- typename U = T,
- typename std::enable_if<
- std::is_constructible<T, U&&>::value &&
- !std::is_same<in_place_t, typename std::decay<U>::type>::value &&
- !std::is_same<optional<T>, typename std::decay<U>::type>::value &&
- !std::is_convertible<U&&, T>::value,
- bool>::type = false>
- explicit constexpr optional(U&& v)
- : data_base(in_place_t(), internal_optional::forward<U>(v)) {}
-
- // Converting copy constructor (implicit)
- template <
- typename U,
- typename std::enable_if<
- std::is_constructible<T, const U&>::value &&
- !internal_optional::is_constructible_convertible_from_optional<
- T, U>::value &&
- std::is_convertible<const U&, T>::value,
- bool>::type = false>
- optional(const optional<U>& rhs) { // NOLINT
- if (rhs) {
- this->construct(*rhs);
- }
- }
-
- // Converting copy constructor (explicit)
- template <
- typename U,
- typename std::enable_if<
- std::is_constructible<T, const U&>::value &&
- !internal_optional::is_constructible_convertible_from_optional<
- T, U>::value &&
- !std::is_convertible<const U&, T>::value,
- bool>::type = false>
- explicit optional(const optional<U>& rhs) {
- if (rhs) {
- this->construct(*rhs);
- }
- }
-
- // Converting move constructor (implicit)
- template <
- typename U,
- typename std::enable_if<
- std::is_constructible<T, U&&>::value &&
- !internal_optional::is_constructible_convertible_from_optional<
- T, U>::value &&
- std::is_convertible<U&&, T>::value,
- bool>::type = false>
- optional(optional<U>&& rhs) { // NOLINT
- if (rhs) {
- this->construct(std::move(*rhs));
- }
- }
-
- // Converting move constructor (explicit)
- template <
- typename U,
- typename std::enable_if<
- std::is_constructible<T, U&&>::value &&
- !internal_optional::is_constructible_convertible_from_optional<
- T, U>::value &&
- !std::is_convertible<U&&, T>::value,
- bool>::type = false>
- explicit optional(optional<U>&& rhs) {
- if (rhs) {
- this->construct(std::move(*rhs));
- }
- }
-
- // [optional.dtor], destructor, trivial if T is trivially destructible.
- ~optional() = default;
-
- // [optional.assign], assignment
-
- // Assignment from nullopt: opt = nullopt
- optional& operator=(nullopt_t) noexcept {
- this->destruct();
- return *this;
- }
-
- // Copy assignment, standard semantics.
- optional& operator=(const optional& src) = default;
-
- // Move assignment, standard semantics.
- optional& operator=(optional&& src) = default;
-
- // Value assignment
- template <
- typename U = T,
- typename = typename std::enable_if<
- !std::is_same<optional<T>, typename std::decay<U>::type>::value &&
- (!std::is_scalar<T>::value ||
- !std::is_same<T, typename std::decay<U>::type>::value) &&
- std::is_constructible<T, U>::value &&
- std::is_assignable<T&, U>::value>::type>
- optional& operator=(U&& v) {
- this->assign(std::forward<U>(v));
- return *this;
- }
-
- template <typename U,
- typename = typename std::enable_if<
- std::is_constructible<T, const U&>::value &&
- std::is_assignable<T&, const U&>::value &&
- !internal_optional::
- is_constructible_convertible_assignable_from_optional<
- T, U>::value>::type>
- optional& operator=(const optional<U>& rhs) {
- if (rhs) {
- this->assign(*rhs);
- } else {
- this->destruct();
- }
- return *this;
- }
-
- template <typename U,
- typename = typename std::enable_if<
- std::is_constructible<T, U>::value &&
- std::is_assignable<T&, U>::value &&
- !internal_optional::
- is_constructible_convertible_assignable_from_optional<
- T, U>::value>::type>
- optional& operator=(optional<U>&& rhs) {
- if (rhs) {
- this->assign(std::move(*rhs));
- } else {
- this->destruct();
- }
- return *this;
- }
-
- // [optional.mod], modifiers
- // Destroys the inner T value if one is present.
- void reset() noexcept { this->destruct(); }
-
- // Emplace reconstruction. (Re)constructs the underlying T in-place with the
- // given arguments forwarded:
- //
- // optional<Foo> opt;
- // opt.emplace(arg1,arg2,arg3); (Constructs Foo(arg1,arg2,arg3))
- //
- // If the optional is non-empty, and the `args` refer to subobjects of the
- // current object, then behavior is undefined. This is because the current
- // object will be destructed before the new object is constructed with `args`.
- //
- template <typename... Args,
- typename = typename std::enable_if<
- std::is_constructible<T, Args&&...>::value>::type>
- void emplace(Args&&... args) {
- this->destruct();
- this->construct(std::forward<Args>(args)...);
- }
-
- // Emplace reconstruction with initializer-list. See immediately above.
- template <class U, class... Args,
- typename = typename std::enable_if<std::is_constructible<
- T, std::initializer_list<U>&, Args&&...>::value>::type>
- void emplace(std::initializer_list<U> il, Args&&... args) {
- this->destruct();
- this->construct(il, std::forward<Args>(args)...);
- }
-
- // [optional.swap], swap
- // Swap, standard semantics.
- void swap(optional& rhs) noexcept(
- std::is_nothrow_move_constructible<T>::value&&
- std::is_trivial<T>::value) {
- if (*this) {
- if (rhs) {
- using std::swap;
- swap(**this, *rhs);
- } else {
- rhs.construct(std::move(**this));
- this->destruct();
- }
- } else {
- if (rhs) {
- this->construct(std::move(*rhs));
- rhs.destruct();
- } else {
- // no effect (swap(disengaged, disengaged))
- }
- }
- }
-
- // [optional.observe], observers
- // You may use `*opt`, and `opt->m`, to access the underlying T value and T's
- // member `m`, respectively. If the optional is empty, behavior is
- // undefined.
- constexpr const T* operator->() const { return this->pointer(); }
- T* operator->() {
- assert(this->engaged_);
- return this->pointer();
- }
- constexpr const T& operator*() const& { return reference(); }
- T& operator*() & {
- assert(this->engaged_);
- return reference();
- }
- constexpr const T&& operator*() const&& { return std::move(reference()); }
- T&& operator*() && {
- assert(this->engaged_);
- return std::move(reference());
- }
-
- // In a bool context an optional<T> will return false if and only if it is
- // empty.
- //
- // if (opt) {
- // // do something with opt.value();
- // } else {
- // // opt is empty
- // }
- //
- constexpr explicit operator bool() const noexcept { return this->engaged_; }
-
- // Returns false if and only if *this is empty.
- constexpr bool has_value() const noexcept { return this->engaged_; }
-
- // Use `opt.value()` to get a reference to underlying value. The constness
- // and lvalue/rvalue-ness of `opt` is preserved to the view of the T
- // subobject.
- const T& value() const& {
- CHECK(*this) << "Bad optional access";
- return reference();
- }
- T& value() & {
- CHECK(*this) << "Bad optional access";
- return reference();
- }
- T&& value() && { // NOLINT(build/c++11)
- CHECK(*this) << "Bad optional access";
- return std::move(reference());
- }
- const T&& value() const&& { // NOLINT(build/c++11)
- CHECK(*this) << "Bad optional access";
- return std::move(reference());
- }
-
- // Use `opt.value_or(val)` to get either the value of T or the given default
- // `val` in the empty case.
- template <class U>
- constexpr T value_or(U&& v) const& {
- return static_cast<bool>(*this) ? **this
- : static_cast<T>(std::forward<U>(v));
- }
- template <class U>
- T value_or(U&& v) && { // NOLINT(build/c++11)
- return static_cast<bool>(*this) ? std::move(**this)
- : static_cast<T>(std::forward<U>(v));
- }
-
- private:
- // Private accessors for internal storage viewed as reference to T.
- constexpr const T& reference() const { return *this->pointer(); }
- T& reference() { return *(this->pointer()); }
-
- // T constraint checks. You can't have an optional of nullopt_t, in_place_t
- // or a reference.
- static_assert(
- !std::is_same<nullopt_t, typename std::remove_cv<T>::type>::value,
- "optional<nullopt_t> is not allowed.");
- static_assert(
- !std::is_same<in_place_t, typename std::remove_cv<T>::type>::value,
- "optional<in_place_t> is not allowed.");
- static_assert(!std::is_reference<T>::value,
- "optional<reference> is not allowed.");
-};
-
-// [optional.specalg]
-// Swap, standard semantics.
-// This function shall not participate in overload resolution unless
-// is_move_constructible_v<T> is true and is_swappable_v<T> is true.
-// NOTE: we assume is_swappable is always true. There will be a compiling error
-// if T is actually not Swappable.
-template <typename T,
- typename std::enable_if<std::is_move_constructible<T>::value,
- bool>::type = false>
-void swap(optional<T>& a, optional<T>& b) noexcept(noexcept(a.swap(b))) {
- a.swap(b);
-}
-
-// NOTE: make_optional cannot be constexpr in C++11 because the copy/move
-// constructor is not constexpr and we don't have guaranteed copy elision
-// util C++17. But they are still declared constexpr for consistency with
-// the standard.
-
-// make_optional(v) creates a non-empty optional<T> where the type T is deduced
-// from v. Can also be explicitly instantiated as make_optional<T>(v).
-template <typename T>
-constexpr optional<typename std::decay<T>::type> make_optional(T&& v) {
- return optional<typename std::decay<T>::type>(std::forward<T>(v));
-}
-
-template <typename T, typename... Args>
-constexpr optional<T> make_optional(Args&&... args) {
- return optional<T>(in_place_t(), internal_optional::forward<Args>(args)...);
-}
-
-template <typename T, typename U, typename... Args>
-constexpr optional<T> make_optional(std::initializer_list<U> il,
- Args&&... args) {
- return optional<T>(in_place_t(), il,
- internal_optional::forward<Args>(args)...);
-}
-
-// Relational operators. Empty optionals are considered equal to each
-// other and less than non-empty optionals. Supports relations between
-// optional<T> and optional<T>, between optional<T> and T, and between
-// optional<T> and nullopt.
-// Note: We're careful to support T having non-bool relationals.
-
-// Relational operators [optional.relops]
-// The C++17 (N4606) "Returns:" statements are translated into code
-// in an obvious way here, and the original text retained as function docs.
-// Returns: If bool(x) != bool(y), false; otherwise if bool(x) == false, true;
-// otherwise *x == *y.
-template <class T>
-constexpr bool operator==(const optional<T>& x, const optional<T>& y) {
- return static_cast<bool>(x) != static_cast<bool>(y)
- ? false
- : static_cast<bool>(x) == false ? true : *x == *y;
-}
-// Returns: If bool(x) != bool(y), true; otherwise, if bool(x) == false, false;
-// otherwise *x != *y.
-template <class T>
-constexpr bool operator!=(const optional<T>& x, const optional<T>& y) {
- return static_cast<bool>(x) != static_cast<bool>(y)
- ? true
- : static_cast<bool>(x) == false ? false : *x != *y;
-}
-// Returns: If !y, false; otherwise, if !x, true; otherwise *x < *y.
-template <class T>
-constexpr bool operator<(const optional<T>& x, const optional<T>& y) {
- return !y ? false : !x ? true : *x < *y;
-}
-// Returns: If !x, false; otherwise, if !y, true; otherwise *x > *y.
-template <class T>
-constexpr bool operator>(const optional<T>& x, const optional<T>& y) {
- return !x ? false : !y ? true : *x > *y;
-}
-// Returns: If !x, true; otherwise, if !y, false; otherwise *x <= *y.
-template <class T>
-constexpr bool operator<=(const optional<T>& x, const optional<T>& y) {
- return !x ? true : !y ? false : *x <= *y;
-}
-// Returns: If !y, true; otherwise, if !x, false; otherwise *x >= *y.
-template <class T>
-constexpr bool operator>=(const optional<T>& x, const optional<T>& y) {
- return !y ? true : !x ? false : *x >= *y;
-}
-
-// Comparison with nullopt [optional.nullops]
-// The C++17 (N4606) "Returns:" statements are used directly here.
-template <class T>
-constexpr bool operator==(const optional<T>& x, nullopt_t) noexcept {
- return !x;
-}
-template <class T>
-constexpr bool operator==(nullopt_t, const optional<T>& x) noexcept {
- return !x;
-}
-template <class T>
-constexpr bool operator!=(const optional<T>& x, nullopt_t) noexcept {
- return static_cast<bool>(x);
-}
-template <class T>
-constexpr bool operator!=(nullopt_t, const optional<T>& x) noexcept {
- return static_cast<bool>(x);
-}
-template <class T>
-constexpr bool operator<(const optional<T>& x, nullopt_t) noexcept {
- return false;
-}
-template <class T>
-constexpr bool operator<(nullopt_t, const optional<T>& x) noexcept {
- return static_cast<bool>(x);
-}
-template <class T>
-constexpr bool operator<=(const optional<T>& x, nullopt_t) noexcept {
- return !x;
-}
-template <class T>
-constexpr bool operator<=(nullopt_t, const optional<T>& x) noexcept {
- return true;
-}
-template <class T>
-constexpr bool operator>(const optional<T>& x, nullopt_t) noexcept {
- return static_cast<bool>(x);
-}
-template <class T>
-constexpr bool operator>(nullopt_t, const optional<T>& x) noexcept {
- return false;
-}
-template <class T>
-constexpr bool operator>=(const optional<T>& x, nullopt_t) noexcept {
- return true;
-}
-template <class T>
-constexpr bool operator>=(nullopt_t, const optional<T>& x) noexcept {
- return !x;
-}
-
-// Comparison with T [optional.comp_with_t]
-// The C++17 (N4606) "Equivalent to:" statements are used directly here.
-template <class T>
-constexpr bool operator==(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x == v : false;
-}
-template <class T>
-constexpr bool operator==(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v == *x : false;
-}
-template <class T>
-constexpr bool operator!=(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x != v : true;
-}
-template <class T>
-constexpr bool operator!=(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v != *x : true;
-}
-template <class T>
-constexpr bool operator<(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x < v : true;
-}
-template <class T>
-constexpr bool operator<(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v < *x : false;
-}
-template <class T>
-constexpr bool operator<=(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x <= v : true;
-}
-template <class T>
-constexpr bool operator<=(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v <= *x : false;
-}
-template <class T>
-constexpr bool operator>(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x > v : false;
-}
-template <class T>
-constexpr bool operator>(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v > *x : true;
-}
-template <class T>
-constexpr bool operator>=(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x >= v : false;
-}
-template <class T>
-constexpr bool operator>=(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v >= *x : true;
-}
+using optional = absl::optional<T>;
} // namespace gtl
} // namespace tensorflow
-namespace std {
-
-// Normally std::hash specializations are not recommended in tensorflow code,
-// but we allow this as it is following a standard library component.
-template <class T>
-struct hash<::tensorflow::gtl::optional<T>> {
- size_t operator()(const ::tensorflow::gtl::optional<T>& opt) const {
- if (opt) {
- return hash<T>()(*opt);
- } else {
- return static_cast<size_t>(0x297814aaad196e6dULL);
- }
- }
-};
-
-} // namespace std
-
#endif // TENSORFLOW_CORE_LIB_GTL_OPTIONAL_H_
diff --git a/tensorflow/core/lib/gtl/optional_test.cc b/tensorflow/core/lib/gtl/optional_test.cc
deleted file mode 100644
index 12b5bbc60b..0000000000
--- a/tensorflow/core/lib/gtl/optional_test.cc
+++ /dev/null
@@ -1,1098 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/gtl/optional.h"
-
-#include <string>
-#include <utility>
-
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace {
-
-using tensorflow::gtl::in_place;
-using tensorflow::gtl::in_place_t;
-using tensorflow::gtl::make_optional;
-using tensorflow::gtl::nullopt;
-using tensorflow::gtl::nullopt_t;
-using tensorflow::gtl::optional;
-
-template <typename T>
-string TypeQuals(T&) {
- return "&";
-}
-template <typename T>
-string TypeQuals(T&&) {
- return "&&";
-}
-template <typename T>
-string TypeQuals(const T&) {
- return "c&";
-}
-template <typename T>
-string TypeQuals(const T&&) {
- return "c&&";
-}
-
-struct StructorListener {
- int construct0 = 0;
- int construct1 = 0;
- int construct2 = 0;
- int listinit = 0;
- int copy = 0;
- int move = 0;
- int copy_assign = 0;
- int move_assign = 0;
- int destruct = 0;
-};
-
-struct Listenable {
- static StructorListener* listener;
-
- Listenable() { ++listener->construct0; }
- Listenable(int /*unused*/) { ++listener->construct1; } // NOLINT
- Listenable(int /*unused*/, int /*unused*/) { ++listener->construct2; }
- Listenable(std::initializer_list<int> /*unused*/) { ++listener->listinit; }
- Listenable(const Listenable& /*unused*/) { ++listener->copy; }
- Listenable(Listenable&& /*unused*/) { ++listener->move; } // NOLINT
- Listenable& operator=(const Listenable& /*unused*/) {
- ++listener->copy_assign;
- return *this;
- }
- Listenable& operator=(Listenable&& /*unused*/) { // NOLINT
- ++listener->move_assign;
- return *this;
- }
- ~Listenable() { ++listener->destruct; }
-};
-
-StructorListener* Listenable::listener = nullptr;
-
-// clang on macos -- even the latest major version at time of writing (8.x) --
-// does not like much of our constexpr business. clang < 3.0 also has trouble.
-#if defined(__clang__) && defined(__APPLE__)
-#define SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG
-#endif
-
-struct ConstexprType {
- constexpr ConstexprType() : x(0) {}
- constexpr explicit ConstexprType(int i) : x(i) {}
-#ifndef SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG
- constexpr ConstexprType(std::initializer_list<int> il) : x(il.size()) {}
-#endif
- constexpr ConstexprType(const char* s) : x(-1) {} // NOLINT
- int x;
-};
-
-struct Copyable {
- Copyable() {}
- Copyable(const Copyable&) {}
- Copyable& operator=(const Copyable&) { return *this; }
-};
-
-struct MoveableThrow {
- MoveableThrow() {}
- MoveableThrow(MoveableThrow&&) {}
- MoveableThrow& operator=(MoveableThrow&&) { return *this; }
-};
-
-struct MoveableNoThrow {
- MoveableNoThrow() {}
- MoveableNoThrow(MoveableNoThrow&&) noexcept {}
- MoveableNoThrow& operator=(MoveableNoThrow&&) noexcept { return *this; }
-};
-
-struct NonMovable {
- NonMovable() {}
- NonMovable(const NonMovable&) = delete;
- NonMovable& operator=(const NonMovable&) = delete;
- NonMovable(NonMovable&&) = delete;
- NonMovable& operator=(NonMovable&&) = delete;
-};
-
-TEST(optionalTest, DefaultConstructor) {
- optional<int> empty;
- EXPECT_FALSE(!!empty);
- constexpr optional<int> cempty;
- static_assert(!cempty.has_value(), "");
- EXPECT_TRUE(std::is_nothrow_default_constructible<optional<int>>::value);
-}
-
-TEST(optionalTest, NullOptConstructor) {
- optional<int> empty(nullopt);
- EXPECT_FALSE(!!empty);
- // Creating a temporary nullopt_t object instead of using nullopt because
- // nullopt cannot be constexpr and have external linkage at the same time.
- constexpr optional<int> cempty{nullopt_t(nullopt_t::init)};
- static_assert(!cempty.has_value(), "");
- EXPECT_TRUE((std::is_nothrow_constructible<optional<int>, nullopt_t>::value));
-}
-
-TEST(optionalTest, CopyConstructor) {
- optional<int> empty, opt42 = 42;
- optional<int> empty_copy(empty);
- EXPECT_FALSE(!!empty_copy);
- optional<int> opt42_copy(opt42);
- EXPECT_TRUE(!!opt42_copy);
- EXPECT_EQ(42, opt42_copy);
- // test copyablility
- EXPECT_TRUE(std::is_copy_constructible<optional<int>>::value);
- EXPECT_TRUE(std::is_copy_constructible<optional<Copyable>>::value);
- EXPECT_FALSE(std::is_copy_constructible<optional<MoveableThrow>>::value);
- EXPECT_FALSE(std::is_copy_constructible<optional<MoveableNoThrow>>::value);
- EXPECT_FALSE(std::is_copy_constructible<optional<NonMovable>>::value);
-}
-
-TEST(optionalTest, MoveConstructor) {
- optional<int> empty, opt42 = 42;
- optional<int> empty_move(std::move(empty));
- EXPECT_FALSE(!!empty_move);
- optional<int> opt42_move(std::move(opt42));
- EXPECT_TRUE(!!opt42_move);
- EXPECT_EQ(42, opt42_move);
- // test movability
- EXPECT_TRUE(std::is_move_constructible<optional<int>>::value);
- EXPECT_TRUE(std::is_move_constructible<optional<Copyable>>::value);
- EXPECT_TRUE(std::is_move_constructible<optional<MoveableThrow>>::value);
- EXPECT_TRUE(std::is_move_constructible<optional<MoveableNoThrow>>::value);
- EXPECT_FALSE(std::is_move_constructible<optional<NonMovable>>::value);
- // test noexcept
- EXPECT_TRUE(std::is_nothrow_move_constructible<optional<int>>::value);
- EXPECT_FALSE(
- std::is_nothrow_move_constructible<optional<MoveableThrow>>::value);
- EXPECT_TRUE(
- std::is_nothrow_move_constructible<optional<MoveableNoThrow>>::value);
-}
-
-TEST(optionalTest, Destructor) {
- struct Trivial {};
-
- struct NonTrivial {
- ~NonTrivial() {}
- };
-
- EXPECT_TRUE(std::is_trivially_destructible<optional<int>>::value);
- EXPECT_TRUE(std::is_trivially_destructible<optional<Trivial>>::value);
- EXPECT_FALSE(std::is_trivially_destructible<optional<NonTrivial>>::value);
-}
-
-TEST(optionalTest, InPlaceConstructor) {
- constexpr optional<ConstexprType> opt0{in_place_t()};
- static_assert(opt0, "");
- static_assert(opt0->x == 0, "");
- constexpr optional<ConstexprType> opt1{in_place_t(), 1};
- static_assert(opt1, "");
- static_assert(opt1->x == 1, "");
-#ifndef SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG
- constexpr optional<ConstexprType> opt2{in_place_t(), {1, 2}};
- static_assert(opt2, "");
- static_assert(opt2->x == 2, "");
-#endif
-
- // TODO(b/34201852): uncomment these when std::is_constructible<T, Args&&...>
- // SFINAE is added to optional::optional(in_place_t, Args&&...).
- // struct I {
- // I(in_place_t);
- // };
-
- // EXPECT_FALSE((std::is_constructible<optional<I>, in_place_t>::value));
- // EXPECT_FALSE((std::is_constructible<optional<I>, const
- // in_place_t&>::value));
-}
-
-// template<U=T> optional(U&&);
-TEST(optionalTest, ValueConstructor) {
- constexpr optional<int> opt0(0);
- static_assert(opt0, "");
- static_assert(*opt0 == 0, "");
- EXPECT_TRUE((std::is_convertible<int, optional<int>>::value));
- // Copy initialization ( = "abc") won't work due to optional(optional&&)
- // is not constexpr. Use list initialization instead. This invokes
- // optional<ConstexprType>::optional<U>(U&&), with U = const char (&) [4],
- // which direct-initializes the ConstexprType value held by the optional
- // via ConstexprType::ConstexprType(const char*).
- constexpr optional<ConstexprType> opt1 = {"abc"};
- static_assert(opt1, "");
- static_assert(-1 == opt1->x, "");
- EXPECT_TRUE(
- (std::is_convertible<const char*, optional<ConstexprType>>::value));
- // direct initialization
- constexpr optional<ConstexprType> opt2{2};
- static_assert(opt2, "");
- static_assert(2 == opt2->x, "");
- EXPECT_FALSE((std::is_convertible<int, optional<ConstexprType>>::value));
-
- // this invokes optional<int>::optional(int&&)
- // NOTE: this has different behavior than assignment, e.g.
- // "opt3 = {};" clears the optional rather than setting the value to 0
- constexpr optional<int> opt3({});
- static_assert(opt3, "");
- static_assert(*opt3 == 0, "");
-
- // this invokes the move constructor with a default constructed optional
- // because non-template function is a better match than template function.
- optional<ConstexprType> opt4({});
- EXPECT_FALSE(!!opt4);
-}
-
-struct Implicit {};
-
-struct Explicit {};
-
-struct Convert {
- Convert(const Implicit&) // NOLINT(runtime/explicit)
- : implicit(true), move(false) {}
- Convert(Implicit&&) // NOLINT(runtime/explicit)
- : implicit(true), move(true) {}
- explicit Convert(const Explicit&) : implicit(false), move(false) {}
- explicit Convert(Explicit&&) : implicit(false), move(true) {}
-
- bool implicit;
- bool move;
-};
-
-struct ConvertFromOptional {
- ConvertFromOptional(const Implicit&) // NOLINT(runtime/explicit)
- : implicit(true), move(false), from_optional(false) {}
- ConvertFromOptional(Implicit&&) // NOLINT(runtime/explicit)
- : implicit(true), move(true), from_optional(false) {}
- ConvertFromOptional(const optional<Implicit>&) // NOLINT(runtime/explicit)
- : implicit(true), move(false), from_optional(true) {}
- ConvertFromOptional(optional<Implicit>&&) // NOLINT(runtime/explicit)
- : implicit(true), move(true), from_optional(true) {}
- explicit ConvertFromOptional(const Explicit&)
- : implicit(false), move(false), from_optional(false) {}
- explicit ConvertFromOptional(Explicit&&)
- : implicit(false), move(true), from_optional(false) {}
- explicit ConvertFromOptional(const optional<Explicit>&)
- : implicit(false), move(false), from_optional(true) {}
- explicit ConvertFromOptional(optional<Explicit>&&)
- : implicit(false), move(true), from_optional(true) {}
-
- bool implicit;
- bool move;
- bool from_optional;
-};
-
-TEST(optionalTest, ConvertingConstructor) {
- optional<Implicit> i_empty;
- optional<Implicit> i(in_place);
- optional<Explicit> e_empty;
- optional<Explicit> e(in_place);
- {
- // implicitly constructing optional<Convert> from optional<Implicit>
- optional<Convert> empty = i_empty;
- EXPECT_FALSE(!!empty);
- optional<Convert> opt_copy = i;
- EXPECT_TRUE(!!opt_copy);
- EXPECT_TRUE(opt_copy->implicit);
- EXPECT_FALSE(opt_copy->move);
- optional<Convert> opt_move = optional<Implicit>(in_place);
- EXPECT_TRUE(!!opt_move);
- EXPECT_TRUE(opt_move->implicit);
- EXPECT_TRUE(opt_move->move);
- }
- {
- // explicitly constructing optional<Convert> from optional<Explicit>
- optional<Convert> empty(e_empty);
- EXPECT_FALSE(!!empty);
- optional<Convert> opt_copy(e);
- EXPECT_TRUE(!!opt_copy);
- EXPECT_FALSE(opt_copy->implicit);
- EXPECT_FALSE(opt_copy->move);
- EXPECT_FALSE((std::is_convertible<const optional<Explicit>&,
- optional<Convert>>::value));
- optional<Convert> opt_move{optional<Explicit>(in_place)};
- EXPECT_TRUE(!!opt_move);
- EXPECT_FALSE(opt_move->implicit);
- EXPECT_TRUE(opt_move->move);
- EXPECT_FALSE(
- (std::is_convertible<optional<Explicit>&&, optional<Convert>>::value));
- }
- {
- // implicitly constructing optional<ConvertFromOptional> from
- // optional<Implicit> via ConvertFromOptional(optional<Implicit>&&)
- // check that ConvertFromOptional(Implicit&&) is NOT called
- static_assert(
- gtl::internal_optional::is_constructible_convertible_from_optional<
- ConvertFromOptional, Implicit>::value,
- "");
- optional<ConvertFromOptional> opt0 = i_empty;
- EXPECT_TRUE(!!opt0);
- EXPECT_TRUE(opt0->implicit);
- EXPECT_FALSE(opt0->move);
- EXPECT_TRUE(opt0->from_optional);
- optional<ConvertFromOptional> opt1 = optional<Implicit>();
- EXPECT_TRUE(!!opt1);
- EXPECT_TRUE(opt1->implicit);
- EXPECT_TRUE(opt1->move);
- EXPECT_TRUE(opt1->from_optional);
- }
- {
- // implicitly constructing optional<ConvertFromOptional> from
- // optional<Explicit> via ConvertFromOptional(optional<Explicit>&&)
- // check that ConvertFromOptional(Explicit&&) is NOT called
- optional<ConvertFromOptional> opt0(e_empty);
- EXPECT_TRUE(!!opt0);
- EXPECT_FALSE(opt0->implicit);
- EXPECT_FALSE(opt0->move);
- EXPECT_TRUE(opt0->from_optional);
- EXPECT_FALSE((std::is_convertible<const optional<Explicit>&,
- optional<ConvertFromOptional>>::value));
- optional<ConvertFromOptional> opt1{optional<Explicit>()};
- EXPECT_TRUE(!!opt1);
- EXPECT_FALSE(opt1->implicit);
- EXPECT_TRUE(opt1->move);
- EXPECT_TRUE(opt1->from_optional);
- EXPECT_FALSE((std::is_convertible<optional<Explicit>&&,
- optional<ConvertFromOptional>>::value));
- }
-}
-
-TEST(optionalTest, StructorBasic) {
- StructorListener listener;
- Listenable::listener = &listener;
- {
- optional<Listenable> empty;
- EXPECT_FALSE(!!empty);
- optional<Listenable> opt0(in_place);
- EXPECT_TRUE(!!opt0);
- optional<Listenable> opt1(in_place, 1);
- EXPECT_TRUE(!!opt1);
- optional<Listenable> opt2(in_place, 1, 2);
- EXPECT_TRUE(!!opt2);
- }
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.construct1);
- EXPECT_EQ(1, listener.construct2);
- EXPECT_EQ(3, listener.destruct);
-}
-
-TEST(optionalTest, CopyMoveStructor) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> original(in_place);
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(0, listener.copy);
- EXPECT_EQ(0, listener.move);
- optional<Listenable> copy(original);
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.copy);
- EXPECT_EQ(0, listener.move);
- optional<Listenable> move(std::move(original));
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.copy);
- EXPECT_EQ(1, listener.move);
-}
-
-TEST(optionalTest, ListInit) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> listinit1(in_place, {1});
- optional<Listenable> listinit2(in_place, {1, 2});
- EXPECT_EQ(2, listener.listinit);
-}
-
-TEST(optionalTest, AssignFromNullopt) {
- optional<int> opt(1);
- opt = nullopt;
- EXPECT_FALSE(!!opt);
-
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> opt1(in_place);
- opt1 = nullopt;
- EXPECT_FALSE(opt1);
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.destruct);
-
- EXPECT_TRUE((std::is_nothrow_assignable<optional<int>, nullopt_t>::value));
- EXPECT_TRUE(
- (std::is_nothrow_assignable<optional<Listenable>, nullopt_t>::value));
-}
-
-TEST(optionalTest, CopyAssignment) {
- const optional<int> empty, opt1 = 1, opt2 = 2;
- optional<int> empty_to_opt1, opt1_to_opt2, opt2_to_empty;
-
- EXPECT_FALSE(!!empty_to_opt1);
- empty_to_opt1 = empty;
- EXPECT_FALSE(!!empty_to_opt1);
- empty_to_opt1 = opt1;
- EXPECT_TRUE(!!empty_to_opt1);
- EXPECT_EQ(1, empty_to_opt1.value());
-
- EXPECT_FALSE(!!opt1_to_opt2);
- opt1_to_opt2 = opt1;
- EXPECT_TRUE(!!opt1_to_opt2);
- EXPECT_EQ(1, opt1_to_opt2.value());
- opt1_to_opt2 = opt2;
- EXPECT_TRUE(!!opt1_to_opt2);
- EXPECT_EQ(2, opt1_to_opt2.value());
-
- EXPECT_FALSE(!!opt2_to_empty);
- opt2_to_empty = opt2;
- EXPECT_TRUE(!!opt2_to_empty);
- EXPECT_EQ(2, opt2_to_empty.value());
- opt2_to_empty = empty;
- EXPECT_FALSE(!!opt2_to_empty);
-
- EXPECT_TRUE(std::is_copy_assignable<optional<Copyable>>::value);
- EXPECT_FALSE(std::is_copy_assignable<optional<MoveableThrow>>::value);
- EXPECT_FALSE(std::is_copy_assignable<optional<MoveableNoThrow>>::value);
- EXPECT_FALSE(std::is_copy_assignable<optional<NonMovable>>::value);
-}
-
-TEST(optionalTest, MoveAssignment) {
- StructorListener listener;
- Listenable::listener = &listener;
-
- optional<Listenable> empty1, empty2, set1(in_place), set2(in_place);
- EXPECT_EQ(2, listener.construct0);
- optional<Listenable> empty_to_empty, empty_to_set, set_to_empty(in_place),
- set_to_set(in_place);
- EXPECT_EQ(4, listener.construct0);
- empty_to_empty = std::move(empty1);
- empty_to_set = std::move(set1);
- set_to_empty = std::move(empty2);
- set_to_set = std::move(set2);
- EXPECT_EQ(0, listener.copy);
- EXPECT_EQ(1, listener.move);
- EXPECT_EQ(1, listener.destruct);
- EXPECT_EQ(1, listener.move_assign);
-
- EXPECT_TRUE(std::is_move_assignable<optional<Copyable>>::value);
- EXPECT_TRUE(std::is_move_assignable<optional<MoveableThrow>>::value);
- EXPECT_TRUE(std::is_move_assignable<optional<MoveableNoThrow>>::value);
- EXPECT_FALSE(std::is_move_assignable<optional<NonMovable>>::value);
-
- EXPECT_FALSE(std::is_nothrow_move_assignable<optional<MoveableThrow>>::value);
- EXPECT_TRUE(
- std::is_nothrow_move_assignable<optional<MoveableNoThrow>>::value);
-}
-
-struct NoConvertToOptional {
- // disable implicit conversion from const NoConvertToOptional&
- // to optional<NoConvertToOptional>.
- NoConvertToOptional(const NoConvertToOptional&) = delete;
-};
-
-struct CopyConvert {
- CopyConvert(const NoConvertToOptional&);
- CopyConvert& operator=(const CopyConvert&) = delete;
- CopyConvert& operator=(const NoConvertToOptional&);
-};
-
-struct CopyConvertFromOptional {
- CopyConvertFromOptional(const NoConvertToOptional&);
- CopyConvertFromOptional(const optional<NoConvertToOptional>&);
- CopyConvertFromOptional& operator=(const CopyConvertFromOptional&) = delete;
- CopyConvertFromOptional& operator=(const NoConvertToOptional&);
- CopyConvertFromOptional& operator=(const optional<NoConvertToOptional>&);
-};
-
-struct MoveConvert {
- MoveConvert(NoConvertToOptional&&);
- MoveConvert& operator=(const MoveConvert&) = delete;
- MoveConvert& operator=(NoConvertToOptional&&);
-};
-
-struct MoveConvertFromOptional {
- MoveConvertFromOptional(NoConvertToOptional&&);
- MoveConvertFromOptional(optional<NoConvertToOptional>&&);
- MoveConvertFromOptional& operator=(const MoveConvertFromOptional&) = delete;
- MoveConvertFromOptional& operator=(NoConvertToOptional&&);
- MoveConvertFromOptional& operator=(optional<NoConvertToOptional>&&);
-};
-
-// template <class U = T> optional<T>& operator=(U&& v);
-TEST(optionalTest, ValueAssignment) {
- optional<int> opt;
- EXPECT_FALSE(!!opt);
- opt = 42;
- EXPECT_TRUE(!!opt);
- EXPECT_EQ(42, opt.value());
- opt = nullopt;
- EXPECT_FALSE(!!opt);
- opt = 42;
- EXPECT_TRUE(!!opt);
- EXPECT_EQ(42, opt.value());
- opt = 43;
- EXPECT_TRUE(!!opt);
- EXPECT_EQ(43, opt.value());
- opt = {}; // this should clear optional
- EXPECT_FALSE(!!opt);
-
- opt = {44};
- EXPECT_TRUE(!!opt);
- EXPECT_EQ(44, opt.value());
-
- // U = const NoConvertToOptional&
- EXPECT_TRUE((std::is_assignable<optional<CopyConvert>&,
- const NoConvertToOptional&>::value));
- // U = const optional<NoConvertToOptional>&
- EXPECT_TRUE((std::is_assignable<optional<CopyConvertFromOptional>&,
- const NoConvertToOptional&>::value));
- // U = const NoConvertToOptional& triggers SFINAE because
- // std::is_constructible_v<MoveConvert, const NoConvertToOptional&> is false
- EXPECT_FALSE((std::is_assignable<optional<MoveConvert>&,
- const NoConvertToOptional&>::value));
- // U = NoConvertToOptional
- EXPECT_TRUE((std::is_assignable<optional<MoveConvert>&,
- NoConvertToOptional&&>::value));
- // U = const NoConvertToOptional& triggers SFINAE because
- // std::is_constructible_v<MoveConvertFromOptional, const
- // NoConvertToOptional&> is false
- EXPECT_FALSE((std::is_assignable<optional<MoveConvertFromOptional>&,
- const NoConvertToOptional&>::value));
- // U = NoConvertToOptional
- EXPECT_TRUE((std::is_assignable<optional<MoveConvertFromOptional>&,
- NoConvertToOptional&&>::value));
- // U = const optional<NoConvertToOptional>&
- EXPECT_TRUE(
- (std::is_assignable<optional<CopyConvertFromOptional>&,
- const optional<NoConvertToOptional>&>::value));
- // U = optional<NoConvertToOptional>
- EXPECT_TRUE((std::is_assignable<optional<MoveConvertFromOptional>&,
- optional<NoConvertToOptional>&&>::value));
-}
-
-// template <class U> optional<T>& operator=(const optional<U>& rhs);
-// template <class U> optional<T>& operator=(optional<U>&& rhs);
-TEST(optionalTest, ConvertingAssignment) {
- optional<int> opt_i;
- optional<char> opt_c('c');
- opt_i = opt_c;
- EXPECT_TRUE(!!opt_i);
- EXPECT_EQ(*opt_c, *opt_i);
- opt_i = optional<char>();
- EXPECT_FALSE(!!opt_i);
- opt_i = optional<char>('d');
- EXPECT_TRUE(!!opt_i);
- EXPECT_EQ('d', *opt_i);
-
- optional<string> opt_str;
- optional<const char*> opt_cstr("abc");
- opt_str = opt_cstr;
- EXPECT_TRUE(!!opt_str);
- EXPECT_EQ(string("abc"), *opt_str);
- opt_str = optional<const char*>();
- EXPECT_FALSE(!!opt_str);
- opt_str = optional<const char*>("def");
- EXPECT_TRUE(!!opt_str);
- EXPECT_EQ(string("def"), *opt_str);
-
- // operator=(const optional<U>&) with U = NoConvertToOptional
- EXPECT_TRUE(
- (std::is_assignable<optional<CopyConvert>,
- const optional<NoConvertToOptional>&>::value));
- // operator=(const optional<U>&) with U = NoConvertToOptional
- // triggers SFINAE because
- // std::is_constructible_v<MoveConvert, const NoConvertToOptional&> is false
- EXPECT_FALSE(
- (std::is_assignable<optional<MoveConvert>&,
- const optional<NoConvertToOptional>&>::value));
- // operator=(optional<U>&&) with U = NoConvertToOptional
- EXPECT_TRUE((std::is_assignable<optional<MoveConvert>&,
- optional<NoConvertToOptional>&&>::value));
- // operator=(const optional<U>&) with U = NoConvertToOptional triggers SFINAE
- // because std::is_constructible_v<MoveConvertFromOptional,
- // const NoConvertToOptional&> is false.
- // operator=(U&&) with U = const optional<NoConverToOptional>& triggers SFINAE
- // because std::is_constructible<MoveConvertFromOptional,
- // optional<NoConvertToOptional>&&> is true.
- EXPECT_FALSE(
- (std::is_assignable<optional<MoveConvertFromOptional>&,
- const optional<NoConvertToOptional>&>::value));
-}
-
-TEST(optionalTest, ResetAndHasValue) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> opt;
- EXPECT_FALSE(!!opt);
- EXPECT_FALSE(opt.has_value());
- opt.emplace();
- EXPECT_TRUE(!!opt);
- EXPECT_TRUE(opt.has_value());
- opt.reset();
- EXPECT_FALSE(!!opt);
- EXPECT_FALSE(opt.has_value());
- EXPECT_EQ(1, listener.destruct);
- opt.reset();
- EXPECT_FALSE(!!opt);
- EXPECT_FALSE(opt.has_value());
-
- constexpr optional<int> empty;
- static_assert(!empty.has_value(), "");
- constexpr optional<int> nonempty(1);
- static_assert(nonempty.has_value(), "");
-}
-
-TEST(optionalTest, Emplace) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> opt;
- EXPECT_FALSE(!!opt);
- opt.emplace(1);
- EXPECT_TRUE(!!opt);
- opt.emplace(1, 2);
- EXPECT_EQ(1, listener.construct1);
- EXPECT_EQ(1, listener.construct2);
- EXPECT_EQ(1, listener.destruct);
-}
-
-TEST(optionalTest, ListEmplace) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> opt;
- EXPECT_FALSE(!!opt);
- opt.emplace({1});
- EXPECT_TRUE(!!opt);
- opt.emplace({1, 2});
- EXPECT_EQ(2, listener.listinit);
- EXPECT_EQ(1, listener.destruct);
-}
-
-TEST(optionalTest, Swap) {
- optional<int> opt_empty, opt1 = 1, opt2 = 2;
- EXPECT_FALSE(!!opt_empty);
- EXPECT_TRUE(!!opt1);
- EXPECT_EQ(1, opt1.value());
- EXPECT_TRUE(!!opt2);
- EXPECT_EQ(2, opt2.value());
- swap(opt_empty, opt1);
- EXPECT_FALSE(!!opt1);
- EXPECT_TRUE(!!opt_empty);
- EXPECT_EQ(1, opt_empty.value());
- EXPECT_TRUE(!!opt2);
- EXPECT_EQ(2, opt2.value());
- swap(opt_empty, opt1);
- EXPECT_FALSE(!!opt_empty);
- EXPECT_TRUE(!!opt1);
- EXPECT_EQ(1, opt1.value());
- EXPECT_TRUE(!!opt2);
- EXPECT_EQ(2, opt2.value());
- swap(opt1, opt2);
- EXPECT_FALSE(!!opt_empty);
- EXPECT_TRUE(!!opt1);
- EXPECT_EQ(2, opt1.value());
- EXPECT_TRUE(!!opt2);
- EXPECT_EQ(1, opt2.value());
-
- EXPECT_TRUE(noexcept(opt1.swap(opt2)));
- EXPECT_TRUE(noexcept(swap(opt1, opt2)));
-}
-
-TEST(optionalTest, PointerStuff) {
- optional<string> opt(in_place, "foo");
- EXPECT_EQ("foo", *opt);
- const auto& opt_const = opt;
- EXPECT_EQ("foo", *opt_const);
- EXPECT_EQ(opt->size(), 3);
- EXPECT_EQ(opt_const->size(), 3);
-
- constexpr optional<ConstexprType> opt1(1);
- static_assert(opt1->x == 1, "");
-}
-
-// gcc has a bug pre 4.9 where it doesn't do correct overload resolution
-// between rvalue reference qualified member methods. Skip that test to make
-// the build green again when using the old compiler.
-#if defined(__GNUC__) && !defined(__clang__)
-#if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 9)
-#define SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG
-#endif
-#endif
-
-TEST(optionalTest, Value) {
- using O = optional<string>;
- using CO = const optional<string>;
- O lvalue(in_place, "lvalue");
- CO clvalue(in_place, "clvalue");
- EXPECT_EQ("lvalue", lvalue.value());
- EXPECT_EQ("clvalue", clvalue.value());
- EXPECT_EQ("xvalue", O(in_place, "xvalue").value());
-#ifndef SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG
- EXPECT_EQ("cxvalue", CO(in_place, "cxvalue").value());
- EXPECT_EQ("&", TypeQuals(lvalue.value()));
- EXPECT_EQ("c&", TypeQuals(clvalue.value()));
- EXPECT_EQ("&&", TypeQuals(O(in_place, "xvalue").value()));
- EXPECT_EQ("c&&", TypeQuals(CO(in_place, "cxvalue").value()));
-#endif
-}
-
-TEST(optionalTest, DerefOperator) {
- using O = optional<string>;
- using CO = const optional<string>;
- O lvalue(in_place, "lvalue");
- CO clvalue(in_place, "clvalue");
- EXPECT_EQ("lvalue", *lvalue);
- EXPECT_EQ("clvalue", *clvalue);
- EXPECT_EQ("xvalue", *O(in_place, "xvalue"));
-#ifndef SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG
- EXPECT_EQ("cxvalue", *CO(in_place, "cxvalue"));
- EXPECT_EQ("&", TypeQuals(*lvalue));
- EXPECT_EQ("c&", TypeQuals(*clvalue));
- EXPECT_EQ("&&", TypeQuals(*O(in_place, "xvalue")));
- EXPECT_EQ("c&&", TypeQuals(*CO(in_place, "cxvalue")));
-#endif
-
- constexpr optional<int> opt1(1);
- static_assert(*opt1 == 1, "");
-
-#if !defined(SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG) && \
- !defined(SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG)
- using COI = const optional<int>;
- static_assert(*COI(2) == 2, "");
-#endif
-}
-
-TEST(optionalTest, ValueOr) {
- optional<double> opt_empty, opt_set = 1.2;
- EXPECT_EQ(42.0, opt_empty.value_or(42));
- EXPECT_EQ(1.2, opt_set.value_or(42));
- EXPECT_EQ(42.0, optional<double>().value_or(42));
- EXPECT_EQ(1.2, optional<double>(1.2).value_or(42));
-
-#ifndef SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG
- constexpr optional<double> copt_empty;
- static_assert(42.0 == copt_empty.value_or(42), "");
-
- constexpr optional<double> copt_set = {1.2};
- static_assert(1.2 == copt_set.value_or(42), "");
-
- using COD = const optional<double>;
- static_assert(42.0 == COD().value_or(42), "");
- static_assert(1.2 == COD(1.2).value_or(42), "");
-#endif
-}
-
-// make_optional cannot be constexpr until C++17
-TEST(optionalTest, make_optional) {
- auto opt_int = make_optional(42);
- EXPECT_TRUE((std::is_same<decltype(opt_int), optional<int>>::value));
- EXPECT_EQ(42, opt_int);
-
- StructorListener listener;
- Listenable::listener = &listener;
-
- optional<Listenable> opt0 = make_optional<Listenable>();
- EXPECT_EQ(1, listener.construct0);
- optional<Listenable> opt1 = make_optional<Listenable>(1);
- EXPECT_EQ(1, listener.construct1);
- optional<Listenable> opt2 = make_optional<Listenable>(1, 2);
- EXPECT_EQ(1, listener.construct2);
- optional<Listenable> opt3 = make_optional<Listenable>({1});
- optional<Listenable> opt4 = make_optional<Listenable>({1, 2});
- EXPECT_EQ(2, listener.listinit);
-}
-
-TEST(optionalTest, Comparisons) {
- optional<int> ae, be, a2 = 2, b2 = 2, a4 = 4, b4 = 4;
-
-#define optionalTest_Comparisons_EXPECT_LESS(x, y) \
- EXPECT_FALSE((x) == (y)); \
- EXPECT_TRUE((x) != (y)); \
- EXPECT_TRUE((x) < (y)); \
- EXPECT_FALSE((x) > (y)); \
- EXPECT_TRUE((x) <= (y)); \
- EXPECT_FALSE((x) >= (y));
-
-#define optionalTest_Comparisons_EXPECT_SAME(x, y) \
- EXPECT_TRUE((x) == (y)); \
- EXPECT_FALSE((x) != (y)); \
- EXPECT_FALSE((x) < (y)); \
- EXPECT_FALSE((x) > (y)); \
- EXPECT_TRUE((x) <= (y)); \
- EXPECT_TRUE((x) >= (y));
-
-#define optionalTest_Comparisons_EXPECT_GREATER(x, y) \
- EXPECT_FALSE((x) == (y)); \
- EXPECT_TRUE((x) != (y)); \
- EXPECT_FALSE((x) < (y)); \
- EXPECT_TRUE((x) > (y)); \
- EXPECT_FALSE((x) <= (y)); \
- EXPECT_TRUE((x) >= (y));
-
- // LHS: nullopt, ae, a2, 3, a4
- // RHS: nullopt, be, b2, 3, b4
-
- // optionalTest_Comparisons_EXPECT_NOT_TO_WORK(nullopt,nullopt);
- optionalTest_Comparisons_EXPECT_SAME(nullopt, be);
- optionalTest_Comparisons_EXPECT_LESS(nullopt, b2);
- // optionalTest_Comparisons_EXPECT_NOT_TO_WORK(nullopt,3);
- optionalTest_Comparisons_EXPECT_LESS(nullopt, b4);
-
- optionalTest_Comparisons_EXPECT_SAME(ae, nullopt);
- optionalTest_Comparisons_EXPECT_SAME(ae, be);
- optionalTest_Comparisons_EXPECT_LESS(ae, b2);
- optionalTest_Comparisons_EXPECT_LESS(ae, 3);
- optionalTest_Comparisons_EXPECT_LESS(ae, b4);
-
- optionalTest_Comparisons_EXPECT_GREATER(a2, nullopt);
- optionalTest_Comparisons_EXPECT_GREATER(a2, be);
- optionalTest_Comparisons_EXPECT_SAME(a2, b2);
- optionalTest_Comparisons_EXPECT_LESS(a2, 3);
- optionalTest_Comparisons_EXPECT_LESS(a2, b4);
-
- // optionalTest_Comparisons_EXPECT_NOT_TO_WORK(3,nullopt);
- optionalTest_Comparisons_EXPECT_GREATER(3, be);
- optionalTest_Comparisons_EXPECT_GREATER(3, b2);
- optionalTest_Comparisons_EXPECT_SAME(3, 3);
- optionalTest_Comparisons_EXPECT_LESS(3, b4);
-
- optionalTest_Comparisons_EXPECT_GREATER(a4, nullopt);
- optionalTest_Comparisons_EXPECT_GREATER(a4, be);
- optionalTest_Comparisons_EXPECT_GREATER(a4, b2);
- optionalTest_Comparisons_EXPECT_GREATER(a4, 3);
- optionalTest_Comparisons_EXPECT_SAME(a4, b4);
-}
-
-TEST(optionalTest, SwapRegression) {
- StructorListener listener;
- Listenable::listener = &listener;
-
- {
- optional<Listenable> a;
- optional<Listenable> b(in_place);
- a.swap(b);
- }
-
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.move);
- EXPECT_EQ(2, listener.destruct);
-
- {
- optional<Listenable> a(in_place);
- optional<Listenable> b;
- a.swap(b);
- }
-
- EXPECT_EQ(2, listener.construct0);
- EXPECT_EQ(2, listener.move);
- EXPECT_EQ(4, listener.destruct);
-}
-
-TEST(optionalTest, BigStringLeakCheck) {
- constexpr size_t n = 1 << 16;
-
- using OS = optional<string>;
-
- OS a;
- OS b = nullopt;
- OS c = string(n, 'c');
- string sd(n, 'd');
- OS d = sd;
- OS e(in_place, n, 'e');
- OS f;
- f.emplace(n, 'f');
-
- OS ca(a);
- OS cb(b);
- OS cc(c);
- OS cd(d);
- OS ce(e);
-
- OS oa;
- OS ob = nullopt;
- OS oc = string(n, 'c');
- string sod(n, 'd');
- OS od = sod;
- OS oe(in_place, n, 'e');
- OS of;
- of.emplace(n, 'f');
-
- OS ma(std::move(oa));
- OS mb(std::move(ob));
- OS mc(std::move(oc));
- OS md(std::move(od));
- OS me(std::move(oe));
- OS mf(std::move(of));
-
- OS aa1;
- OS ab1 = nullopt;
- OS ac1 = string(n, 'c');
- string sad1(n, 'd');
- OS ad1 = sad1;
- OS ae1(in_place, n, 'e');
- OS af1;
- af1.emplace(n, 'f');
-
- OS aa2;
- OS ab2 = nullopt;
- OS ac2 = string(n, 'c');
- string sad2(n, 'd');
- OS ad2 = sad2;
- OS ae2(in_place, n, 'e');
- OS af2;
- af2.emplace(n, 'f');
-
- aa1 = af2;
- ab1 = ae2;
- ac1 = ad2;
- ad1 = ac2;
- ae1 = ab2;
- af1 = aa2;
-
- OS aa3;
- OS ab3 = nullopt;
- OS ac3 = string(n, 'c');
- string sad3(n, 'd');
- OS ad3 = sad3;
- OS ae3(in_place, n, 'e');
- OS af3;
- af3.emplace(n, 'f');
-
- aa3 = nullopt;
- ab3 = nullopt;
- ac3 = nullopt;
- ad3 = nullopt;
- ae3 = nullopt;
- af3 = nullopt;
-
- OS aa4;
- OS ab4 = nullopt;
- OS ac4 = string(n, 'c');
- string sad4(n, 'd');
- OS ad4 = sad4;
- OS ae4(in_place, n, 'e');
- OS af4;
- af4.emplace(n, 'f');
-
- aa4 = OS(in_place, n, 'a');
- ab4 = OS(in_place, n, 'b');
- ac4 = OS(in_place, n, 'c');
- ad4 = OS(in_place, n, 'd');
- ae4 = OS(in_place, n, 'e');
- af4 = OS(in_place, n, 'f');
-
- OS aa5;
- OS ab5 = nullopt;
- OS ac5 = string(n, 'c');
- string sad5(n, 'd');
- OS ad5 = sad5;
- OS ae5(in_place, n, 'e');
- OS af5;
- af5.emplace(n, 'f');
-
- string saa5(n, 'a');
- string sab5(n, 'a');
- string sac5(n, 'a');
- string sad52(n, 'a');
- string sae5(n, 'a');
- string saf5(n, 'a');
-
- aa5 = saa5;
- ab5 = sab5;
- ac5 = sac5;
- ad5 = sad52;
- ae5 = sae5;
- af5 = saf5;
-
- OS aa6;
- OS ab6 = nullopt;
- OS ac6 = string(n, 'c');
- string sad6(n, 'd');
- OS ad6 = sad6;
- OS ae6(in_place, n, 'e');
- OS af6;
- af6.emplace(n, 'f');
-
- aa6 = string(n, 'a');
- ab6 = string(n, 'b');
- ac6 = string(n, 'c');
- ad6 = string(n, 'd');
- ae6 = string(n, 'e');
- af6 = string(n, 'f');
-
- OS aa7;
- OS ab7 = nullopt;
- OS ac7 = string(n, 'c');
- string sad7(n, 'd');
- OS ad7 = sad7;
- OS ae7(in_place, n, 'e');
- OS af7;
- af7.emplace(n, 'f');
-
- aa7.emplace(n, 'A');
- ab7.emplace(n, 'B');
- ac7.emplace(n, 'C');
- ad7.emplace(n, 'D');
- ae7.emplace(n, 'E');
- af7.emplace(n, 'F');
-}
-
-TEST(optionalTest, MoveAssignRegression) {
- StructorListener listener;
- Listenable::listener = &listener;
-
- {
- optional<Listenable> a;
- Listenable b;
- a = std::move(b);
- }
-
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.move);
- EXPECT_EQ(2, listener.destruct);
-}
-
-TEST(optionalTest, ValueType) {
- EXPECT_TRUE((std::is_same<optional<int>::value_type, int>::value));
- EXPECT_TRUE((std::is_same<optional<string>::value_type, string>::value));
- EXPECT_FALSE((std::is_same<optional<int>::value_type, nullopt_t>::value));
-}
-
-TEST(optionalTest, Hash) {
- std::hash<optional<int>> hash;
- std::set<size_t> hashcodes;
- hashcodes.insert(hash(nullopt));
- for (int i = 0; i < 100; ++i) {
- hashcodes.insert(hash(i));
- }
- EXPECT_GT(hashcodes.size(), 90);
-}
-
-struct MoveMeNoThrow {
- MoveMeNoThrow() : x(0) {}
- MoveMeNoThrow(const MoveMeNoThrow& other) : x(other.x) {
- LOG(FATAL) << "Should not be called.";
- }
- MoveMeNoThrow(MoveMeNoThrow&& other) noexcept : x(other.x) {}
- int x;
-};
-
-struct MoveMeThrow {
- MoveMeThrow() : x(0) {}
- MoveMeThrow(const MoveMeThrow& other) : x(other.x) {}
- MoveMeThrow(MoveMeThrow&& other) : x(other.x) {}
- int x;
-};
-
-TEST(optionalTest, NoExcept) {
- static_assert(
- std::is_nothrow_move_constructible<optional<MoveMeNoThrow>>::value, "");
- static_assert(
- !std::is_nothrow_move_constructible<optional<MoveMeThrow>>::value, "");
- std::vector<optional<MoveMeNoThrow>> v;
- v.reserve(10);
- for (int i = 0; i < 10; ++i) v.emplace_back();
-}
-
-} // namespace
-} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/block_builder.h b/tensorflow/core/lib/io/block_builder.h
index e2927689d2..117b6a0bb8 100644
--- a/tensorflow/core/lib/io/block_builder.h
+++ b/tensorflow/core/lib/io/block_builder.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <stdint.h>
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace table {
diff --git a/tensorflow/core/lib/io/path.h b/tensorflow/core/lib/io/path.h
index e3649fd0c9..38fb0c5d86 100644
--- a/tensorflow/core/lib/io/path.h
+++ b/tensorflow/core/lib/io/path.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_LIB_IO_PATH_H_
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace io {
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
index c24628be57..f93ebea771 100644
--- a/tensorflow/core/lib/io/record_reader.cc
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -109,9 +109,6 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) {
}
Status RecordReader::ReadRecord(uint64* offset, string* record) {
- static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
- static const size_t kFooterSize = sizeof(uint32);
-
// Position the input stream.
int64 curr_pos = input_stream_->Tell();
int64 desired_pos = static_cast<int64>(*offset);
diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h
index c05f9e1b36..11af1366b0 100644
--- a/tensorflow/core/lib/io/record_reader.h
+++ b/tensorflow/core/lib/io/record_reader.h
@@ -58,6 +58,14 @@ class RecordReaderOptions {
// Note: this class is not thread safe; external synchronization required.
class RecordReader {
public:
+ // Format of a single record:
+ // uint64 length
+ // uint32 masked crc of length
+ // byte data[length]
+ // uint32 masked crc of data
+ static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
+ static const size_t kFooterSize = sizeof(uint32);
+
// Create a reader that will return log records from "*file".
// "*file" must remain live while this Reader is in use.
explicit RecordReader(
diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc
index 6e71d23e71..2c6db2487e 100644
--- a/tensorflow/core/lib/io/record_writer.cc
+++ b/tensorflow/core/lib/io/record_writer.cc
@@ -88,10 +88,6 @@ RecordWriter::~RecordWriter() {
}
}
-static uint32 MaskedCrc(const char* data, size_t n) {
- return crc32c::Mask(crc32c::Value(data, n));
-}
-
Status RecordWriter::WriteRecord(StringPiece data) {
if (dest_ == nullptr) {
return Status(::tensorflow::error::FAILED_PRECONDITION,
@@ -102,13 +98,10 @@ Status RecordWriter::WriteRecord(StringPiece data) {
// uint32 masked crc of length
// byte data[length]
// uint32 masked crc of data
- char header[sizeof(uint64) + sizeof(uint32)];
- core::EncodeFixed64(header + 0, data.size());
- core::EncodeFixed32(header + sizeof(uint64),
- MaskedCrc(header, sizeof(uint64)));
- char footer[sizeof(uint32)];
- core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size()));
-
+ char header[kHeaderSize];
+ char footer[kFooterSize];
+ PopulateHeader(header, data.data(), data.size());
+ PopulateFooter(footer, data.data(), data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
TF_RETURN_IF_ERROR(dest_->Append(data));
return dest_->Append(StringPiece(footer, sizeof(footer)));
diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h
index 2f6afa5487..1212e1fafb 100644
--- a/tensorflow/core/lib/io/record_writer.h
+++ b/tensorflow/core/lib/io/record_writer.h
@@ -16,8 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_
#define TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_
+#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
#if !defined(IS_SLIM_BUILD)
#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
@@ -41,12 +43,20 @@ class RecordWriterOptions {
// Options specific to zlib compression.
#if !defined(IS_SLIM_BUILD)
- ZlibCompressionOptions zlib_options;
+ tensorflow::io::ZlibCompressionOptions zlib_options;
#endif // IS_SLIM_BUILD
};
class RecordWriter {
public:
+ // Format of a single record:
+ // uint64 length
+ // uint32 masked crc of length
+ // byte data[length]
+ // uint32 masked crc of data
+ static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
+ static const size_t kFooterSize = sizeof(uint32);
+
// Create a writer that will append data to "*dest".
// "*dest" must be initially empty.
// "*dest" must remain live while this Writer is in use.
@@ -72,13 +82,35 @@ class RecordWriter {
// are invalid.
Status Close();
+ // Utility method to populate TFRecord headers. Populates record-header in
+ // "header[0,kHeaderSize-1]". The record-header is based on data[0, n-1].
+ inline static void PopulateHeader(char* header, const char* data, size_t n);
+
+ // Utility method to populate TFRecord footers. Populates record-footer in
+ // "footer[0,kFooterSize-1]". The record-footer is based on data[0, n-1].
+ inline static void PopulateFooter(char* footer, const char* data, size_t n);
+
private:
WritableFile* dest_;
RecordWriterOptions options_;
+ inline static uint32 MaskedCrc(const char* data, size_t n) {
+ return crc32c::Mask(crc32c::Value(data, n));
+ }
+
TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter);
};
+void RecordWriter::PopulateHeader(char* header, const char* data, size_t n) {
+ core::EncodeFixed64(header + 0, n);
+ core::EncodeFixed32(header + sizeof(uint64),
+ MaskedCrc(header, sizeof(uint64)));
+}
+
+void RecordWriter::PopulateFooter(char* footer, const char* data, size_t n) {
+ core::EncodeFixed32(footer, MaskedCrc(data, n));
+}
+
} // namespace io
} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc
index da514bd21c..946d7188d3 100644
--- a/tensorflow/core/lib/io/recordio_test.cc
+++ b/tensorflow/core/lib/io/recordio_test.cc
@@ -58,7 +58,7 @@ class StringDest : public WritableFile {
Status Close() override { return Status::OK(); }
Status Flush() override { return Status::OK(); }
Status Sync() override { return Status::OK(); }
- Status Append(const StringPiece& slice) override {
+ Status Append(StringPiece slice) override {
contents_->append(slice.data(), slice.size());
return Status::OK();
}
diff --git a/tensorflow/core/lib/io/table_test.cc b/tensorflow/core/lib/io/table_test.cc
index 877ac40f1c..9cebbf40c6 100644
--- a/tensorflow/core/lib/io/table_test.cc
+++ b/tensorflow/core/lib/io/table_test.cc
@@ -98,7 +98,7 @@ class StringSink : public WritableFile {
Status Flush() override { return Status::OK(); }
Status Sync() override { return Status::OK(); }
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
contents_.append(data.data(), data.size());
return Status::OK();
}
diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.cc b/tensorflow/core/lib/io/zlib_outputbuffer.cc
index 84b47c171f..cba139e6ad 100644
--- a/tensorflow/core/lib/io/zlib_outputbuffer.cc
+++ b/tensorflow/core/lib/io/zlib_outputbuffer.cc
@@ -143,7 +143,7 @@ Status ZlibOutputBuffer::FlushOutputBufferToFile() {
return Status::OK();
}
-Status ZlibOutputBuffer::Append(const StringPiece& data) {
+Status ZlibOutputBuffer::Append(StringPiece data) {
// If there is sufficient free space in z_stream_input_ to fit data we
// add it there and return.
// If there isn't enough space we deflate the existing contents of
diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.h b/tensorflow/core/lib/io/zlib_outputbuffer.h
index 3d86d89a99..ccad2fda44 100644
--- a/tensorflow/core/lib/io/zlib_outputbuffer.h
+++ b/tensorflow/core/lib/io/zlib_outputbuffer.h
@@ -62,7 +62,7 @@ class ZlibOutputBuffer : public WritableFile {
// to file when the buffer is full.
//
// To immediately write contents to file call `Flush()`.
- Status Append(const StringPiece& data) override;
+ Status Append(StringPiece data) override;
// Deflates any cached input and writes all output to file.
Status Flush() override;
diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h
index c204d52cfe..9e4e1989dd 100644
--- a/tensorflow/core/lib/monitoring/collection_registry.h
+++ b/tensorflow/core/lib/monitoring/collection_registry.h
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace monitoring {
diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h
index 756e5c2af8..bc4365e439 100644
--- a/tensorflow/core/lib/monitoring/metric_def.h
+++ b/tensorflow/core/lib/monitoring/metric_def.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace monitoring {
diff --git a/tensorflow/core/lib/png/png_io.h b/tensorflow/core/lib/png/png_io.h
index bb5d20fb68..c876c5156a 100644
--- a/tensorflow/core/lib/png/png_io.h
+++ b/tensorflow/core/lib/png/png_io.h
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/png.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace png {
diff --git a/tensorflow/core/lib/wav/wav_io.cc b/tensorflow/core/lib/wav/wav_io.cc
index 36d939e061..c536b5688e 100644
--- a/tensorflow/core/lib/wav/wav_io.cc
+++ b/tensorflow/core/lib/wav/wav_io.cc
@@ -232,6 +232,11 @@ Status DecodeLin16WaveAsFloatVector(const string& wav_string,
"Bad audio format for WAV: Expected 1 (PCM), but got", audio_format);
}
TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset));
+ if (*channel_count < 1) {
+ return errors::InvalidArgument(
+ "Bad number of channels for WAV: Expected at least 1, but got ",
+ *channel_count);
+ }
TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset));
uint32 bytes_per_second;
TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset));
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
index 01452b3e85..7c4184bff4 100644
--- a/tensorflow/core/ops/boosted_trees_ops.cc
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -22,6 +22,10 @@ limitations under the License.
namespace tensorflow {
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
REGISTER_RESOURCE_HANDLE_OP(BoostedTreesEnsembleResource);
REGISTER_OP("IsBoostedTreesEnsembleInitialized")
@@ -354,4 +358,125 @@ REGISTER_OP("BoostedTreesCenterBias")
return Status::OK();
});
+REGISTER_RESOURCE_HANDLE_OP(BoostedTreesQuantileStreamResource);
+
+REGISTER_OP("IsBoostedTreesQuantileStreamResourceInitialized")
+ .Input("quantile_stream_resource_handle: resource")
+ .Output("is_initialized: bool")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesCreateQuantileStreamResource")
+ .Attr("max_elements: int = 1099511627776") // 1 << 40
+ .Input("quantile_stream_resource_handle: resource")
+ .Input("epsilon: float")
+ .Input("num_streams: int64")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesMakeQuantileSummaries")
+ .Attr("num_features: int >= 0")
+ .Input("float_values: num_features * float")
+ .Input("example_weights: float")
+ .Input("epsilon: float")
+ .Output("summaries: num_features * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ ShapeHandle example_weights_shape;
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(num_features), 1, &example_weights_shape));
+ for (int i = 0; i < num_features; ++i) {
+ ShapeHandle feature_shape;
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
+ c->Dim(example_weights_shape, 0),
+ &unused_dim));
+ // the columns are value, weight, min_rank, max_rank.
+ c->set_output(i, c->MakeShape({c->UnknownDim(), 4}));
+ }
+ // epsilon must be a scalar.
+ ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(num_features + 1), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceAddSummaries")
+ .Attr("num_features: int >= 0")
+ .Input("quantile_stream_resource_handle: resource")
+ .Input("summaries: num_features * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ // resource handle must be a scalar.
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ // each summary must be rank 2.
+ for (int i = 1; i < num_features + 1; i++) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &unused_input));
+ }
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceFlush")
+ .Attr("generate_quantiles: bool = False")
+ .Input("quantile_stream_resource_handle: resource")
+ .Input("num_buckets: int64")
+ .SetShapeFn([](InferenceContext* c) {
+ // All the inputs are scalars.
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
+ .Attr("num_features: int >= 0")
+ .Input("quantile_stream_resource_handle: resource")
+ .Output("bucket_boundaries: num_features * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ shape_inference::ShapeHandle unused_input;
+ // resource handle must be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ for (int i = 0; i < num_features; i++) {
+ c->set_output(i, c->Vector(c->UnknownDim()));
+ }
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesBucketize")
+ .Attr("num_features: int >= 0")
+ .Input("float_values: num_features * float")
+ .Input("bucket_boundaries: num_features * float")
+ .Output("buckets: num_features * int32")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ ShapeHandle feature_shape;
+ DimensionHandle unused_dim;
+ for (int i = 0; i < num_features; i++) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
+ c->Dim(c->input(0), 0), &unused_dim));
+ }
+ // Bucketized result should have same dimension as input.
+ for (int i = 0; i < num_features; i++) {
+ c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0), 1}));
+ }
+ return Status::OK();
+ });
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 9e67662fa6..e59958749c 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -11360,6 +11360,29 @@ op {
is_commutative: true
}
op {
+ name: "BoostedTreesBucketize"
+ input_arg {
+ name: "float_values"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "bucket_boundaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ output_arg {
+ name: "buckets"
+ type: DT_INT32
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+}
+op {
name: "BoostedTreesCalculateBestGainsPerFeature"
input_arg {
name: "node_id_range"
@@ -11469,6 +11492,29 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesCreateQuantileStreamResource"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "epsilon"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "num_streams"
+ type: DT_INT64
+ }
+ attr {
+ name: "max_elements"
+ type: "int"
+ default_value {
+ i: 1099511627776
+ }
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesDeserializeEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -11562,6 +11608,32 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesMakeQuantileSummaries"
+ input_arg {
+ name: "float_values"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "example_weights"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "epsilon"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "summaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+}
+op {
name: "BoostedTreesMakeStatsSummary"
input_arg {
name: "node_ids"
@@ -11631,6 +11703,83 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesQuantileStreamResourceAddSummaries"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "summaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceFlush"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "num_buckets"
+ type: DT_INT64
+ }
+ attr {
+ name: "generate_quantiles"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "bucket_boundaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceHandleOp"
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesSerializeEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -13070,6 +13219,71 @@ op {
is_stateful: true
}
op {
+ name: "ConditionalAccumulator"
+ output_arg {
+ name: "handle"
+ type: DT_STRING
+ is_ref: true
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "reduction_type"
+ type: "string"
+ default_value {
+ s: "MEAN"
+ }
+ allowed_values {
+ list {
+ s: "MEAN"
+ s: "SUM"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "Conj"
input_arg {
name: "input"
@@ -27127,6 +27341,18 @@ op {
is_stateful: true
}
op {
+ name: "IsBoostedTreesQuantileStreamResourceInitialized"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "is_initialized"
+ type: DT_BOOL
+ }
+ is_stateful: true
+}
+op {
name: "IsFinite"
input_arg {
name: "x"
@@ -29381,6 +29607,49 @@ op {
}
}
op {
+ name: "MapDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "MapDefun"
input_arg {
name: "arguments"
@@ -34842,6 +35111,29 @@ op {
}
}
op {
+ name: "ModelDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Mul"
input_arg {
name: "x"
@@ -35682,6 +35974,42 @@ op {
}
}
op {
+ name: "NonMaxSuppressionV2"
+ input_arg {
+ name: "boxes"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scores"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "iou_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+}
+op {
name: "NonMaxSuppressionV3"
input_arg {
name: "boxes"
@@ -35709,6 +36037,46 @@ op {
}
}
op {
+ name: "NonMaxSuppressionV3"
+ input_arg {
+ name: "boxes"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scores"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "iou_threshold"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "score_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+}
+op {
name: "NonMaxSuppressionV4"
input_arg {
name: "boxes"
@@ -35747,6 +36115,57 @@ op {
}
}
op {
+ name: "NonMaxSuppressionV4"
+ input_arg {
+ name: "boxes"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scores"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "iou_threshold"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "score_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "valid_outputs"
+ type: DT_INT32
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "pad_to_max_output_size"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
name: "NonMaxSuppressionWithOverlaps"
input_arg {
name: "overlaps"
@@ -37037,6 +37456,54 @@ op {
}
}
op {
+ name: "ParallelInterleaveDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "cycle_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "block_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ParallelMapDataset"
input_arg {
name: "input_dataset"
@@ -37118,6 +37585,53 @@ op {
}
}
op {
+ name: "ParallelMapDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "ParameterizedTruncatedNormal"
input_arg {
name: "shape"
@@ -56665,6 +57179,125 @@ op {
}
}
op {
+ name: "SdcaOptimizer"
+ input_arg {
+ name: "sparse_example_indices"
+ type: DT_INT64
+ number_attr: "num_sparse_features"
+ }
+ input_arg {
+ name: "sparse_feature_indices"
+ type: DT_INT64
+ number_attr: "num_sparse_features"
+ }
+ input_arg {
+ name: "sparse_feature_values"
+ type: DT_FLOAT
+ number_attr: "num_sparse_features_with_values"
+ }
+ input_arg {
+ name: "dense_features"
+ type: DT_FLOAT
+ number_attr: "num_dense_features"
+ }
+ input_arg {
+ name: "example_weights"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "example_labels"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "sparse_indices"
+ type: DT_INT64
+ number_attr: "num_sparse_features"
+ }
+ input_arg {
+ name: "sparse_weights"
+ type: DT_FLOAT
+ number_attr: "num_sparse_features"
+ }
+ input_arg {
+ name: "dense_weights"
+ type: DT_FLOAT
+ number_attr: "num_dense_features"
+ }
+ input_arg {
+ name: "example_state_data"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "out_example_state_data"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "out_delta_sparse_weights"
+ type: DT_FLOAT
+ number_attr: "num_sparse_features"
+ }
+ output_arg {
+ name: "out_delta_dense_weights"
+ type: DT_FLOAT
+ number_attr: "num_dense_features"
+ }
+ attr {
+ name: "loss_type"
+ type: "string"
+ allowed_values {
+ list {
+ s: "logistic_loss"
+ s: "squared_loss"
+ s: "hinge_loss"
+ s: "smooth_hinge_loss"
+ s: "poisson_loss"
+ }
+ }
+ }
+ attr {
+ name: "adaptative"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "num_sparse_features"
+ type: "int"
+ has_minimum: true
+ }
+ attr {
+ name: "num_sparse_features_with_values"
+ type: "int"
+ has_minimum: true
+ }
+ attr {
+ name: "num_dense_features"
+ type: "int"
+ has_minimum: true
+ }
+ attr {
+ name: "l1"
+ type: "float"
+ }
+ attr {
+ name: "l2"
+ type: "float"
+ }
+ attr {
+ name: "num_loss_partitions"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "num_inner_iterations"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "SdcaShrinkL1"
input_arg {
name: "weights"
@@ -64381,6 +65014,71 @@ op {
is_stateful: true
}
op {
+ name: "SparseConditionalAccumulator"
+ output_arg {
+ name: "handle"
+ type: DT_STRING
+ is_ref: true
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "reduction_type"
+ type: "string"
+ default_value {
+ s: "MEAN"
+ }
+ allowed_values {
+ list {
+ s: "MEAN"
+ s: "SUM"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "SparseCross"
input_arg {
name: "indices"
@@ -69174,6 +69872,21 @@ op {
}
}
op {
+ name: "StaticRegexFullMatch"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_BOOL
+ }
+ attr {
+ name: "pattern"
+ type: "string"
+ }
+}
+op {
name: "StaticRegexReplace"
input_arg {
name: "input"
@@ -74889,9 +75602,21 @@ op {
type: DT_VARIANT
}
input_arg {
- name: "window_size"
+ name: "size"
type: DT_INT64
}
+ input_arg {
+ name: "shift"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "stride"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
output_arg {
name: "handle"
type: DT_VARIANT
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index eed0bce174..ffab8ad661 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -419,6 +419,7 @@ REGISTER_OP("ConditionalAccumulator")
.Attr("shape: shape")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
+ .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(2));
@@ -456,6 +457,7 @@ REGISTER_OP("SparseConditionalAccumulator")
.Attr("shape: shape")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
+ .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(2));
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index f03639e833..4d3f272c1b 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -198,6 +198,7 @@ REGISTER_OP("MapDataset")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
+ .Attr("use_inter_op_parallelism: bool = true")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ParallelMapDataset")
@@ -209,6 +210,7 @@ REGISTER_OP("ParallelMapDataset")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
+ .Attr("use_inter_op_parallelism: bool = true")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("MapAndBatchDataset")
@@ -325,6 +327,19 @@ REGISTER_OP("ParallelInterleaveDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("ParallelInterleaveDatasetV2")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("cycle_length: int64")
+ .Input("block_length: int64")
+ .Input("num_parallel_calls: int64")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("GroupByReducerDataset")
.Input("input_dataset: variant")
.Input("key_func_other_arguments: Tkey_func_other_arguments")
@@ -381,14 +396,20 @@ REGISTER_OP("FilterByLastComponentDataset")
REGISTER_OP("WindowDataset")
.Input("input_dataset: variant")
- .Input("window_size: int64")
+ .Input("size: int64")
+ .Input("shift: int64")
+ .Input("stride: int64")
+ .Input("drop_remainder: bool")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
- // batch_size should be a scalar.
+ // size, shift, stride, and drop_remainder should be scalars.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
return shape_inference::ScalarShape(c);
});
@@ -858,6 +879,13 @@ REGISTER_OP("IteratorGetNextAsOptional")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("ModelDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("MapDefun")
.Input("arguments: Targuments")
.Output("output: output_types")
@@ -866,7 +894,7 @@ REGISTER_OP("MapDefun")
.Attr("output_shapes: list(shape) >= 1")
.Attr("f: func")
.SetShapeFn([](shape_inference::InferenceContext* c) {
- std::vector<TensorShape> output_shapes;
+ std::vector<PartialTensorShape> output_shapes;
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
if (output_shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
@@ -876,6 +904,10 @@ REGISTER_OP("MapDefun")
int64 dim_zero = -1;
for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) {
+ if (c->Rank(c->input(i)) == 0) {
+ return errors::InvalidArgument(
+ "Inputs must have rank at least 1. Input ", i, " has rank of 0");
+ }
auto dim_handle = c->Dim(c->input(i), 0);
if (c->ValueKnown(dim_handle)) {
if (dim_zero == -1) {
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 11ca0bd259..5427275284 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -683,11 +683,12 @@ REGISTER_OP("NonMaxSuppression")
});
REGISTER_OP("NonMaxSuppressionV2")
- .Input("boxes: float")
- .Input("scores: float")
+ .Input("boxes: T")
+ .Input("scores: T")
.Input("max_output_size: int32")
.Input("iou_threshold: float")
.Output("selected_indices: int32")
+ .Attr("T: {half, float} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
// Get inputs and validate ranks.
ShapeHandle boxes;
@@ -711,22 +712,24 @@ REGISTER_OP("NonMaxSuppressionV2")
});
REGISTER_OP("NonMaxSuppressionV3")
- .Input("boxes: float")
- .Input("scores: float")
+ .Input("boxes: T")
+ .Input("scores: T")
.Input("max_output_size: int32")
.Input("iou_threshold: float")
.Input("score_threshold: float")
.Output("selected_indices: int32")
+ .Attr("T: {half, float} = DT_FLOAT")
.SetShapeFn(NMSShapeFn);
REGISTER_OP("NonMaxSuppressionV4")
- .Input("boxes: float")
- .Input("scores: float")
+ .Input("boxes: T")
+ .Input("scores: T")
.Input("max_output_size: int32")
.Input("iou_threshold: float")
.Input("score_threshold: float")
.Output("selected_indices: int32")
.Output("valid_outputs: int32")
+ .Attr("T: {half, float} = DT_FLOAT")
.Attr("pad_to_max_output_size: bool = false")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(NMSShapeFn(c));
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index c0376b5721..4ece1c8953 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -4272,6 +4272,29 @@ op {
is_commutative: true
}
op {
+ name: "BoostedTreesBucketize"
+ input_arg {
+ name: "float_values"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "bucket_boundaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ output_arg {
+ name: "buckets"
+ type: DT_INT32
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+}
+op {
name: "BoostedTreesCalculateBestGainsPerFeature"
input_arg {
name: "node_id_range"
@@ -4381,6 +4404,29 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesCreateQuantileStreamResource"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "epsilon"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "num_streams"
+ type: DT_INT64
+ }
+ attr {
+ name: "max_elements"
+ type: "int"
+ default_value {
+ i: 1099511627776
+ }
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesDeserializeEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -4474,6 +4520,32 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesMakeQuantileSummaries"
+ input_arg {
+ name: "float_values"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "example_weights"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "epsilon"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "summaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+}
+op {
name: "BoostedTreesMakeStatsSummary"
input_arg {
name: "node_ids"
@@ -4543,6 +4615,83 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesQuantileStreamResourceAddSummaries"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "summaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceFlush"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "num_buckets"
+ type: DT_INT64
+ }
+ attr {
+ name: "generate_quantiles"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "bucket_boundaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceHandleOp"
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesSerializeEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -5592,6 +5741,19 @@ op {
s: ""
}
}
+ attr {
+ name: "reduction_type"
+ type: "string"
+ default_value {
+ s: "MEAN"
+ }
+ allowed_values {
+ list {
+ s: "MEAN"
+ s: "SUM"
+ }
+ }
+ }
is_stateful: true
}
op {
@@ -13149,6 +13311,18 @@ op {
is_stateful: true
}
op {
+ name: "IsBoostedTreesQuantileStreamResourceInitialized"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "is_initialized"
+ type: DT_BOOL
+ }
+ is_stateful: true
+}
+op {
name: "IsFinite"
input_arg {
name: "x"
@@ -14542,6 +14716,13 @@ op {
has_minimum: true
minimum: 1
}
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
}
op {
name: "MapDefun"
@@ -16540,6 +16721,29 @@ op {
}
}
op {
+ name: "ModelDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Mul"
input_arg {
name: "x"
@@ -17078,11 +17282,11 @@ op {
name: "NonMaxSuppressionV2"
input_arg {
name: "boxes"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "scores"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "max_output_size"
@@ -17096,16 +17300,29 @@ op {
name: "selected_indices"
type: DT_INT32
}
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
}
op {
name: "NonMaxSuppressionV3"
input_arg {
name: "boxes"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "scores"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "max_output_size"
@@ -17123,16 +17340,29 @@ op {
name: "selected_indices"
type: DT_INT32
}
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
}
op {
name: "NonMaxSuppressionV4"
input_arg {
name: "boxes"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "scores"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "max_output_size"
@@ -17155,6 +17385,19 @@ op {
type: DT_INT32
}
attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
name: "pad_to_max_output_size"
type: "bool"
default_value {
@@ -18192,6 +18435,54 @@ op {
}
}
op {
+ name: "ParallelInterleaveDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "cycle_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "block_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ParallelMapDataset"
input_arg {
name: "input_dataset"
@@ -18230,6 +18521,13 @@ op {
has_minimum: true
minimum: 1
}
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
}
op {
name: "ParameterizedTruncatedNormal"
@@ -26977,6 +27275,7 @@ op {
s: "squared_loss"
s: "hinge_loss"
s: "smooth_hinge_loss"
+ s: "poisson_loss"
}
}
}
@@ -29609,6 +29908,19 @@ op {
s: ""
}
}
+ attr {
+ name: "reduction_type"
+ type: "string"
+ default_value {
+ s: "MEAN"
+ }
+ allowed_values {
+ list {
+ s: "MEAN"
+ s: "SUM"
+ }
+ }
+ }
is_stateful: true
}
op {
@@ -32107,6 +32419,21 @@ op {
}
}
op {
+ name: "StaticRegexFullMatch"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_BOOL
+ }
+ attr {
+ name: "pattern"
+ type: "string"
+ }
+}
+op {
name: "StaticRegexReplace"
input_arg {
name: "input"
@@ -35872,9 +36199,21 @@ op {
type: DT_VARIANT
}
input_arg {
- name: "window_size"
+ name: "size"
type: DT_INT64
}
+ input_arg {
+ name: "shift"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "stride"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
output_arg {
name: "handle"
type: DT_VARIANT
diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc
index 79ca96d249..eff453241d 100644
--- a/tensorflow/core/ops/parsing_ops.cc
+++ b/tensorflow/core/ops/parsing_ops.cc
@@ -343,10 +343,11 @@ REGISTER_OP("DecodeCSV")
// Validate the record_defaults inputs.
for (int i = 1; i < c->num_inputs(); ++i) {
ShapeHandle v;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &v));
- if (c->Value(c->Dim(v, 0)) > 1) {
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
+ if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
return errors::InvalidArgument(
- "Shape of a default must be a length-0 or length-1 vector");
+ "Shape of a default must be a length-0 or length-1 vector, or a "
+ "scalar.");
}
}
diff --git a/tensorflow/core/ops/parsing_ops_test.cc b/tensorflow/core/ops/parsing_ops_test.cc
index c65e66d1a8..ba594e400c 100644
--- a/tensorflow/core/ops/parsing_ops_test.cc
+++ b/tensorflow/core/ops/parsing_ops_test.cc
@@ -52,9 +52,12 @@ TEST(ParsingOpsTest, DecodeCSV_ShapeFn) {
INFER_OK(op, "[1,2,?,4];?;?", "in0;in0");
INFER_OK(op, "[1,2,?,4];[?];[?]", "in0;in0");
+ // Scalar defaults are ok
+ INFER_OK(op, "?;?;[]", "in0;in0");
+
// Check errors in the record_defaults inputs.
- INFER_ERROR("must be rank 1", op, "?;?;[]");
- INFER_ERROR("must be rank 1", op, "?;[];?");
+ INFER_ERROR("must be at most rank 1 but is rank 2", op, "?;?;[1,2]");
+ INFER_ERROR("must be at most rank 1 but is rank 2", op, "?;[3,4];?");
INFER_ERROR("Shape of a default must be", op, "?;?;[2]");
INFER_ERROR("Shape of a default must be", op, "?;[2];?");
}
diff --git a/tensorflow/core/ops/sdca_ops.cc b/tensorflow/core/ops/sdca_ops.cc
index 4025070adb..fdf53a55dd 100644
--- a/tensorflow/core/ops/sdca_ops.cc
+++ b/tensorflow/core/ops/sdca_ops.cc
@@ -41,7 +41,7 @@ static Status ApplySdcaOptimizerShapeFn(InferenceContext* c) {
REGISTER_OP("SdcaOptimizer")
.Attr(
"loss_type: {'logistic_loss', 'squared_loss', 'hinge_loss',"
- "'smooth_hinge_loss'}")
+ "'smooth_hinge_loss', 'poisson_loss'}")
.Attr("adaptative : bool=false")
.Attr("num_sparse_features: int >= 0")
.Attr("num_sparse_features_with_values: int >= 0")
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 7aa1e71809..ef8b15dc8a 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -56,6 +56,12 @@ REGISTER_OP("RegexFullMatch")
return Status::OK();
});
+REGISTER_OP("StaticRegexFullMatch")
+ .Input("input: string")
+ .Attr("pattern: string")
+ .Output("output: bool")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
REGISTER_OP("StringToHashBucketFast")
.Input("input: string")
.Output("output: int64")
diff --git a/tensorflow/core/platform/abi.cc b/tensorflow/core/platform/abi.cc
index e597a490d6..d7a13a3528 100644
--- a/tensorflow/core/platform/abi.cc
+++ b/tensorflow/core/platform/abi.cc
@@ -37,13 +37,13 @@ extern "C" char* __unDName(char* output_string, const char* name,
namespace tensorflow {
namespace port {
-std::string MaybeAbiDemangle(const char* name) {
+string MaybeAbiDemangle(const char* name) {
#if defined(_MSC_VER)
std::unique_ptr<char> demangled{__unDName(nullptr, name, 0, std::malloc,
std::free,
static_cast<unsigned short>(0))};
- return std::string(demangled.get() != nullptr ? demangled.get() : name);
+ return string(demangled.get() != nullptr ? demangled.get() : name);
#else
int status = 0;
std::unique_ptr<char, void (*)(void*)> res{
diff --git a/tensorflow/core/platform/abi.h b/tensorflow/core/platform/abi.h
index 591e83b0c4..d1498a6a64 100644
--- a/tensorflow/core/platform/abi.h
+++ b/tensorflow/core/platform/abi.h
@@ -17,11 +17,12 @@ limitations under the License.
#define TENSORFLOW_CORE_PLATFORM_ABI_H_
#include <string>
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace port {
-std::string MaybeAbiDemangle(const char* name);
+string MaybeAbiDemangle(const char* name);
} // namespace port
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc
index a1be4aacce..5e1eabee5b 100644
--- a/tensorflow/core/platform/cloud/curl_http_request.cc
+++ b/tensorflow/core/platform/cloud/curl_http_request.cc
@@ -394,9 +394,9 @@ size_t CurlHttpRequest::HeaderCallback(const void* ptr, size_t size,
.StopCapture()
.OneLiteral(": ")
.GetResult(&value, &name)) {
- string str_value = std::string(value);
+ string str_value(value);
str_util::StripTrailingWhitespace(&str_value);
- that->response_headers_[std::string(name)] = str_value;
+ that->response_headers_[string(name)] = str_value;
}
return size * nmemb;
}
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 9d33787bd5..83228fab6f 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -179,13 +179,13 @@ Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket,
return errors::InvalidArgument("GCS path doesn't start with 'gs://': ",
fname);
}
- *bucket = std::string(bucketp);
+ *bucket = string(bucketp);
if (bucket->empty() || *bucket == ".") {
return errors::InvalidArgument("GCS path doesn't contain a bucket name: ",
fname);
}
str_util::ConsumePrefix(&objectp, "/");
- *object = std::string(objectp);
+ *object = string(objectp);
if (!empty_object_ok && object->empty()) {
return errors::InvalidArgument("GCS path doesn't contain an object name: ",
fname);
@@ -224,7 +224,7 @@ std::set<string> AddAllSubpaths(const std::vector<string>& paths) {
for (const string& path : paths) {
StringPiece subpath = io::Dirname(path);
while (!subpath.empty()) {
- result.emplace(std::string(subpath));
+ result.emplace(string(subpath));
subpath = io::Dirname(subpath);
}
}
@@ -371,7 +371,7 @@ class GcsWritableFile : public WritableFile {
~GcsWritableFile() override { Close().IgnoreError(); }
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
TF_RETURN_IF_ERROR(CheckWritable());
sync_needed_ = true;
outfile_ << data;
@@ -723,7 +723,7 @@ GcsFileSystem::GcsFileSystem() {
if (!header_name.empty() && !header_value.empty()) {
additional_header_.reset(new std::pair<const string, const string>(
- std::string(header_name), std::string(header_value)));
+ string(header_name), string(header_value)));
VLOG(1) << "GCS additional header ENABLED. "
<< "Name: " << additional_header_->first << ", "
@@ -1229,7 +1229,7 @@ Status GcsFileSystem::GetMatchingPaths(const string& pattern,
// Find the fixed prefix by looking for the first wildcard.
const string& fixed_prefix =
pattern.substr(0, pattern.find_first_of("*?[\\"));
- const string& dir = std::string(io::Dirname(fixed_prefix));
+ const string dir(io::Dirname(fixed_prefix));
if (dir.empty()) {
return errors::InvalidArgument(
"A GCS pattern doesn't have a bucket name: ", pattern);
@@ -1326,7 +1326,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname,
" doesn't match the prefix ", object_prefix));
}
if (!relative_path.empty() || include_self_directory_marker) {
- result->emplace_back(std::string(relative_path));
+ result->emplace_back(relative_path);
}
if (++retrieved_results >= max_results) {
return Status::OK();
@@ -1354,7 +1354,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname,
"Unexpected response: the returned folder name ", prefix_str,
" doesn't match the prefix ", object_prefix);
}
- result->emplace_back(std::string(relative_path));
+ result->emplace_back(relative_path);
if (++retrieved_results >= max_results) {
return Status::OK();
}
diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc
index ee6ba7b041..9b85cae9b9 100644
--- a/tensorflow/core/platform/cloud/oauth_client.cc
+++ b/tensorflow/core/platform/cloud/oauth_client.cc
@@ -216,7 +216,7 @@ Status OAuthClient::GetTokenFromServiceAccountJson(
// Send the request to the Google OAuth 2.0 server to get the token.
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
std::vector<char> response_buffer;
- request->SetUri(std::string(oauth_server_uri));
+ request->SetUri(string(oauth_server_uri));
request->SetPostFromBuffer(request_body.c_str(), request_body.size());
request->SetResultBuffer(&response_buffer);
TF_RETURN_IF_ERROR(request->Send());
@@ -248,7 +248,7 @@ Status OAuthClient::GetTokenFromRefreshTokenJson(
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
std::vector<char> response_buffer;
- request->SetUri(std::string(oauth_server_uri));
+ request->SetUri(string(oauth_server_uri));
request->SetPostFromBuffer(request_body.c_str(), request_body.size());
request->SetResultBuffer(&response_buffer);
TF_RETURN_IF_ERROR(request->Send());
diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc
index 4ffa72288b..1cd0641cd3 100644
--- a/tensorflow/core/platform/cloud/oauth_client_test.cc
+++ b/tensorflow/core/platform/cloud/oauth_client_test.cc
@@ -126,9 +126,9 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) {
EXPECT_EQ("urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer",
grant_type);
- int last_dot = std::string(assertion).find_last_of(".");
- string header_dot_claim = std::string(assertion.substr(0, last_dot));
- string signature_encoded = std::string(assertion.substr(last_dot + 1));
+ int last_dot = assertion.rfind('.');
+ string header_dot_claim(assertion.substr(0, last_dot));
+ string signature_encoded(assertion.substr(last_dot + 1));
// Check that 'signature' signs 'header_dot_claim'.
diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/cloud/retrying_file_system.h
index 92aa72be89..941ab7ad65 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system.h
+++ b/tensorflow/core/platform/cloud/retrying_file_system.h
@@ -177,7 +177,7 @@ class RetryingWritableFile : public WritableFile {
Close().IgnoreError();
}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
return RetryingUtils::CallWithRetries(
[this, &data]() { return base_file_->Append(data); },
initial_delay_microseconds_);
diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
index ec2c470db7..5910fef1d2 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
@@ -72,7 +72,7 @@ class MockRandomAccessFile : public RandomAccessFile {
class MockWritableFile : public WritableFile {
public:
explicit MockWritableFile(const ExpectedCalls& calls) : calls_(calls) {}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
return calls_.ConsumeNextCall("Append");
}
Status Close() override { return calls_.ConsumeNextCall("Close"); }
diff --git a/tensorflow/core/platform/cord.h b/tensorflow/core/platform/cord.h
new file mode 100644
index 0000000000..7c5c6655be
--- /dev/null
+++ b/tensorflow/core/platform/cord.h
@@ -0,0 +1,26 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_CORD_H_
+#define TENSORFLOW_CORE_PLATFORM_CORD_H_
+
+// Include appropriate platform-dependent implementations
+#if defined(PLATFORM_GOOGLE)
+#include "tensorflow/core/platform/google/cord.h"
+#else
+#include "tensorflow/core/platform/default/cord.h"
+#endif
+
+#endif // TENSORFLOW_CORE_PLATFORM_CORD_H_
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 0411a8c4f9..bb841aeab7 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -625,7 +625,9 @@ def tf_additional_lib_deps():
"""Additional dependencies needed to build TF libraries."""
return [
"@com_google_absl//absl/base:base",
+ "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/types:span",
+ "@com_google_absl//absl/types:optional",
] + if_static(
["@nsync//:nsync_cpp"],
["@nsync//:nsync_headers"],
diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl
index 3a012c23fd..37475feebe 100644
--- a/tensorflow/core/platform/default/build_config_root.bzl
+++ b/tensorflow/core/platform/default/build_config_root.bzl
@@ -3,64 +3,64 @@
# be separate to avoid cyclic references.
def tf_cuda_tests_tags():
- return ["requires-gpu"]
+ return ["requires-gpu", "local", "gpu"]
def tf_sycl_tests_tags():
- return ["requires-gpu"]
+ return ["requires-gpu", "local", "gpu"]
def tf_additional_plugin_deps():
- return select({
- str(Label("//tensorflow:with_xla_support")): [
- str(Label("//tensorflow/compiler/jit"))
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_xla_support")): [
+ str(Label("//tensorflow/compiler/jit")),
+ ],
+ "//conditions:default": [],
+ })
def tf_additional_xla_deps_py():
- return []
+ return []
def tf_additional_grpc_deps_py():
- return []
+ return []
def tf_additional_license_deps():
- return select({
- str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"],
+ "//conditions:default": [],
+ })
def tf_additional_verbs_deps():
- return select({
- str(Label("//tensorflow:with_verbs_support")): [
- str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
- str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_verbs_support")): [
+ str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
+ str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
+ ],
+ "//conditions:default": [],
+ })
def tf_additional_mpi_deps():
- return select({
- str(Label("//tensorflow:with_mpi_support")): [
- str(Label("//tensorflow/contrib/mpi:mpi_server_lib")),
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_mpi_support")): [
+ str(Label("//tensorflow/contrib/mpi:mpi_server_lib")),
+ ],
+ "//conditions:default": [],
+ })
def tf_additional_gdr_deps():
- return select({
- str(Label("//tensorflow:with_gdr_support")): [
- str(Label("//tensorflow/contrib/gdr:gdr_server_lib")),
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_gdr_support")): [
+ str(Label("//tensorflow/contrib/gdr:gdr_server_lib")),
+ ],
+ "//conditions:default": [],
+ })
-def if_static(extra_deps, otherwise=[]):
- return select({
- str(Label("//tensorflow:framework_shared_object")): otherwise,
- "//conditions:default": extra_deps,
- })
+def if_static(extra_deps, otherwise = []):
+ return select({
+ str(Label("//tensorflow:framework_shared_object")): otherwise,
+ "//conditions:default": extra_deps,
+ })
-def if_dynamic_kernels(extra_deps, otherwise=[]):
- return select({
- str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps,
- "//conditions:default": otherwise,
- })
+def if_dynamic_kernels(extra_deps, otherwise = []):
+ return select({
+ str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps,
+ "//conditions:default": otherwise,
+ })
diff --git a/tensorflow/core/lib/gtl/optional.cc b/tensorflow/core/platform/default/cord.h
index 8dea073788..1ab682182c 100644
--- a/tensorflow/core/lib/gtl/optional.cc
+++ b/tensorflow/core/platform/default/cord.h
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,13 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/lib/gtl/optional.h"
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
-namespace tensorflow {
-namespace gtl {
+class Cord;
+namespace absl {
+using ::Cord;
+} // namespace absl
-nullopt_t::init_t nullopt_t::init;
-extern const nullopt_t nullopt{nullopt_t::init};
-
-} // namespace gtl
-} // namespace tensorflow
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc
index ccddf1eafc..0389149469 100644
--- a/tensorflow/core/platform/default/device_tracer.cc
+++ b/tensorflow/core/platform/default/device_tracer.cc
@@ -321,6 +321,11 @@ class DeviceTracerImpl : public DeviceTracer,
return nullptr;
}
+ bool IsEnabled(bool is_expensive) const override {
+ // We don't do anything with 'Activities' so we are never 'enabled'.
+ return false;
+ }
+
protected:
// This callback is used exclusively by CUPTIManager.
friend class CUPTIManager;
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index 305a9a682f..2e32abdffb 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cord.h"
#include "tensorflow/core/platform/null_file_system.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
@@ -345,7 +346,13 @@ TEST_F(DefaultEnvTest, LocalTempFilename) {
// Write something to the temporary file.
std::unique_ptr<WritableFile> file_to_write;
TF_CHECK_OK(env->NewWritableFile(filename, &file_to_write));
+#if defined(PLATFORM_GOOGLE)
+ TF_CHECK_OK(file_to_write->Append("Nu"));
+ TF_CHECK_OK(file_to_write->Append(absl::Cord("ll")));
+#else
+ // TODO(ebrevdo): Remove this version.
TF_CHECK_OK(file_to_write->Append("Null"));
+#endif
TF_CHECK_OK(file_to_write->Close());
TF_CHECK_OK(env->FileExists(filename));
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index 077b1d79cf..30059dc02e 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/cord.h"
#include "tensorflow/core/platform/file_statistics.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/platform.h"
@@ -252,7 +253,12 @@ class WritableFile {
virtual ~WritableFile();
/// \brief Append 'data' to the file.
- virtual Status Append(const StringPiece& data) = 0;
+ virtual Status Append(StringPiece data) = 0;
+
+ // \brief Append 'data' to the file.
+ virtual Status Append(const absl::Cord& cord) {
+ return errors::Unimplemented("Append(absl::Cord) is not implemented");
+ }
/// \brief Close the file.
///
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
index 8cdb08f51b..eb35531e9f 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
@@ -282,7 +282,7 @@ class HDFSWritableFile : public WritableFile {
}
}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
if (hdfs_->hdfsWrite(fs_, file_, data.data(),
static_cast<tSize>(data.size())) == -1) {
return IOError(filename_, errno);
diff --git a/tensorflow/core/platform/posix/posix_file_system.cc b/tensorflow/core/platform/posix/posix_file_system.cc
index 47bfa020ce..c7afab9583 100644
--- a/tensorflow/core/platform/posix/posix_file_system.cc
+++ b/tensorflow/core/platform/posix/posix_file_system.cc
@@ -91,7 +91,7 @@ class PosixWritableFile : public WritableFile {
}
}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
size_t r = fwrite(data.data(), 1, data.size(), file_);
if (r != data.size()) {
return IOError(filename_, errno);
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index ce0f6cd741..e0b8e37745 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -211,7 +211,7 @@ class S3WritableFile : public WritableFile {
std::ios_base::binary | std::ios_base::trunc | std::ios_base::in |
std::ios_base::out)) {}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
if (!outfile_) {
return errors::FailedPrecondition(
"The internal temporary file is not writable.");
diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h
index e5851f1dfe..9974bbbb4e 100644
--- a/tensorflow/core/platform/tracing.h
+++ b/tensorflow/core/platform/tracing.h
@@ -155,6 +155,10 @@ class TraceCollector {
StringPiece name_part1, StringPiece name_part2,
bool is_expensive) const = 0;
+ // Returns true if this activity handle tracking is enabled for an op of the
+ // given expensiveness.
+ virtual bool IsEnabled(bool is_expensive) const = 0;
+
protected:
static string ConcatenateNames(StringPiece first, StringPiece second);
diff --git a/tensorflow/core/platform/windows/windows_file_system.cc b/tensorflow/core/platform/windows/windows_file_system.cc
index 9079a5ccaa..6cf79634d7 100644
--- a/tensorflow/core/platform/windows/windows_file_system.cc
+++ b/tensorflow/core/platform/windows/windows_file_system.cc
@@ -150,7 +150,7 @@ class WindowsWritableFile : public WritableFile {
}
}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
DWORD bytes_written = 0;
DWORD data_size = static_cast<DWORD>(data.size());
BOOL write_result =
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index c68504a272..85cd02350a 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -390,9 +390,12 @@ message ConfigProto {
message Experimental {
// Task name for group resolution.
string collective_group_leader = 1;
- // Whether the client will format templated errors. For example, the string:
- // "The node was defined on ^^node:Foo:${file}:${line}^^".
- bool client_handles_error_formatting = 2;
+
+ // We removed the flag client_handles_error_formatting. Marking the tag
+ // number as reserved.
+ // TODO(shikharagarwal): Should we just remove this tag so that it can be
+ // used in future for other purpose?
+ reserved 2;
// Which executor to use, the default executor will be used
// if it is an empty string or "DEFAULT"
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 4129c93af5..b043a69431 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -19,12 +19,12 @@ limitations under the License.
// TensorFlow uses semantic versioning, see http://semver.org/.
#define TF_MAJOR_VERSION 1
-#define TF_MINOR_VERSION 10
+#define TF_MINOR_VERSION 11
#define TF_PATCH_VERSION 0
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
// "-beta", "-rc", "-rc.1")
-#define TF_VERSION_SUFFIX ""
+#define TF_VERSION_SUFFIX "-rc1"
#define TF_STR_HELPER(x) #x
#define TF_STR(x) TF_STR_HELPER(x)
diff --git a/tensorflow/core/util/ctc/ctc_beam_entry.h b/tensorflow/core/util/ctc/ctc_beam_entry.h
index 973e315f09..24002e72a0 100644
--- a/tensorflow/core/util/ctc/ctc_beam_entry.h
+++ b/tensorflow/core/util/ctc/ctc_beam_entry.h
@@ -1,4 +1,3 @@
-// LINT.IfChange
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// LINT.IfChange
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
diff --git a/tensorflow/core/util/ctc/ctc_beam_scorer.h b/tensorflow/core/util/ctc/ctc_beam_scorer.h
index 1a622babe1..1e45a8abd3 100644
--- a/tensorflow/core/util/ctc/ctc_beam_scorer.h
+++ b/tensorflow/core/util/ctc/ctc_beam_scorer.h
@@ -1,4 +1,3 @@
-// LINT.IfChange
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// LINT.IfChange
// Collection of scoring classes that can be extended and provided to the
// CTCBeamSearchDecoder to incorporate additional scoring logic (such as a
diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h
index 5e2aeb7830..6fbb1ed0da 100644
--- a/tensorflow/core/util/ctc/ctc_beam_search.h
+++ b/tensorflow/core/util/ctc/ctc_beam_search.h
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// LINT.IfChange
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
diff --git a/tensorflow/core/util/ctc/ctc_decoder.h b/tensorflow/core/util/ctc/ctc_decoder.h
index 3be36822e5..b55d7d77ac 100644
--- a/tensorflow/core/util/ctc/ctc_decoder.h
+++ b/tensorflow/core/util/ctc/ctc_decoder.h
@@ -1,4 +1,3 @@
-// LINT.IfChange
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// LINT.IfChange
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
#define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
diff --git a/tensorflow/core/util/ctc/ctc_loss_util.h b/tensorflow/core/util/ctc/ctc_loss_util.h
index 36be9e92ef..054412d388 100644
--- a/tensorflow/core/util/ctc/ctc_loss_util.h
+++ b/tensorflow/core/util/ctc/ctc_loss_util.h
@@ -1,4 +1,3 @@
-// LINT.IfChange
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// LINT.IfChange
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_
#define TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 6474319370..680211edff 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
#ifdef INTEL_MKL
+#include <string>
#include <memory>
#include <unordered_map>
#include <utility>
@@ -56,6 +57,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/env_var.h"
#ifndef INTEL_MKL_ML_ONLY
#include "mkldnn.hpp"
@@ -102,6 +104,8 @@ typedef enum {
Dim3d_I = 1
} MklDnnDims3D;
+static const int kSmallBatchSize = 32;
+
#ifdef INTEL_MKL_ML_ONLY
class MklShape {
public:
@@ -2000,7 +2004,9 @@ const mkldnn::memory::dims NONE_DIMS = {};
template <typename T>
class MklPrimitiveFactory {
public:
- MklPrimitiveFactory() {}
+ MklPrimitiveFactory() {
+ }
+
~MklPrimitiveFactory() {}
MklPrimitive* GetOp(const string& key) {
@@ -2023,6 +2029,22 @@ class MklPrimitiveFactory {
map[key] = op;
}
+ /// Function to decide whether HW has AVX512 or AVX2
+ /// For those legacy device(w/o AVX512 and AVX2),
+ /// MKL-DNN GEMM will be used.
+ static inline bool IsLegacyPlatform() {
+ return (!port::TestCPUFeature(port::CPUFeature::AVX512F)
+ && !port::TestCPUFeature(port::CPUFeature::AVX2));
+ }
+
+ /// Fuction to check whether primitive memory optimization is enabled
+ static inline bool IsPrimitiveMemOptEnabled() {
+ bool is_primitive_mem_opt_enabled = true;
+ TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE", true,
+ &is_primitive_mem_opt_enabled));
+ return is_primitive_mem_opt_enabled;
+ }
+
private:
static inline std::unordered_map<string, MklPrimitive*>& GetHashMap() {
static thread_local std::unordered_map<string, MklPrimitive*> map_;
@@ -2060,7 +2082,7 @@ class FactoryKeyCreator {
const char delimiter = 'x';
const int kMaxKeyLength = 256;
void Append(StringPiece s) {
- key_.append(s.ToString());
+ key_.append(string(s));
key_.append(1, delimiter);
}
};
@@ -2099,7 +2121,7 @@ class MklReorderPrimitive : public MklPrimitive {
context_.dst_mem->set_data_handle(to->get_data_handle());
}
- private:
+ private:
struct ReorderContext {
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> dst_mem;
@@ -2141,7 +2163,7 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
return instance_;
}
- private:
+ private:
MklReorderPrimitiveFactory() {}
~MklReorderPrimitiveFactory() {}
@@ -2186,6 +2208,15 @@ inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
return *reorder_prim->GetPrimitive();
}
+// utility function to determine if it is conv 1x1 and stride != 1
+// for purpose of temporarily disabling primitive reuse
+inline bool IsConv1x1StrideNot1(memory::dims filter_dims, memory::dims strides) {
+ if (filter_dims.size() != 4 || strides.size() != 2) return false;
+
+ return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
+ ((strides[0] != 1) || (strides[1] != 1)));
+}
+
#endif // INTEL_MKL_DNN
} // namespace tensorflow
diff --git a/tensorflow/core/util/sparse/group_iterator.cc b/tensorflow/core/util/sparse/group_iterator.cc
index 204b933051..546b0a833c 100644
--- a/tensorflow/core/util/sparse/group_iterator.cc
+++ b/tensorflow/core/util/sparse/group_iterator.cc
@@ -21,8 +21,8 @@ namespace sparse {
void GroupIterable::IteratorStep::UpdateEndOfGroup() {
++next_loc_;
- int64 N = iter_->ix_.dim_size(0);
- auto ix_t = iter_->ix_.template matrix<int64>();
+ const auto& ix_t = iter_->ix_matrix_;
+ const int64 N = ix_t.dimension(0);
while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) {
++next_loc_;
}
@@ -54,7 +54,7 @@ GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++(
std::vector<int64> Group::group() const {
std::vector<int64> g;
- auto ix_t = iter_->ix_.template matrix<int64>();
+ const auto& ix_t = iter_->ix_matrix_;
for (const int d : iter_->group_dims_) {
g.push_back(ix_t(loc_, d));
}
@@ -62,8 +62,8 @@ std::vector<int64> Group::group() const {
}
TTypes<int64>::UnalignedConstMatrix Group::indices() const {
- return TTypes<int64>::UnalignedConstMatrix(
- &(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_);
+ return TTypes<int64>::UnalignedConstMatrix(&(iter_->ix_matrix_(loc_, 0)),
+ next_loc_ - loc_, iter_->dims_);
}
} // namespace sparse
diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h
index 3fa8cb6116..14610c61d9 100644
--- a/tensorflow/core/util/sparse/group_iterator.h
+++ b/tensorflow/core/util/sparse/group_iterator.h
@@ -79,6 +79,7 @@ class GroupIterable {
GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
: ix_(ix),
+ ix_matrix_(ix_.matrix<int64>()),
vals_(vals),
dims_(dims),
group_dims_(group_dims.begin(), group_dims.end()) {}
@@ -127,7 +128,8 @@ class GroupIterable {
private:
friend class Group;
- Tensor ix_;
+ const Tensor ix_;
+ const TTypes<int64>::ConstMatrix ix_matrix_;
Tensor vals_;
const int dims_;
const gtl::InlinedVector<int64, 8> group_dims_;
diff --git a/tensorflow/core/util/status_util.h b/tensorflow/core/util/status_util.h
deleted file mode 100644
index ea92f61dce..0000000000
--- a/tensorflow/core/util/status_util.h
+++ /dev/null
@@ -1,36 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_
-#define TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_
-
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-
-namespace tensorflow {
-
-// Creates a tag to be used in an exception error message. This can be parsed by
-// the Python layer and replaced with information about the node.
-//
-// For example, error_format_tag(node, "${file}") returns
-// "^^node:NODE_NAME:${line}^^" which would be rewritten by the Python layer as
-// e.g. "file/where/node/was/created.py".
-inline string error_format_tag(const Node& node, const string& format) {
- return strings::StrCat("^^node:", node.name(), ":", format, "^^");
-}
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_
diff --git a/tensorflow/core/util/status_util_test.cc b/tensorflow/core/util/status_util_test.cc
deleted file mode 100644
index 1f06004db2..0000000000
--- a/tensorflow/core/util/status_util_test.cc
+++ /dev/null
@@ -1,36 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/util/status_util.h"
-
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/graph/node_builder.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace {
-
-TEST(TestStatusUtil, ErrorFormatTagForNode) {
- Graph graph(OpRegistry::Global());
- Node* node;
- TF_CHECK_OK(NodeBuilder("Foo", "NoOp").Finalize(&graph, &node));
- EXPECT_EQ(error_format_tag(*node, "${line}"), "^^node:Foo:${line}^^");
- EXPECT_EQ(error_format_tag(*node, "${file}:${line}"),
- "^^node:Foo:${file}:${line}^^");
-}
-
-} // namespace
-} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_bundle/naming.h b/tensorflow/core/util/tensor_bundle/naming.h
index 6539d565e2..7b101971a8 100644
--- a/tensorflow/core/util/tensor_bundle/naming.h
+++ b/tensorflow/core/util/tensor_bundle/naming.h
@@ -35,6 +35,7 @@ limitations under the License.
#define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_NAMING_H_
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {